# Transcript Data  
> This module retrieves genomic coordinates and other key information related to transcripts from a GTF file. Using **pyRanges**, it efficiently stores and visualizes transcript features like exons, CDS, and UTRs.  

The module also supports tasks like intron coordinate calculation, transcript length measurement, and batch queries. It includes optional sequence retrieval from a reference genome and tools for analyzing alternative splicing events. Caching is used to speed up repeated queries, making it ideal for working with large datasets.  


In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| default_exp transcript_data

In [1]:
#| export
import pandas as pd
from typing import List, Tuple, Dict, Optional
from functools import lru_cache
import pyranges as pr
from pyfaidx import Fasta
import logging
import re
from typing import Any, Dict, Union


In [2]:
# | export
class TranscriptData:
    """
    A class for managing transcript and gene information from a GTF file using PyRanges.

    Existing Features:
      - Lookup by transcript ID or gene ID/name
      - Support for exons, CDS, UTR queries
      - Intron coordinate calculation
      - Batch queries
      - Transcript length calculation
      - Caching/memoization for repeated queries
      - Basic logging/error handling

    NEW Features:
      (1) Nucleotide/protein sequence retrieval for CDS (with optional FASTA)
      (2) Alternative splicing analysis with splice junctions, isoform comparisons,
          and junction-chain interpretation.
    """

    def __init__(self, gtf_file: str, reference_fasta: Optional[str] = None):
        """
        Read the GTF file into a PyRanges object and store it.
        Optionally store a path to a reference FASTA for sequence methods.

        Args:
            gtf_file (str): Path to a GTF/GFF file.
            reference_fasta (str, optional): Path to a reference genome FASTA.
        """
        self.gtf_file = gtf_file
        self.reference_fasta = reference_fasta  # store for later use
        logging.info(f"Loading GTF from {gtf_file}. This may take a while...")
        try:
            self.gr = pr.read_gtf(gtf_file)
        except Exception as e:
            logging.error(f"Error reading GTF file: {e}")
            raise
        logging.info("GTF loaded successfully.")

    @lru_cache(maxsize=None)
    def get_exons(self, transcript_id: str) -> pr.PyRanges:
        """
        Returns a PyRanges of exons for the given transcript.
        Results are cached for faster repeat lookups.

        Args:
            transcript_id (str): The transcript ID to filter on.

        Returns:
            pr.PyRanges: PyRanges containing exon features for the transcript.
        """
        exons = self.gr[(self.gr.Feature == "exon") & (self.gr.transcript_id == transcript_id)]
        if len(exons) == 0:
            logging.warning(f"No exons found for transcript {transcript_id}.")
        return exons

    @lru_cache(maxsize=None)
    def get_cds(self, transcript_id: str) -> pr.PyRanges:
        """
        Returns a PyRanges of CDS features for the given transcript.
        Results are cached for faster repeat lookups.

        Args:
            transcript_id (str): The transcript ID to filter on.

        Returns:
            pr.PyRanges: PyRanges containing CDS features for the transcript.
        """
        cds = self.gr[(self.gr.Feature == "CDS") & (self.gr.transcript_id == transcript_id)]
        if len(cds) == 0:
            logging.warning(f"No CDS features found for transcript {transcript_id}.")
        return cds

    @lru_cache(maxsize=None)
    def get_utr(self, transcript_id: str, utr_type: str = None) -> pr.PyRanges:
        """
        Returns a PyRanges of UTR features for the given transcript.
        Optionally specify '5UTR' or '3UTR' to filter further.

        Args:
            transcript_id (str): The transcript ID to filter on.
            utr_type (str, optional): If '5UTR', return only 5' UTR;
                                      if '3UTR', return only 3' UTR;
                                      otherwise return all UTR features.

        Returns:
            pr.PyRanges: PyRanges containing UTR features for the transcript.
        """
        utr = self.gr[(self.gr.Feature.str.contains("UTR", na=False)) & (self.gr.transcript_id == transcript_id)]
        if utr_type == "5UTR":
            utr = utr[utr.Feature == "5UTR"]
        elif utr_type == "3UTR":
            utr = utr[utr.Feature == "3UTR"]

        if len(utr) == 0:
            logging.warning(f"No UTR features found for transcript {transcript_id} (type={utr_type}).")
        return utr

    def get_intron_ranges(self, transcript_id: str) -> pr.PyRanges:
        """
        Compute intron ranges by subtracting exons from the entire transcript range.

        Args:
            transcript_id (str): The transcript ID to filter on.

        Returns:
            pr.PyRanges: PyRanges containing intron coordinates.
        """
        exons = self.get_exons(transcript_id)
        if len(exons) == 0:
            return pr.PyRanges()

        df_exons = exons.df
        chrom = df_exons["Chromosome"].iloc[0]
        strand = df_exons["Strand"].iloc[0]

        start_min = df_exons["Start"].min()
        end_max = df_exons["End"].max()

        transcript_range = pr.PyRanges(
            pd.DataFrame({
                "Chromosome": [chrom],
                "Start": [start_min],
                "End": [end_max],
                "Strand": [strand]
            })
        )
        introns = transcript_range.subtract(exons)
        return introns

    def get_exon_coords_and_strand(self, transcript_id: str) -> Tuple[List[List[int]], Optional[int]]:
        """
        Return exon coordinates and strand (+1 or -1) for a given transcript ID,
        mimicking the style of the Ensembl API example (list of [end, start] pairs).

        Args:
            transcript_id (str): The transcript ID to query.

        Returns:
            (exon_coord, strand):
                exon_coord is a list of [end, start] pairs
                strand is +1 or -1
        """
        exons = self.get_exons(transcript_id)
        if len(exons) == 0:
            return ([], None)

        df = exons.df.sort_values(by="Start")
        strand_symbol = df["Strand"].iloc[0]  # '+' or '-'
        strand = 1 if strand_symbol == '+' else -1

        exon_coord = [[row.End, row.Start] for _, row in df.iterrows()]
        return (exon_coord, strand)

    def get_transcript_length(self, transcript_id: str) -> int:
        """
        Return the total length of exons for the given transcript.

        Args:
            transcript_id (str): The transcript ID to query.

        Returns:
            int: Sum of all exon lengths for this transcript.
        """
        exons = self.get_exons(transcript_id)
        if len(exons) == 0:
            return 0
        df = exons.df
        lengths = df["End"] - df["Start"]
        return lengths.sum()

    def get_chromosome(self, transcript_id: str) -> Optional[str]:
        """
        Return the chromosome/contig name for the given transcript.
        Assumes that all exons in this transcript are on the same chromosome.

        Args:
            transcript_id (str): The transcript ID to query.

        Returns:
            str or None: Chromosome name (e.g., 'chr1', '1', etc.) or None if not found.
        """
        exons = self.get_exons(transcript_id)
        if len(exons) == 0:
            return None
        df = exons.df
        return df["Chromosome"].iloc[0]

    def get_strand(self, transcript_id: str) -> Optional[int]:
        """
        Return +1 or -1 for the transcript's strand.

        Args:
            transcript_id (str): The transcript ID to query.

        Returns:
            int or None: 1 or -1, or None if not found.
        """
        exons = self.get_exons(transcript_id)
        if len(exons) == 0:
            return None
        df = exons.df
        strand_symbol = df["Strand"].iloc[0]
        return 1 if strand_symbol == '+' else -1

    def get_transcripts_by_gene_id(self, gene_id: str) -> List[str]:
        """
        Return a list of transcript IDs associated with a given gene_id.

        Args:
            gene_id (str): The gene ID to search for.

        Returns:
            List[str]: Transcript IDs for that gene.
        """
        df = self.gr.df
        subset = df[df.gene_id == gene_id]
        t_ids = subset["transcript_id"].dropna().unique()
        if len(t_ids) == 0:
            logging.warning(f"No transcripts found for gene ID {gene_id}.")
        return list(t_ids)

    def get_transcripts_by_gene_name(self, gene_name: str) -> List[str]:
        """
        Return a list of transcript IDs associated with a given gene_name.

        Args:
            gene_name (str): The gene name to search for (e.g. BRCA1).

        Returns:
            List[str]: Transcript IDs for that gene.
        """
        df = self.gr.df
        if "gene_name" not in df.columns:
            logging.warning("No 'gene_name' column in GTF; cannot filter by gene name.")
            return []
        subset = df[df.gene_name == gene_name]
        t_ids = subset["transcript_id"].dropna().unique()
        if len(t_ids) == 0:
            logging.warning(f"No transcripts found for gene name {gene_name}.")
        return list(t_ids)

    def get_exons_batch(self, transcript_ids: List[str]) -> Dict[str, pr.PyRanges]:
        """
        Return a dict of transcript_id -> exons PyRanges for a list of transcript IDs.
        Useful for batch queries.

        Args:
            transcript_ids (list of str): List of transcript IDs to fetch.

        Returns:
            dict: {transcript_id: PyRanges}
        """
        result = {}
        for tid in transcript_ids:
            result[tid] = self.get_exons(tid)
        return result

    def get_exon_coords_and_strand_batch(self, transcript_ids: List[str]) -> Dict[str, Tuple[List[List[int]], Optional[int]]]:
        """
        Return a dict of transcript_id -> (exon_coord, strand), for batch querying.

        Args:
            transcript_ids (list of str): Transcript IDs to fetch.

        Returns:
            dict: {transcript_id: ([ [end, start], ... ], strand) }
        """
        result = {}
        for tid in transcript_ids:
            result[tid] = self.get_exon_coords_and_strand(tid)
        return result

    # ---------------------------------------------------------------------
    # NEW FEATURE (1): Sequence Extraction
    # ---------------------------------------------------------------------

    def get_cds_sequence(self,
                         transcript_id: str,
                         reference_fasta: Optional[str] = None
                         ) -> Optional[str]:
        """
        Return the nucleotide sequence of the CDS for a given transcript.
        If reference_fasta is not provided, the method will use self.reference_fasta,
        or prompt the user if that is also None.

        Requires pyfaidx and a valid reference FASTA.

        Args:
            transcript_id (str): The transcript ID to query.
            reference_fasta (str, optional): Path to the reference genome FASTA file.

        Returns:
            str or None: Nucleotide sequence of the CDS, or None if no CDS found.
        """
        if Fasta is None:
            logging.error("pyfaidx is not installed. Cannot extract sequences.")
            return None

        # If no explicit FASTA was passed, fallback to the instance-level FASTA
        if reference_fasta is None:
            if self.reference_fasta:
                reference_fasta = self.reference_fasta
            else:
                reference_fasta = input("Please provide path to the reference genome FASTA file: ")

        cds = self.get_cds(transcript_id)
        if len(cds) == 0:
            logging.warning(f"No CDS found for transcript {transcript_id}.")
            return None

        # Open reference FASTA
        fa = Fasta(reference_fasta)

        # Sort CDS features by genomic position (important if multiple exons)
        df = cds.df.sort_values(by="Start")

        # Extract sequence pieces and concatenate
        seq_pieces = []
        for _, row in df.iterrows():
            chrom = row["Chromosome"]
            start = row["Start"]
            end = row["End"]
            strand = row["Strand"]

            piece = fa[chrom][start:end].seq
            if strand == "-":
                # Reverse complement if negative strand
                piece = self._revcomp(piece)
            seq_pieces.append(piece)

        return "".join(seq_pieces)

    def get_protein_sequence(self,
                             transcript_id: str,
                             reference_fasta: Optional[str] = None
                             ) -> Optional[str]:
        """
        Return the translated protein sequence (in one-letter code) for a given transcript’s CDS.
        If reference_fasta is not provided, the method will use self.reference_fasta,
        or prompt the user if that is also None.

        Requires pyfaidx and a valid reference FASTA.

        Args:
            transcript_id (str): The transcript ID to query.
            reference_fasta (str, optional): Path to the reference genome FASTA file.

        Returns:
            str or None: Amino acid sequence, or None if no CDS is found.
        """
        cds_seq = self.get_cds_sequence(transcript_id, reference_fasta)
        if cds_seq is None:
            return None

        codon_table = self._get_standard_codon_table()
        protein = []
        for i in range(0, len(cds_seq), 3):
            codon = cds_seq[i : i + 3]
            if len(codon) < 3:
                break  # incomplete codon
            aa = codon_table.get(codon, "X")  # unknown => 'X'
            if aa == "*":  # stop codon
                break
            protein.append(aa)
        return "".join(protein)

    def _revcomp(self, seq: str) -> str:
        """
        Return the reverse-complement of a nucleotide sequence.
        """
        complement = {
            "A": "T", "C": "G", "G": "C", "T": "A",
            "a": "t", "c": "g", "g": "c", "t": "a",
            "N": "N", "n": "n"
        }
        rev = []
        for base in reversed(seq):
            rev.append(complement.get(base, "N"))
        return "".join(rev)

    def _get_standard_codon_table(self) -> Dict[str, str]:
        """
        Return a minimal codon table mapping triplets to single-letter amino acids.
        Stop codon => '*'
        """
        return {
            "ATA":"I","ATC":"I","ATT":"I","ATG":"M","ACA":"T","ACC":"T","ACG":"T","ACT":"T",
            "AAC":"N","AAT":"N","AAA":"K","AAG":"K","AGC":"S","AGT":"S","AGA":"R","AGG":"R",
            "CTA":"L","CTC":"L","CTG":"L","CTT":"L","CCA":"P","CCC":"P","CCG":"P","CCT":"P",
            "CAC":"H","CAT":"H","CAA":"Q","CAG":"Q","CGA":"R","CGC":"R","CGG":"R","CGT":"R",
            "GTA":"V","GTC":"V","GTG":"V","GTT":"V","GCA":"A","GCC":"A","GCG":"A","GCT":"A",
            "GAC":"D","GAT":"D","GAA":"E","GAG":"E","GGA":"G","GGC":"G","GGG":"G","GGT":"G",
            "TCA":"S","TCC":"S","TCG":"S","TCT":"S","TTC":"F","TTT":"F","TTA":"L","TTG":"L",
            "TAC":"Y","TAT":"Y","TAA":"*","TAG":"*","TGC":"C","TGT":"C","TGA":"*","TGG":"W"
        }

    # ---------------------------------------------------------------------
    # NEW FEATURE (2): Alternative Splicing Analysis
    # ---------------------------------------------------------------------

    def get_splice_junctions(self, transcript_id: str) -> List[Tuple[int, int]]:
        """
        Return the genomic start/end positions for each splice junction
        (the exon-exon boundaries) for a given transcript.

        Args:
            transcript_id (str): The transcript ID to query.

        Returns:
            List[Tuple[int, int]]: List of (donor_site, acceptor_site) for each splice junction.
        """
        exons = self.get_exons(transcript_id)
        if len(exons) < 2:
            logging.info(f"Transcript {transcript_id} has fewer than 2 exons; no internal junctions.")
            return []

        df_exons = exons.df.sort_values(by="Start")
        junctions = []
        for i in range(len(df_exons) - 1):
            exon_end = df_exons.iloc[i]["End"]
            next_exon_start = df_exons.iloc[i + 1]["Start"]
            junctions.append((exon_end, next_exon_start))

        return junctions

    def compare_transcripts_across_gene(self, gene_id: str) -> pd.DataFrame:
        """
        Compare exons of all transcripts for a given gene.
        Returns a DataFrame of all exons grouped by transcript ID,
        so you can quickly see which exons are shared or unique across isoforms.

        Args:
            gene_id (str): The gene ID to compare.

        Returns:
            pd.DataFrame: A dataframe with columns
                          [transcript_id, Chromosome, Start, End, Strand].
        """
        tid_list = self.get_transcripts_by_gene_id(gene_id)
        if not tid_list:
            logging.warning(f"No transcripts found for gene {gene_id}.")
            return pd.DataFrame()

        all_exons = []
        for tid in tid_list:
            exons = self.get_exons(tid)
            if len(exons) == 0:
                continue
            df_exons = exons.df.copy()
            df_exons["transcript_id"] = tid
            all_exons.append(
                df_exons[["transcript_id", "Chromosome", "Start", "End", "Strand"]]
            )

        if not all_exons:
            return pd.DataFrame()

        result = pd.concat(all_exons, ignore_index=True)
        result.sort_values(by=["transcript_id", "Start"], inplace=True)
        return result

    # ---------------------------------------------------------------------
    # Junction Chain Interpretation
    # ---------------------------------------------------------------------

    def get_junction_chain_signature(self, transcript_id: str) -> Optional[Tuple[Tuple[int, int], ...]]:
        """
        Return a tuple of (exon_end, next_exon_start) pairs for each splice junction
        in the given transcript. This provides a 'signature' to compare across
        transcripts to see if they have the same junction chain.

        Args:
            transcript_id (str): The transcript ID to query.

        Returns:
            tuple of (int, int) or None:
                A tuple of (end_of_exon_i, start_of_exon_(i+1)) for i in [0..n_exons-2].
                Returns None if fewer than 2 exons (no internal junctions).
        """
        exons = self.get_exons(transcript_id)
        if len(exons) < 2:
            logging.info(f"Transcript {transcript_id} has fewer than 2 exons; no junction chain.")
            return None

        df_exons = exons.df.sort_values(by="Start")
        junctions = []
        for i in range(len(df_exons) - 1):
            exon_end = df_exons.iloc[i]["End"]
            next_exon_start = df_exons.iloc[i + 1]["Start"]
            junctions.append((exon_end, next_exon_start))

        return tuple(junctions)

    def interpret_unique_junction_chains(self, gene_id: str) -> pd.DataFrame:
        """
        Group transcripts of a given gene by their unique junction chain signatures.
        This helps identify which isoforms share the exact same exon-exon boundaries
        and which are unique.

        Args:
            gene_id (str): The gene ID to analyze.

        Returns:
            pd.DataFrame:
                Columns:
                  - 'junction_chain_signature': The tuple of (exon_end, next_exon_start) pairs.
                  - 'transcript_count': How many transcripts share this chain.
                  - 'transcripts': A list of transcript IDs that have this chain.
        """
        tid_list = self.get_transcripts_by_gene_id(gene_id)
        if not tid_list:
            logging.warning(f"No transcripts found for gene {gene_id}.")
            return pd.DataFrame()

        chain_map = {}  # {chain_signature: [transcript_1, transcript_2, ...]}

        for tid in tid_list:
            signature = self.get_junction_chain_signature(tid)
            if signature is None:  # e.g., single-exon transcripts
                continue
            chain_map.setdefault(signature, []).append(tid)

        rows = []
        for chain_sig, transcripts in chain_map.items():
            rows.append({
                "junction_chain_signature": chain_sig,
                "transcript_count": len(transcripts),
                "transcripts": transcripts
            })

        df = pd.DataFrame(rows)
        df.sort_values("transcript_count", ascending=False, inplace=True)
        df.reset_index(drop=True, inplace=True)
        return df
    
    def get_gene_names_for_transcripts(self, transcript_ids: List[str], ignore_after_period: bool = True, alternative_column: Optional[str] = None) -> List[Optional[str]]:
        """
        Given a list of transcript IDs, return a list of the same length
        where each element is the corresponding gene name or alternative column value from the GTF.
        If a transcript is not found or if the target column is not available in the GTF,
        the result will contain None for that transcript.

        Args:
            transcript_ids (List[str]): A list of transcript IDs.
            ignore_after_period (bool): If True, strip the version suffix after the period.
            alternative_column (Optional[str]): If provided, use this column in place of 'gene_name'.

        Returns:
            List[Optional[str]]: A parallel list of gene names (or alternative column values) or None.
        """
        # Optionally strip the version suffix using regex
        if ignore_after_period:
            transcript_ids = [re.sub(r"\.\d+$", "", tid) for tid in transcript_ids]

        df = self.gr.df
        target_column = alternative_column if alternative_column is not None else "gene_name"
        
        if target_column not in df.columns:
            logging.warning(f"No '{target_column}' column in GTF; cannot retrieve gene names.")
            return [None] * len(transcript_ids)

        # Filter to only the rows with the requested transcript IDs
        subset = df[df.transcript_id.isin(transcript_ids)]

        # Build a dict: transcript_id -> list of unique values from the target column in the annotation
        mapping = (
            subset
            .groupby("transcript_id")[target_column]
            .apply(lambda x: list(x.unique()))
            .to_dict()
        )

        # For each transcript in the input, pick the first value from the mapping
        result = []
        for tid in transcript_ids:
            possible_names = mapping.get(tid, [])
            if possible_names:
                result.append(possible_names[0])
            else:
                result.append(None)

        return result
    
    def get_transcript_info(self, transcript_id: str) -> Dict[str, Any]:
        """
        Return a dictionary with basic info about the given transcript, including:
        - transcript_id
        - transcript_name (if available in the GTF, else "unknown")
        - transcript_type (if available in the GTF, else "unknown")
        - cds_start, cds_end (based on min/max of CDS ranges if present, else None)
        - chromosome
        - strand (either '+' or '-')
        """
        df = self.gr.df
        sub = df[df.transcript_id == transcript_id]

        # If we didn't find this transcript at all, return an empty dict or raise an error.
        if sub.empty:
            logging.warning(f"Transcript {transcript_id} not found in GTF.")
            return {}

        # Pull transcript_name, transcript_type from columns if they exist
        # (these column names vary in different GTF sources).
        if "transcript_name" in sub.columns:
            transcript_name = sub["transcript_name"].dropna().unique()
            if len(transcript_name) > 0:
                transcript_name = transcript_name[0]
            else:
                transcript_name = "unknown"
        else:
            transcript_name = "unknown"
        
        if "transcript_type" in sub.columns:
            transcript_type = sub["transcript_type"].dropna().unique()
            if len(transcript_type) > 0:
                transcript_type = transcript_type[0]
            else:
                transcript_type = "unknown"
        else:
            transcript_type = "unknown"

        # Derive chromosome and strand from any row of this transcript
        # (assuming a consistent chromosome/strand for all features).
        chromosome = str(sub["Chromosome"].iloc[0])
        strand_symbol = str(sub["Strand"].iloc[0])  # '+' or '-'

        # Compute the CDS boundaries using our existing get_cds() method.
        cds_ranges = self.get_cds(transcript_id)
        if len(cds_ranges) > 0:
            cds_df = cds_ranges.df
            cds_start = int(cds_df["Start"].min())
            cds_end = int(cds_df["End"].max())
        else:
            cds_start = None
            cds_end = None

        return {
            "transcript_id": transcript_id,
            "transcript_name": transcript_name,
            "transcript_type": transcript_type,
            "cds_start": cds_start,
            "cds_end": cds_end,
            "chromosome": chromosome,
            "strand": strand_symbol
        }
        
    def get_exon_psi_matrix(self,
                            gene_name: Optional[str] = None,
                            transcript_ids: Optional[List[str]] = None,
                            transcript_counts: Union[Dict[str, float], pd.DataFrame] = None
                            ) -> pd.DataFrame:
        """
        Compute an exon PSI (percent spliced in) matrix for a gene or a given list of transcript IDs.
        For each unique exon (defined by Chromosome, Start, End, and Strand) among the transcripts,
        the PSI is calculated as:
        
            PSI = (sum of counts for transcripts including the exon) / (total counts for all transcripts)
        
        Args:
            gene_name (str, optional): If provided, transcripts for this gene are retrieved.
            transcript_ids (List[str], optional): List of transcript IDs.
                Ignored if gene_name is provided.
            transcript_counts (dict or pd.DataFrame): Transcript-level counts.
                If a dict is provided, it is assumed to map transcript_id -> count (single-sample).
                If a DataFrame is provided, its index should be transcript_ids and its columns
                represent different samples.
        
        Returns:
            pd.DataFrame: A DataFrame where each row corresponds to a unique exon, with columns:
                - 'Chromosome', 'Start', 'End', 'Strand'
                - For a dict input: a column 'psi' (a value in [0,1])
                - For a DataFrame input: one column per sample (named 'psi_{sample}')
                - 'included_transcripts': the list of transcripts that include this exon.
        """
        if gene_name is not None:
            transcript_ids = self.get_transcripts_by_gene_name(gene_name)
            if not transcript_ids:
                logging.warning(f"No transcripts found for gene name {gene_name}.")
                return pd.DataFrame()
        elif transcript_ids is None:
            logging.error("Either gene_name or transcript_ids must be provided.")
            return pd.DataFrame()
        
        # Process transcript_counts
        if transcript_counts is None:
            logging.error("You must provide transcript_counts for PSI computation.")
            return pd.DataFrame()
            
        is_multi_sample = isinstance(transcript_counts, pd.DataFrame)
        if not is_multi_sample:
            if isinstance(transcript_counts, dict):
                counts_series = pd.Series(transcript_counts)
            else:
                logging.error("transcript_counts must be a dict or a DataFrame.")
                return pd.DataFrame()
            # Keep only transcripts in the provided list
            counts_series = counts_series[counts_series.index.isin(transcript_ids)]
            total_counts = counts_series.sum()
        else:
            counts_df = transcript_counts.loc[transcript_counts.index.intersection(transcript_ids)]
            if counts_df.empty:
                logging.error("No transcript counts found for the provided transcript IDs.")
                return pd.DataFrame()
            total_counts = counts_df.sum()  # Series: sample -> total count
        
        # Build a mapping from unique exon coordinates to the set of transcript IDs that include it.
        exon_mapping = {}
        for tid in transcript_ids:
            try:
                exons = self.get_exons(tid)
            except Exception as e:
                logging.warning(f"Error retrieving exons for transcript {tid}: {e}")
                continue
            if len(exons) == 0:
                continue
            # Iterate over each exon in this transcript
            for _, row in exons.df.iterrows():
                key = (row["Chromosome"], row["Start"], row["End"], row["Strand"])
                exon_mapping.setdefault(key, set()).add(tid)
        
        # For each unique exon, compute its PSI value(s)
        rows = []
        for exon_key, tid_set in exon_mapping.items():
            row_data = {"Chromosome": exon_key[0],
                        "Start": exon_key[1],
                        "End": exon_key[2],
                        "Strand": exon_key[3]}
            if is_multi_sample:
                # Sum counts over transcripts that include this exon for each sample
                included_counts = counts_df.loc[counts_df.index.intersection(list(tid_set))].sum()
                psi_dict = {}
                for sample in counts_df.columns:
                    tot = total_counts[sample]
                    psi = included_counts[sample] / tot if tot != 0 else np.nan
                    psi_dict[f"psi_{sample}"] = psi
                row_data.update(psi_dict)
            else:
                included_counts = counts_series[counts_series.index.isin(list(tid_set))].sum()
                psi = included_counts / total_counts if total_counts != 0 else np.nan
                row_data["psi"] = psi
            row_data["included_transcripts"] = list(tid_set)
            rows.append(row_data)
        
        psi_df = pd.DataFrame(rows)
        psi_df.sort_values(["Chromosome", "Start"], inplace=True)
        return psi_df





