In [10]:
import argparse

import gzip
from collections import defaultdict


def get_gene_regions(annotation_file):
    assert annotation_file.endswith(
        (".gff3", ".gtf", ".gff3.gz", ".gtf.gz")), "Error: Unsupported annotation file format"

    gene_regions = {}
    exon_regions = defaultdict(lambda: defaultdict(list))
    intron_regions = defaultdict(lambda: defaultdict(list))

    def process_gene(parts, gene_id):
        chr, start, end = parts[0], int(parts[3]), int(parts[4])
        gene_regions[gene_id] = {"chr": chr, "start": start, "end": end}  # 1-based, start-inclusive, end-inclusive

    def process_exon(parts, gene_id, transcript_id):
        chr, start, end = parts[0], int(parts[3]), int(parts[4])
        exon_regions[gene_id][transcript_id].append((chr, start, end))  # 1-based, start-inclusive, end-inclusive

    def parse_attributes_gff3(attributes):
        return {key_value.split("=")[0]: key_value.split("=")[1] for key_value in attributes.split(";")}

    def parse_attributes_gtf(attributes):
        attr_dict = {}
        for attr in attributes.strip().split(";"):
            if attr:
                key, value = attr.strip().split(" ")
                attr_dict[key] = value.replace('"', '')
        return attr_dict

    def parse_file(file_handle, file_type):
        for line in file_handle:
            if line.startswith("#"):
                continue
            parts = line.strip().split("\t")
            feature_type = parts[2]
            attributes = parts[8]

            if file_type == "gff3":
                attr_dict = parse_attributes_gff3(attributes)
            elif file_type == "gtf":
                attr_dict = parse_attributes_gtf(attributes)

            if feature_type == "gene":
                gene_id = attr_dict["gene_id"]
                process_gene(parts, gene_id)
            elif feature_type == "exon":
                transcript_id = attr_dict["transcript_id"]
                gene_id = attr_dict["gene_id"]
                process_exon(parts, gene_id, transcript_id)

    open_func = gzip.open if annotation_file.endswith(".gz") else open
    file_type = "gff3" if ".gff3" in annotation_file else "gtf"

    with open_func(annotation_file, "rt") as f:
        parse_file(f, file_type)

    # Calculate intron regions based on exons
    for gene_id, transcripts in exon_regions.items():
        for transcript_id, exons in transcripts.items():
            if len(exons) == 1:
                continue
            exons_sorted = sorted(exons, key=lambda x: x[1])
            for i in range(1, len(exons_sorted)):
                intron_start = exons_sorted[i - 1][2] + 1
                intron_end = exons_sorted[i][1] - 1
                if intron_start < intron_end:
                    intron_regions[gene_id][transcript_id].append(
                        (exons_sorted[i - 1][0], intron_start, intron_end))  # 1-based, start-inclusive, end-inclusive

    return gene_regions, exon_regions, intron_regions

In [14]:
annotation_file = "/Users/nh661/postdoc/projs/rna_snps/gencode/gencode.v46.annotation.gtf.gz"
gene_regions, exon_regions, intron_regions = get_gene_regions(annotation_file)

In [15]:
exon_regions["ENSG00000187634.13"]

defaultdict(list,
            {'ENST00000616016.5': [('chr1', 923923, 924948),
              ('chr1', 925922, 926013),
              ('chr1', 930155, 930336),
              ('chr1', 931039, 931089),
              ('chr1', 935772, 935896),
              ('chr1', 939040, 939129),
              ('chr1', 939275, 939412),
              ('chr1', 941144, 941306),
              ('chr1', 942136, 942251),
              ('chr1', 942410, 942488),
              ('chr1', 942559, 943058),
              ('chr1', 943253, 943377),
              ('chr1', 943698, 943808),
              ('chr1', 943908, 944574)],
             'ENST00000618323.5': [('chr1', 923923, 924948),
              ('chr1', 925922, 926013),
              ('chr1', 930155, 930336),
              ('chr1', 931039, 931089),
              ('chr1', 935772, 935896),
              ('chr1', 939040, 939129),
              ('chr1', 939272, 939412),
              ('chr1', 941144, 941306),
              ('chr1', 942136, 942251),
              ('chr

In [16]:
intron_regions["ENSG00000187634.13"]

defaultdict(list,
            {'ENST00000616016.5': [('chr1', 924949, 925921),
              ('chr1', 926014, 930154),
              ('chr1', 930337, 931038),
              ('chr1', 931090, 935771),
              ('chr1', 935897, 939039),
              ('chr1', 939130, 939274),
              ('chr1', 939413, 941143),
              ('chr1', 941307, 942135),
              ('chr1', 942252, 942409),
              ('chr1', 942489, 942558),
              ('chr1', 943059, 943252),
              ('chr1', 943378, 943697),
              ('chr1', 943809, 943907)],
             'ENST00000618323.5': [('chr1', 924949, 925921),
              ('chr1', 926014, 930154),
              ('chr1', 930337, 931038),
              ('chr1', 931090, 935771),
              ('chr1', 935897, 939039),
              ('chr1', 939130, 939271),
              ('chr1', 939413, 941143),
              ('chr1', 941307, 942135),
              ('chr1', 942252, 942409),
              ('chr1', 942489, 942558),
              ('chr

In [32]:
def cluster_intervals(intervals, tolerance):
    # Step 2: Sort intervals by start, then by end
    intervals.sort(key=lambda x: (x[0], x[1]))

    clusters = []
    current_cluster = [intervals[0]]
    min_start, max_end = intervals[0]

    for i in range(1, len(intervals)):
        start_current, end_current = intervals[i]

        # Check if the current interval can be added to the current cluster
        if start_current <= max_end + tolerance and end_current >= min_start - tolerance:
            current_cluster.append(intervals[i])
            min_start = min(min_start, start_current)
            max_end = max(max_end, end_current)
        else:
            clusters.append(current_cluster)
            current_cluster = [intervals[i]]
            min_start, max_end = intervals[i]

    # Add the last cluster
    clusters.append(current_cluster)

    return clusters

# Example usage
intervals = [(1, 10), (2, 5), (3, 9)]
tolerance = 3
clusters = cluster_intervals(intervals, tolerance)
print(clusters)

[[(1, 10), (2, 5), (3, 9)]]