In [3]:
import os
import urllib.request
from pathlib import Path

# Example Ensembl URLs for mouse GRCm39 (release 109)
gtf_url = "ftp://ftp.ensembl.org/pub/release-109/gtf/mus_musculus/Mus_musculus.GRCm39.109.gtf.gz"
fasta_url = "ftp://ftp.ensembl.org/pub/release-109/fasta/mus_musculus/dna/Mus_musculus.GRCm39.dna.primary_assembly.fa.gz"

# Store data one directory back
data_dir = Path("..") / "data"
data_dir.mkdir(parents=True, exist_ok=True)

gtf_file_local = data_dir / "Mus_musculus.GRCm39.109.gtf.gz"
fasta_file_local = data_dir / "Mus_musculus.GRCm39.dna.primary_assembly.fa.gz"

# Download if not already present
if not gtf_file_local.is_file():
    print(f"Downloading {gtf_url}...")
    urllib.request.urlretrieve(gtf_url, gtf_file_local)

if not fasta_file_local.is_file():
    print(f"Downloading {fasta_url}...")
    urllib.request.urlretrieve(fasta_url, fasta_file_local)

# Instantiate your TranscriptData
td = TranscriptData(
    gtf_file=gtf_file_local,
    reference_fasta=fasta_file_local
)

# Now you can make queries like:
example_transcript_id = "ENSMUST00000070533"  # e.g., for mouse
exons = td.get_exons(example_transcript_id)
print("Exons:", exons)


Exons: +--------------+----------------+------------+-----------+-------+
|   Chromosome | Source         | Feature    |     Start | +22   |
|   (category) | (object)       | (object)   |   (int64) | ...   |
|--------------+----------------+------------+-----------+-------|
|            1 | ensembl_havana | exon       |   3740774 | ...   |
|            1 | ensembl_havana | exon       |   3491924 | ...   |
|            1 | ensembl_havana | exon       |   3284704 | ...   |
+--------------+----------------+------------+-----------+-------+
Stranded PyRanges object has 3 rows and 26 columns from 1 chromosomes.
For printing, the PyRanges was sorted on Chromosome and Strand.
22 hidden columns: End, Score, Strand, Frame, gene_id, gene_version, ... (+ 16 more.)


In [4]:
ranges = pr.read_gtf(gtf_file_local)

In [5]:
ranges.columns

Index(['Chromosome', 'Source', 'Feature', 'Start', 'End', 'Score', 'Strand',
       'Frame', 'gene_id', 'gene_version', 'gene_name', 'gene_source',
       'gene_biotype', 'transcript_id', 'transcript_version',
       'transcript_name', 'transcript_source', 'transcript_biotype', 'tag',
       'transcript_support_level', 'exon_number', 'exon_id', 'exon_version',
       'ccds_id', 'protein_id', 'protein_version'],
      dtype='object')

In [6]:
ranges.as_df()['transcript_id']

0                         NaN
1          ENSMUST00000194081
2          ENSMUST00000194081
3                         NaN
4          ENSMUST00000194393
                  ...        
1901233    ENSMUST00000189418
1901234    ENSMUST00000189418
1901235                   NaN
1901236    ENSMUST00000186353
1901237    ENSMUST00000186353
Name: transcript_id, Length: 1901238, dtype: object

In [7]:
from allos.readers_tests import *
mouse_data = process_mouse_data()


🔎 Looking for file at: /data/analysis/data_mcandrew/Allos_new/allos_env/lib/python3.9/site-packages/allos/resources/e18.mouse.clusters.csv
✅ File found at: /data/analysis/data_mcandrew/Allos_new/allos_env/lib/python3.9/site-packages/allos/resources/e18.mouse.clusters.csv
✅ File already exists at: /data/analysis/data_mcandrew/Allos_new/allos_env/lib/python3.9/site-packages/allos/resources/data/mouse_1.txt.gz

🔄 Decompressing /data/analysis/data_mcandrew/Allos_new/allos_env/lib/python3.9/site-packages/allos/resources/data/mouse_1.txt.gz to /data/analysis/data_mcandrew/Allos_new/allos_env/lib/python3.9/site-packages/allos/resources/data/mouse_1.txt...
✅ Decompression complete.
Test data (mouse_1) downloaded successfully
✅ File already exists at: /data/analysis/data_mcandrew/Allos_new/allos_env/lib/python3.9/site-packages/allos/resources/data/mouse_2.txt.gz

🔄 Decompressing /data/analysis/data_mcandrew/Allos_new/allos_env/lib/python3.9/site-packages/allos/resources/data/mouse_2.txt.gz to 

  utils.warn_names_duplicates("obs")


In [8]:
transcriptIds = mouse_data.var.index.to_list()

In [9]:
transcriptIds[:10]

['ENSMUST00000156717.1',
 'ENSMUST00000212520.1',
 'ENSMUST00000025798.12',
 'ENSMUST00000231280.1',
 'ENSMUST00000039286.4',
 'ENSMUST00000144552.7',
 'ENSMUST00000112304.8',
 'ENSMUST00000162041.7',
 'ENSMUST00000053506.6',
 'ENSMUST00000028207.12']

In [10]:
gene_names = td.get_gene_names_for_transcripts(transcript_ids=transcriptIds)

In [11]:
gene_names[:10]

['Klc2',
 'Capn15',
 'Klc2',
 'Eva1c',
 'Atg5',
 'Znhit3',
 'Ppm1b',
 'Gcc2',
 'Bbs1',
 'Crat']

In [12]:
#| hide 
from nbdev.showdoc import *

In [13]:
#| hide
import nbdev; nbdev.nbdev_export()