## Helper functions

In [1]:
import altair as alt
import pandas as pd
from typing import Iterable, Optional
from Bio import Entrez, SeqIO
from Bio.Seq import Seq
from Bio.SeqFeature import CompoundLocation
import time
import altair as alt
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output
from typing import Iterable, Optional, Dict, Any
import altair as alt
import pandas as pd
from typing import Iterable, Optional

alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [2]:
# Be nice to NCBI
Entrez.email = "you@example.com"      # <-- put your email
# Entrez.api_key = "YOUR_NCBI_API_KEY" # optional but recommended

# IUPAC complement map (handles ambiguity)
_COMP_MAP = str.maketrans("ACGTRYMKBDHVNacgtrymkbdhvn", "TGCAYRKMVHDBNtgcayrkmvhdbn")

def fetch_genbank_record(accession, tries=5, sleep_s=0.4):
    """
    Fetch a GenBank record (nuccore) by accession using NCBI Entrez and parse it with Biopython.

    Parameters
    ----------
    accession : str
        GenBank/RefSeq accession (e.g., "NC_001498.1" or "AB016162.1").
    tries : int, optional
        Maximum number of retry attempts on transient errors (default: 5).
    sleep_s : float, optional
        Base sleep (in seconds) between retries; increases linearly per attempt (default: 0.4).

    Returns
    -------
    Bio.SeqRecord.SeqRecord
        Parsed GenBank record containing sequence and feature annotations.

    Raises
    ------
    RuntimeError
        If fetching or parsing fails after all retry attempts.

    Notes
    -----
    - You should set `Entrez.email` (and optionally `Entrez.api_key`) before calling to comply with NCBI usage policy.
    - This function uses `rettype="gb"` and `retmode="text"` and parses via `SeqIO.read(..., "gb")`.
    - Backoff strategy is linear: `sleep_s * (attempt_index + 1)`.

    Examples
    --------
    >>> Entrez.email = "me@example.com"
    >>> rec = fetch_genbank_record("NC_001498.1")
    >>> rec.id
    'NC_001498.1'
    """
    last_err = None
    for i in range(tries):
        try:
            with Entrez.efetch(db="nuccore", id=accession, rettype="gb", retmode="text") as fh:
                return SeqIO.read(fh, "gb")
        except Exception as e:
            last_err = e
            time.sleep(sleep_s * (i + 1))
    raise RuntimeError(f"Failed to fetch {accession} after {tries} tries: {last_err}")

def _segments_1based(loc):
    """
    Convert a Biopython feature location into 1-based inclusive genomic segments.

    Parameters
    ----------
    loc : Bio.SeqFeature.FeatureLocation or Bio.SeqFeature.CompoundLocation
        The location of a feature (e.g., `CDS.location`). May be a single interval
        (`FeatureLocation`) or a spliced/joined location (`CompoundLocation`).

    Returns
    -------
    list[tuple[int, int]]
        A list of (start, end) tuples representing genomic segments in **ascending
        genomic order**, using **1-based inclusive** coordinates.

    Notes
    -----
    - Biopython `FeatureLocation` uses 0-based, end-exclusive coordinates internally.
      This helper converts them to 1-based, end-inclusive coordinates common in GenBank flat files.
    - For joined locations, each part is converted independently and then sorted by start.

    Examples
    --------
    >>> segs = _segments_1based(feature.location)
    >>> segs[:3]
    [(123, 245), (456, 789), (900, 1020)]
    """
    parts = list(loc.parts) if isinstance(loc, CompoundLocation) else [loc]
    segs = [(int(p.start) + 1, int(p.end)) for p in parts]  # GenBank: 0-based, end-excl -> 1-based inclusive
    segs.sort()
    return segs

def _coding_walk(segs, strand):
    """
    Produce the sequence of genomic positions (1-based) in **5′→3′ coding order**.

    Parameters
    ----------
    segs : list[tuple[int, int]]
        Genomic segments as returned by `_segments_1based`, i.e., (start, end) 1-based inclusive,
        sorted by ascending genomic coordinate.
    strand : int
        Strand of the feature: `1` (plus/forward), `-1` (minus/reverse), or `0/None` (unknown).

    Returns
    -------
    list[int]
        A flattened list of genomic positions in **coding order**:
        - On the plus strand: increasing within each segment, segments traversed from left to right.
        - On the minus strand: decreasing within each segment, segments traversed from right to left.
        - If `strand` is unknown, falls back to ascending genomic order.

    Examples
    --------
    >>> _coding_walk([(100,102),(200,201)], strand=1)
    [100, 101, 102, 200, 201]
    >>> _coding_walk([(100,102),(200,201)], strand=-1)
    [201, 200, 102, 101, 100]
    """
    if strand == 1:
        pos = []
        for s, e in segs:
            pos.extend(range(s, e + 1))
        return pos
    elif strand == -1:
        pos = []
        for s, e in sorted(segs, reverse=True):
            pos.extend(range(e, s - 1, -1))
        return pos
    else:
        pos = []
        for s, e in segs:
            pos.extend(range(s, e + 1))
        return pos

def _coding_bases(genome_seq, positions, strand):
    """
    Extract bases in **coding-strand orientation** for a list of genomic positions.

    Parameters
    ----------
    genome_seq : Bio.Seq.Seq or sequence-like
        The full genomic sequence for the record (e.g., `record.seq`).
    positions : list[int]
        1-based genomic positions to extract, typically from `_coding_walk`.
    strand : int
        Strand of the feature: `1` (plus) or `-1` (minus). If `-1`, bases are complemented.

    Returns
    -------
    list[str]
        Uppercase single-letter nucleotide strings in **5′→3′ coding order** for the feature:
        - Plus strand: the bases at `positions` as-is.
        - Minus strand: the **complement** of the forward-strand bases at `positions`.

    Notes
    -----
    - This function complements (not reverse-complements) on the minus strand because
      `positions` produced by `_coding_walk` are already in coding order (i.e., reverse traversal
      of genomic coordinates for minus-strand features). Therefore only complementation is needed.

    Examples
    --------
    >>> _coding_bases(record.seq, [100,101,102], strand=1)
    ['A','T','G']
    >>> _coding_bases(record.seq, [200,199,198], strand=-1)
    ['C','A','T']  # complements of forward-strand bases at those loci
    """
    if strand == -1:
        return [str(genome_seq[p - 1]).translate(_COMP_MAP) for p in positions]
    else:
        return [str(genome_seq[p - 1]) for p in positions]

def _aa_from_codon(codon, transl_table=1):
    """
    Translate a single codon (3-mer) into a one-letter amino acid symbol.

    Parameters
    ----------
    codon : str
        A 3-character nucleotide string representing a codon on the **coding strand**
        (5′→3′). Ambiguity codes (e.g., N, R, Y) are tolerated; Biopython will attempt
        a best-effort translation.
    transl_table : int, optional
        NCBI genetic code table ID (default: 1, the Standard Code). If a provided table
        is invalid for the sequence, the function falls back to Biopython’s default.

    Returns
    -------
    str
        A single-character amino acid code (e.g., 'M', 'L', '*'), where '*' denotes stop.

    Notes
    -----
    - Uses `Bio.Seq.Seq.translate` with `to_stop=False`.
    - If translation raises (e.g., due to an unexpected table), a fallback call without
      explicit `table` is used.

    Examples
    --------
    >>> _aa_from_codon("ATG")
    'M'
    >>> _aa_from_codon("TAA")
    '*'
    """
    try:
        return str(Seq(codon).translate(table=transl_table, to_stop=False))
    except Exception:
        return str(Seq(codon).translate(to_stop=False))


In [3]:
def cds_position_table(record) -> pd.DataFrame:
    """
    Build a long-form table with **one row per (CDS × covered genomic position)** for a
    GenBank/RefSeq record, including the **nucleotide at each position** on the coding strand.

    For every CDS feature in `record.features`, this function walks the feature in **5′→3′
    coding order** (respecting splicing and strand), and emits per-nucleotide context:
    - `genome_pos` (1-based genomic coordinate),
    - `cds_nt_index` (1-based index along the CDS),
    - the **coding-strand nucleotide** at that position (`nt`),
    - the codon triplet on the coding strand (`codon`),
    - its translated amino acid (`aa`),
    - and the nucleotide’s **codon position** within the triplet (`codon_pos ∈ {1,2,3}`).

    Where a complete codon is not available (e.g., before the first full codon due to a
    `codon_start` offset, or at an incomplete trailing codon), `codon` and `aa` are `None`
    and `codon_pos` is `NaN`. The **`nt` column is always present** (it is the base at the
    current CDS position on the coding strand).

    Parameters
    ----------
    record : Bio.SeqRecord.SeqRecord
        A parsed GenBank/RefSeq record (e.g., from `SeqIO.read(..., "gb")`) that includes:
        - `record.seq`: the genomic sequence.
        - `record.features`: annotations with one or more CDS features. Each CDS is expected
          to carry standard qualifiers such as `gene`, `locus_tag`, `protein_id`, `product`,
          optionally `transl_table` (NCBI genetic code, default 1) and `codon_start` (default 1).

    Returns
    -------
    pandas.DataFrame
        Sorted by `feature_index` then `cds_nt_index`, with columns:
        - **accession** (`str`): `record.id` (e.g., `"NC_001498.1"`).
        - **feature_index** (`int`): 1-based index of the CDS among all CDS features in the record.
        - **gene** (`str`), **locus_tag** (`str`), **protein_id** (`str`), **product** (`str`).
        - **strand** (`int`): +1 (plus), −1 (minus), or 0 if unknown.
        - **genome_pos** (`int`): 1-based genomic coordinate covered by the CDS (in coding order).
        - **cds_nt_index** (`int`): 1-based nucleotide index along the CDS (coding order).
        - **codon_index** (`int|None`): 1-based codon index along the CDS; `None` if incomplete.
        - **codon_pos** (`int|NaN`): {1,2,3}; `NaN` if no complete codon context.
        - **nt** (`str`): **single-letter nucleotide on the coding strand** at this position
          (uppercase). On minus-strand CDS this is the **complement** of the forward-strand
          genomic base; positions are already in coding order, so no reverse step is needed.
        - **codon** (`str|None`): three-letter codon (coding strand, uppercase) or `None`.
        - **aa** (`str|None`): one-letter amino acid (using the CDS’s `transl_table` if present;
          `'*'` denotes stop) or `None`.

    Notes
    -----
    - **Coordinate conventions:** genomic segments are converted to **1-based, end-inclusive**
      coordinates from Biopython’s internal 0-based, end-exclusive coordinates.
    - **Strand handling:** `nt` and `codon` are always in **coding-strand 5′→3′** orientation.
      For minus-strand CDS, we complement forward-strand bases and traverse genomic positions
      in reverse so that the emitted sequence is in coding order.
    - **Frame offset:** leading nucleotides before the first full codon (due to `codon_start=2/3`)
      have `codon=None`, `aa=None`, `codon_pos=NaN` but still have a valid `nt`.

    Complexity
    ----------
    Linear in the total number of CDS-covered nucleotides. For large records, consider
    filtering to specific CDS before calling.

    Examples
    --------
    >>> # rec = fetch_genbank_record("AB016162.1")
    >>> tbl = cds_position_table(rec)
    >>> tbl.query("gene == 'P'")[["genome_pos","cds_nt_index","nt","codon","aa"]].head()
    """
    all_rows = []
    genome_seq = record.seq
    cds_feats = [f for f in record.features if f.type == "CDS"]

    for idx, feat in enumerate(cds_feats, start=1):
        segs = _segments_1based(feat.location)
        strand = feat.location.strand or 0
        gene = ";".join(feat.qualifiers.get("gene", []))
        locus_tag = ";".join(feat.qualifiers.get("locus_tag", []))
        protein_id = ";".join(feat.qualifiers.get("protein_id", []))
        product = ";".join(feat.qualifiers.get("product", []))
        transl_table = int(feat.qualifiers.get("transl_table", ["1"])[0])
        frame_offset = int(feat.qualifiers.get("codon_start", ["1"])[0]) - 1
        if frame_offset not in (0, 1, 2):
            frame_offset = 0

        positions = _coding_walk(segs, strand)                 # coding-order genomic coords (1-based)
        bases = _coding_bases(genome_seq, positions, strand)   # coding-strand bases aligned to positions
        L = len(positions)

        for i in range(L):
            # nucleotide at this CDS position (coding-strand base)
            nt = str(bases[i]).upper()

            idx_adj = i - frame_offset
            if idx_adj < 0:
                codon_pos = float("nan"); codon = None; aa = None; codon_idx = None
            else:
                codon_pos = (idx_adj % 3) + 1
                codon_start_i = i - ((codon_pos - 1) + frame_offset)
                if 0 <= codon_start_i <= L - 3:
                    triplet = bases[codon_start_i:codon_start_i + 3]
                    codon = "".join(triplet).upper()
                    aa = _aa_from_codon(codon, transl_table)
                    codon_idx = (idx_adj // 3) + 1
                else:
                    codon = None; aa = None; codon_idx = None

            all_rows.append({
                "accession": record.id,
                "feature_index": idx,
                "gene": gene,
                "locus_tag": locus_tag,
                "protein_id": protein_id,
                "product": product,
                "strand": strand,
                "genome_pos": positions[i],   # 1-based genomic coordinate
                "cds_nt_index": i + 1,        # 1-based along CDS
                "codon_index": codon_idx,     # 1-based codon index in the CDS
                "codon_pos": codon_pos,       # 1/2/3 (NaN near edges)
                "nt": nt,                     # coding-strand nucleotide at this position
                "codon": codon,               # coding-strand 5'->3'
                "aa": aa,
            })

    df = pd.DataFrame(all_rows).sort_values(
        ["feature_index", "cds_nt_index"], ignore_index=True
    )
    return df


In [4]:
def _featureindex_to_transl_table(record):
    """
    Build a mapping from **CDS feature index → translation table (transl_table)**.

    This helper iterates over `record.features` in order, counts only features with
    `feat.type == "CDS"`, and assigns each CDS a 1-based **feature_index** matching the
    enumeration used elsewhere (e.g., in `cds_position_table`). For each CDS, it reads the
    GenBank qualifier `transl_table` (NCBI genetic code ID). If the qualifier is absent,
    it defaults to **1** (the Standard Code).

    Parameters
    ----------
    record : Bio.SeqRecord.SeqRecord
        A parsed GenBank/RefSeq record (e.g., from `SeqIO.read(..., "gb")`) that has
        a list of annotated features in `record.features`. Only CDS features are considered.

    Returns
    -------
    dict[int, int]
        A dictionary mapping `feature_index` (1-based integer) to the translation table
        ID (integer), e.g. `{1: 1, 2: 11, 3: 1, ...}`.

    Notes
    -----
    - **Ordering matters**: indices are assigned in the exact order the CDS features
      appear in `record.features`. This is the same convention assumed by
      `cds_position_table` and the variant annotation pipeline.
    - If a CDS lacks a `transl_table` qualifier, this function records **1**.
    - This function **does not** inspect sub-features or infer context; it purely
      reads qualifiers from each CDS.

    Examples
    --------
    >>> rec = fetch_genbank_record("NC_001498.1")
    >>> _featureindex_to_transl_table(rec)
    {1: 1, 2: 1, 3: 1}
    """
    m = {}
    idx = 0
    for feat in record.features:
        if feat.type != "CDS":
            continue
        idx += 1
        m[idx] = int(feat.qualifiers.get("transl_table", ["1"])[0])
    return m


def annotate_variants(df_pos: pd.DataFrame,
                      variants_df: pd.DataFrame,
                      record,
                      collapse_overlaps: bool = False,
                      list_join: str = "|") -> pd.DataFrame:
    """
    Annotate single-nucleotide variants (SNVs) as **synonymous** or **nonsynonymous** with
    codon/AA context, and optionally collapse overlapping CDS effects to one row per variant.

    The function performs four main steps:

    1) **Reference sanity check** (forward strand):
       - Compares each variant's `REF` to the reference base at (`CHROM`, `POS`) in `record.seq`.
       - Adds columns to `variants_df` (carried into the output):
         * `ref_genome_base`: forward-strand reference base at `POS` (or `None` if `CHROM != record.id`)
         * `ref_match`: `True` if `REF` matches `ref_genome_base` for SNVs; `None` otherwise
         * `is_snv`: `True` for single-base REF/ALT, else `False`
         * `note`: `"non_SNV"` for non-SNV alleles; `"REF!=reference"` if mismatch; `""` otherwise

    2) **Join to CDS positions**:
       - Left-joins `variants_df` to `df_pos` on (`CHROM`, `POS`) ↔ (`accession`, `genome_pos`).
       - `df_pos` is expected to come from `cds_position_table(record)` and therefore includes,
         for each (CDS × position), `feature_index`, `strand`, `codon_index`, `codon_pos`,
         `cds_nt_index`, `codon`, and `aa` (all in coding orientation).

    3) **Per-CDS effect calling (SNVs only)**:
       - For each joined row inside a CDS:
         * Determines the **ALT base on the coding strand**: if `strand == -1`, complements ALT
           using the global `_COMP_MAP`; plus strand uses ALT as-is.
         * Substitutes the ALT into the `ref_codon` at the (0-based) index `codon_pos - 1`.
         * Translates both codons with the appropriate genetic code. The code is looked up via
           `_featureindex_to_transl_table(record)`; if absent, defaults to 1.
         * Sets `effect = "synonymous"` if AA unchanged else `"nonsynonymous"`.
       - If the position lacks a complete codon context (e.g., leading `codon_start` offset or
         truncated trailing codon), sets `effect = "cds-edge"` and leaves new codon/AA as `None`.
       - Noncoding or unmatched rows are labeled `effect = "noncoding"`.
       - Non-SNV alleles are labeled `effect = "unsupported_variant"` (indels/multiallelic not handled here).

    4) **Optional collapse across overlapping CDS** (`collapse_overlaps=True`):
       - Groups rows by the original **variant keys** (i.e., the columns present in `variants_df`).
       - For variants overlapping multiple CDS, merges per-CDS attributes as `list_join`-delimited strings
         and summarizes `effect` with the following precedence:
         * `"nonsynonymous"` if **any** CDS is nonsynonymous and none are synonymous;
         * `"mixed(nonsyn+syn)"` if some CDS are nonsynonymous and others synonymous;
         * `"synonymous"` if all CDS are synonymous (ignoring `cds-edge`);
         * `"mixed(syn+edge)"` if only synonymous and edge cases occur;
         * `"cds-edge"` if only edge cases occur;
         * `"noncoding"` if no CDS overlap.
       - Sets `region = "CDS-overlap"` when more than one CDS contributes; otherwise `"CDS"`.

    Parameters
    ----------
    df_pos : pandas.DataFrame
        Per-(CDS × position) coding-context table as returned by `cds_position_table(record)`.
        Must include columns: `accession`, `genome_pos`, `feature_index`, `strand`,
        `codon_index`, `codon_pos`, `cds_nt_index`, `codon`, `aa`, and (optionally) `gene`,
        `product`, `protein_id`.
    variants_df : pandas.DataFrame
        Variant table with at least columns: `CHROM` (must equal `record.id`), `POS` (1-based),
        `REF`, `ALT`. Additional columns (e.g., `Sample`, `AF`, `DP`, `FILTER`) are preserved
        and propagated to the output.
        - **SNV requirement**: Only rows with single-character REF and ALT are annotated;
          others are marked `"unsupported_variant"`.
    record : Bio.SeqRecord.SeqRecord
        The same GenBank/RefSeq record used to create `df_pos`; used both to validate REF and to
        obtain per-CDS translation tables.
    collapse_overlaps : bool, optional
        If `False` (default), return one row per **(variant × overlapping CDS)** (fine-grained).
        If `True`, collapse to **one row per variant** summarizing across overlapping CDS.
    list_join : str, optional
        Delimiter used when concatenating per-CDS fields in the collapsed view (default: `"|"`).

    Returns
    -------
    pandas.DataFrame
        If `collapse_overlaps=False`, one row per `(variant × CDS)` with columns:
        - **all columns from `variants_df`**, plus:
        - `region` : {"CDS", "noncoding"} or "CDS" for per-CDS rows; "noncoding" for noncoding/unsupported
        - `feature_index` : int or None
        - `gene`, `product`, `protein_id` : str or None (copied from `df_pos`)
        - `strand` : {1, -1} or None
        - `codon_index`, `codon_pos`, `cds_nt_index` : int/float or None
        - `ref_codon`, `new_codon` : str or None (coding strand, 5′→3′)
        - `ref_aa`, `new_aa` : str or None (one-letter; `*` for stop)
        - `effect` : {"synonymous","nonsynonymous","cds-edge","noncoding","unsupported_variant"}

        If `collapse_overlaps=True`, one row per variant with columns:
        - **all columns from `variants_df`**, plus:
        - `region` : {"CDS","CDS-overlap","noncoding"}
        - `n_cds_overlap` : int (0 for noncoding)
        - concatenated fields (strings joined with `list_join`):
          `feature_indexes`, `genes`, `products`, `protein_ids`, `strands`,
          `codon_indexes`, `codon_positions`, `ref_codons`, `new_codons`,
          `ref_aas`, `new_aas`
        - `effect` : summary category as described above

    Assumptions & Conventions
    -------------------------
    - **Coordinates**: `POS` is 1-based; `CHROM` must equal `record.id`.
    - **Strand handling**: For minus-strand CDS, ALT is **complemented** (not reversed) before
      substitution because `df_pos` codons are already in 5′→3′ coding orientation.
    - **Translation table**: Looked up by per-CDS `feature_index` using
      `_featureindex_to_transl_table(record)`; defaults to 1 if absent.
    - **Overlaps**: Variants overlapping multiple CDS yield multiple per-CDS rows, or a single
      collapsed row with merged annotations when `collapse_overlaps=True`.

    Limitations
    -----------
    - Only SNVs are annotated; indels and multi-allelic variants are flagged as
      `"unsupported_variant"`.
    - Ambiguous REF/ALT symbols are not resolved beyond a simple complement for minus strand.
    - If `df_pos` was generated from a different record or assembly (e.g., a different `CHROM`),
      `ref_match` will be `None` and variants will appear as noncoding.

    Complexity
    ----------
    Roughly O(N + K) where N is the number of rows in `variants_df` and K is the number of
    joined `(variant × CDS)` matches. Memory scales with the number of matches retained.

    Examples
    --------
    >>> # Build positional context
    >>> rec = fetch_genbank_record("AB016162.1")
    >>> df_pos = cds_position_table(rec)
    >>> # Example variants
    >>> import pandas as pd
    >>> variants_df = pd.DataFrame({
    ...     "Sample": ["S1","S1"],
    ...     "CHROM": ["AB016162.1","AB016162.1"],
    ...     "POS": [81, 84],
    ...     "REF": ["A","A"],
    ...     "ALT": ["G","G"]
    ... })
    >>> per_cds = annotate_variants(df_pos, variants_df, rec, collapse_overlaps=False)
    >>> per_var = annotate_variants(df_pos, variants_df, rec, collapse_overlaps=True)
    """
    # 1. Basic REF sanity check vs forward-strand reference
    ref_series = []
    ref_match = []
    snv_mask = []
    for _, r in variants_df.iterrows():
        if str(r["CHROM"]) != record.id:
            ref_series.append(None); ref_match.append(None); snv_mask.append(False); continue
        pos = int(r["POS"])
        ref_base = str(record.seq[pos-1]).upper()
        ref_series.append(ref_base)
        is_snv = (len(str(r["REF"])) == 1) and (len(str(r["ALT"])) == 1)
        snv_mask.append(is_snv)
        ref_match.append(ref_base == str(r["REF"]).upper() if is_snv else None)

    variants_df = variants_df.copy()
    variants_df["ref_genome_base"] = ref_series
    variants_df["ref_match"] = ref_match
    variants_df["is_snv"] = snv_mask
    variants_df["note"] = variants_df.apply(
        lambda r: ("non_SNV" if not r["is_snv"] else ("REF!=reference" if r["ref_match"] is False else "")),
        axis=1
    )

    # 2. Join variants to CDS positions
    merged = variants_df.merge(
        df_pos,
        left_on=["CHROM", "POS"],
        right_on=["accession", "genome_pos"],
        how="left",
        suffixes=("", "_cds"),
    )

    transl_map = _featureindex_to_transl_table(record)
    out_rows = []
    for _, r in merged.iterrows():
        # Non-SNV or different CHROM → mark and continue
        if not r["is_snv"]:
            out_rows.append({
                **{c: r[c] for c in variants_df.columns},
                "region": "noncoding" if pd.isna(r.get("feature_index")) else "CDS",
                "feature_index": (None if pd.isna(r.get("feature_index")) else int(r["feature_index"])),
                "gene": r.get("gene"),
                "product": r.get("product"),
                "protein_id": r.get("protein_id"),
                "strand": (None if pd.isna(r.get("strand")) else int(r["strand"])),
                "codon_index": r.get("codon_index"),
                "codon_pos": r.get("codon_pos"),
                "cds_nt_index": r.get("cds_nt_index"),
                "ref_codon": r.get("codon"),
                "new_codon": None,
                "ref_aa": r.get("aa"),
                "new_aa": None,
                "effect": "unsupported_variant",
            })
            continue

        # Variant not on this record or outside CDS:
        if pd.isna(r.get("feature_index")):
            out_rows.append({
                **{c: r[c] for c in variants_df.columns},
                "region": "noncoding",
                "feature_index": None,
                "gene": None,
                "product": None,
                "protein_id": None,
                "strand": None,
                "codon_index": None,
                "codon_pos": None,
                "cds_nt_index": None,
                "ref_codon": None,
                "new_codon": None,
                "ref_aa": None,
                "new_aa": None,
                "effect": "noncoding",
            })
            continue

        # Inside a CDS
        ref_codon = r.get("codon")
        codon_pos = r.get("codon_pos")
        strand = int(r.get("strand")) if not pd.isna(r.get("strand")) else 0
        table = transl_map.get(int(r["feature_index"]), 1)

        # Edge: no full codon context
        if pd.isna(ref_codon) or pd.isna(codon_pos):
            out_rows.append({
                **{c: r[c] for c in variants_df.columns},
                "region": "CDS",
                "feature_index": int(r["feature_index"]),
                "gene": r.get("gene"),
                "product": r.get("product"),
                "protein_id": r.get("protein_id"),
                "strand": strand,
                "codon_index": r.get("codon_index"),
                "codon_pos": r.get("codon_pos"),
                "cds_nt_index": r.get("cds_nt_index"),
                "ref_codon": ref_codon,
                "new_codon": None,
                "ref_aa": r.get("aa"),
                "new_aa": None,
                "effect": "cds-edge",
            })
            continue

        # Build ALT on the coding strand and mutate the codon
        alt_gen = str(r["ALT"]).upper()
        alt_coding = alt_gen if strand == 1 else alt_gen.translate(_COMP_MAP)
        cp0 = int(codon_pos) - 1
        ref_codon = str(ref_codon).upper()
        new_codon = ref_codon[:cp0] + alt_coding + ref_codon[cp0+1:]
        ref_aa = str(r.get("aa")) if r.get("aa") is not None else _aa_from_codon(ref_codon, table)
        new_aa = _aa_from_codon(new_codon, table)
        effect = "synonymous" if ref_aa == new_aa else "nonsynonymous"

        out_rows.append({
            **{c: r[c] for c in variants_df.columns},
            "region": "CDS",
            "feature_index": int(r["feature_index"]),
            "gene": r.get("gene"),
            "product": r.get("product"),
            "protein_id": r.get("protein_id"),
            "strand": strand,
            "codon_index": r.get("codon_index"),
            "codon_pos": int(codon_pos),
            "cds_nt_index": r.get("cds_nt_index"),
            "ref_codon": ref_codon,
            "new_codon": new_codon,
            "ref_aa": ref_aa,
            "new_aa": new_aa,
            "effect": effect,
        })

    per_cds = pd.DataFrame(out_rows)

    if not collapse_overlaps:
        return per_cds

    # Collapse to one row per variant if requested
    var_cols = [c for c in variants_df.columns]  # includes Sample, CHROM, POS, etc.
    def summarize(g):
        cds_rows = g[g["region"] == "CDS"].sort_values("feature_index")
        if cds_rows.empty:
            return pd.Series({"region": "noncoding", "n_cds_overlap": 0, "effect": "noncoding"})
        effects = cds_rows["effect"].tolist()
        has_non, has_syn, has_edge = ("nonsynonymous" in effects), ("synonymous" in effects), ("cds-edge" in effects)
        if has_non and has_syn: eff = "mixed(nonsyn+syn)"
        elif has_non: eff = "nonsynonymous"
        elif has_syn and has_edge: eff = "mixed(syn+edge)"
        elif has_syn: eff = "synonymous"
        else: eff = "cds-edge"
        join = lambda s: list_join.join("" if pd.isna(x) else str(x) for x in s)
        return pd.Series({
            "region": "CDS-overlap" if len(cds_rows) > 1 else "CDS",
            "n_cds_overlap": len(cds_rows),
            "feature_indexes": join(cds_rows["feature_index"].astype(int)),
            "genes": join(cds_rows["gene"]),
            "products": join(cds_rows["product"]),
            "protein_ids": join(cds_rows["protein_id"]),
            "strands": join(cds_rows["strand"].astype(int)),
            "codon_indexes": join(cds_rows["codon_index"]),
            "codon_positions": join(cds_rows["codon_pos"]),
            "ref_codons": join(cds_rows["ref_codon"]),
            "new_codons": join(cds_rows["new_codon"]),
            "ref_aas": join(cds_rows["ref_aa"]),
            "new_aas": join(cds_rows["new_aa"]),
            "effect": eff,
        })

    return (per_cds.groupby(var_cols, as_index=False).apply(summarize).reset_index())

In [12]:
try:
    from IPython.display import HTML, display
except Exception:
    HTML = None
    display = None

def make_gene_variant_tracks(
    df: pd.DataFrame,
    vdf: pd.DataFrame,
    *,
    gene: Optional[str] = None,                 # initial gene to show
    width: int = 900,
    height_variants: int = 100,
    height_aa: int = 50,
    height_codonpos: Optional[int] = None,      # add codon-pos panel if not None
    details_panel: bool = True,
    details_width: int = 340,
    product_priority: Optional[Iterable[str]] = ("phosphoprotein",),
    tooltip_all_cols: bool = True,
    fix_tooltip_clip: bool = True
) -> alt.Chart:
    """
    Return an Altair chart for a selected gene with:
      • AA rectangles (1-nt wide) colored by AA + nucleotide letter inside
      • Variants panel (AF vs position) with deterministic horizontal jitter (sorted by product)
      • Optional codon-position panel
      • Hover-driven details side panel
      • A dropdown to switch the displayed gene

    Parameters
    ----------
    df : DataFrame with columns at least
         ['gene','product','genome_pos' or 'pos_num','codon','codon_pos','aa'].
    vdf: DataFrame with columns
         ['gene', 'af', 'pos' or 'pos_num'] (+ optional: 'var_id','product','effect','sample_frac', ...).
    gene : Initial gene to display (also set as the dropdown's default). If None, the first sorted gene is used.
    width, height_variants, height_aa, height_codonpos : Sizes for tracks (px).
    details_panel : If True, show right-hand “Variant details” panel on hover.
    details_width : Width of details panel (px).
    product_priority : Iterable of product names to place first within jitter ordering.
    tooltip_all_cols : If True, include all vdf columns in tooltips (with light formatting).
    fix_tooltip_clip : If True (in Jupyter), inject CSS so tooltips aren’t clipped.

    Returns
    -------
    alt.Chart
    """
    # --- Optional CSS to prevent tooltip clipping (Jupyter) ---
    if fix_tooltip_clip and HTML is not None and display is not None:
        alt.renderers.set_embed_options(actions=False, tooltip={"theme": "light", "offset": 12})
        display(HTML("""
        <style>
        .vega-embed, .vega-embed * { overflow: visible !important; }
        .vg-tooltip { max-width: none !important; white-space: nowrap !important; z-index: 999999 !important; }
        </style>
        """))

    # --- Prepare df (tracks) ---
    df = df.copy()
    if "gene" not in df.columns:
        raise ValueError("df must contain a 'gene' column.")
    df["pos_num"] = pd.to_numeric(df.get("pos_num", df.get("genome_pos")), errors="coerce")
    df["pos_num_next"] = df["pos_num"] + 1

    def _nt_from_codon_row(r):
        c = str(r.get("codon", "") or "")
        cp = r.get("codon_pos")
        try:
            cp = int(cp)
        except Exception:
            return ""
        return c[cp-1] if (len(c) == 3 and cp in (1, 2, 3)) else ""

    df["nt_letter"] = df.apply(_nt_from_codon_row, axis=1)

    # --- Prepare vdf (variants) ---
    vdf = vdf.copy()
    if "gene" not in vdf.columns:
        raise ValueError("vdf must contain a 'gene' column.")
    vdf["pos_num"] = pd.to_numeric(vdf.get("pos_num", vdf.get("pos")), errors="coerce")
    if "x_center" not in vdf.columns:
        vdf["x_center"] = vdf["pos_num"] + 0.5
    for c in vdf.select_dtypes(include=["object"]).columns:
        vdf[c] = vdf[c].fillna("").astype(str)

    # Product order (for jitter sorting)
    if "product" in vdf.columns:
        if product_priority:
            priority = list(product_priority)
            remaining = [p for p in sorted(vdf["product"].unique()) if p not in priority]
            product_order = priority + remaining
        else:
            product_order = sorted(vdf["product"].unique())
        vdf["product"] = pd.Categorical(vdf["product"], categories=product_order, ordered=True)
        vdf["product_rank"] = vdf["product"].cat.codes
    else:
        vdf["product_rank"] = 0

    # --- Gene dropdown options & default ---
    gene_options = sorted(set(df["gene"].dropna().astype(str)) | set(vdf["gene"].dropna().astype(str)))
    if not gene_options:
        raise ValueError("No genes found in df/vdf.")
    default_gene = gene or gene_options[0]
    if default_gene not in gene_options:
        gene_options = [default_gene] + [g for g in gene_options if g != default_gene]

    # Altair parameter for dropdown (works in v5; fallback for v4)
    try:
        gene_param = alt.param(
            name="gene_sel",
            bind=alt.binding_select(options=gene_options, name="Gene: "),
            value=default_gene
        )
        gene_filter = (alt.datum.gene == gene_param)
    except Exception:
        # Altair v4
        gene_param = alt.selection_single(
            fields=["gene"], bind=alt.binding_select(options=gene_options, name="Gene: "),
            init={"gene": default_gene}
        )
        gene_filter = gene_param

    # --- Interactions: wheel zoom/pan + hover (for details) ---
    zoom = alt.selection_interval(bind="scales", encodings=["x"])
    try:
        hover = alt.selection_point(fields=(["var_id"] if "var_id" in vdf.columns else ["x_center","af"]),
                                    on="mouseover", nearest=True, empty="none")
    except Exception:
        hover = alt.selection_single(fields=(["var_id"] if "var_id" in vdf.columns else ["x_center","af"]),
                                     on="mouseover", nearest=True, empty="none")

    def add_zoom(chart: alt.Chart, include_hover: bool = False) -> alt.Chart:
        try:
            return chart.add_params(zoom, hover) if include_hover else chart.add_params(zoom)
        except Exception:
            out = chart.add_selection(zoom)
            return out.add_selection(hover) if include_hover else out

    # --- Tooltips from all vdf columns (optional) ---
    def _altair_type(series: pd.Series) -> str:
        return "Q" if (pd.api.types.is_integer_dtype(series) or pd.api.types.is_float_dtype(series)) else "N"

    vdf_tooltips = alt.Undefined
    if tooltip_all_cols:
        tips = []
        for col in vdf.columns:
            vtype = _altair_type(vdf[col])
            fmt = None
            if vtype == "Q":
                if col in {"af", "sample_frac"}:
                    fmt = ".3f"
                elif col in {"pos", "pos_num", "x_center", "cds_nt_index", "samples"}:
                    fmt = ".0f"
            tips.append(alt.Tooltip(f"{col}:{vtype}", title=col, format=fmt) if fmt else alt.Tooltip(f"{col}:{vtype}", title=col))
        vdf_tooltips = tips

    # --- Base for df tracks ---
    base = alt.Chart(df).transform_filter(gene_filter).encode(y=alt.Y("product:N", title=None))

    # AA rectangles + nucleotide letters (filtered by gene)
    aa_rects = (
        base.mark_rect(stroke="black")
        .encode(
            x=alt.X("pos_num:Q", title=None),
            x2="pos_num_next:Q",
            color=alt.Color("aa:N", title="aa", legend=None, scale=alt.Scale(scheme="category20"))
        )
        .properties(width=width, height=height_aa)
    )
    aa_text = (
        alt.Chart(df).transform_filter(gene_filter)
        .transform_calculate(x_center="datum.pos_num + 0.5")
        .mark_text(baseline="middle", font="monospace", fontSize=12, color="black",
                   stroke="white", strokeWidth=1)
        .encode(
            x=alt.X("x_center:Q", title=None),
            y=alt.Y("product:N", title=None),
            text=alt.Text("nt_letter:N")
        )
        .properties(width=width, height=height_aa)
    )
    aa_panel = add_zoom(aa_rects + aa_text)

    # Variants: filter by gene first; group per (gene, x_center) for jitter
    vars_track = (
        alt.Chart(vdf).transform_filter(gene_filter)
        .transform_joinaggregate(n="count()", groupby=["gene", "x_center"])
        .transform_window(
            rn="row_number()",
            sort=[dict(field="product_rank", order="ascending"),
                  dict(field="af",            order="descending")],
            groupby=["gene", "x_center"]
        )
        .transform_calculate(spread="0.45 / max(1, (datum.n - 1) / 2)")
        .transform_calculate(x_jitter="datum.x_center + (datum.rn - (datum.n + 1) / 2) * datum.spread")
        .mark_circle(stroke="black")
        .encode(
            x=alt.X("x_jitter:Q", title=None),
            y=alt.Y("af:Q", title="AF", scale=alt.Scale(domain=[0, 1]), axis=alt.Axis(grid=True)),
            color=alt.Color("effect:N", title="Effect", scale=alt.Scale(scheme="set1")) if "effect" in vdf.columns else alt.ColorValue("steelblue"),
            size=alt.Size("sample_frac:Q", title="sample_frac", scale=alt.Scale(range=[30, 300])) if "sample_frac" in vdf.columns else alt.value(80),
            tooltip=vdf_tooltips
        )
        .properties(width=width, height=height_variants)
    )
    vars_track = add_zoom(vars_track, include_hover=details_panel)

    # Optional codon-position panel
    stack_parts = [vars_track, aa_panel]
    if height_codonpos is not None:
        codonpos_panel = (
            alt.Chart(df).transform_filter(gene_filter)
            .mark_rect(stroke="black")
            .encode(
                x=alt.X("pos_num:Q", title=None),
                x2="pos_num_next:Q",
                color=alt.Color("codon_pos:O", title="codon_pos")
            )
            .properties(width=width, height=height_codonpos)
        )
        codonpos_panel = add_zoom(codonpos_panel)
        stack_parts.append(codonpos_panel)

    left_stack = alt.vconcat(*stack_parts).resolve_scale(x="shared", color="independent")

    # Details panel (also filtered by gene + hover)
    if details_panel:
        cols = list(vdf.columns)
        panel_height = min(20 * len(cols), max(height_variants + height_aa + (height_codonpos or 0), 240))
        details = (
            alt.Chart(vdf)
            .transform_filter(gene_filter)
            .transform_filter(hover)
            .transform_fold(cols, as_=["field", "value"])
            .mark_text(align="left", font="monospace", fontSize=12)
            .encode(y=alt.Y("field:N", sort=cols, title=None), text="value:N")
            .properties(width=details_width, height=panel_height, title="Variant details")
        )
        final = (left_stack | details)
    else:
        final = left_stack

    # Put the gene dropdown control at the top-level container (single control, affects all children)
    try:
        final = final.add_params(gene_param)
    except Exception:
        # v4: add the selection to the container (control appears once)
        final = final.add_selection(gene_param)

    return final

# Example:
# chart = make_gene_variant_tracks(df, vdf, gene="P", width=900, height_variants=100, height_aa=50, height_codonpos=50)
# chart


In [15]:
def gene_variant_widget(
    df: pd.DataFrame,
    vdf: pd.DataFrame,
    *,
    gene: Optional[str] = None,
    width: int = 900,
    height_variants: int = 100,
    height_aa: int = 50,
    height_codonpos: Optional[int] = 50,      # None -> hide codon-pos track by default
    details_panel: bool = True,
    details_width: int = 340,
    product_priority: Optional[Iterable[str]] = ("phosphoprotein",),
    tooltip_all_cols: bool = True,
    fix_tooltip_clip: bool = True,
    display_now: bool = True
) -> Dict[str, Any]:
    """
    Create an interactive widget (dropdown + sliders) that renders the gene-level
    genome/variant chart from `make_gene_variant_tracks` and updates it dynamically.

    Parameters
    ----------
    df, vdf : pandas.DataFrame
        Dataframes used by `make_gene_variant_tracks`. Both must have a 'gene' column.
    gene : str, optional
        Initial gene shown in the chart. Defaults to the first (sorted) available gene.
    width : int, default 900
        Width (px) of each track on the left.
    height_variants : int, default 100
        Height (px) of the variants track.
    height_aa : int, default 50
        Height (px) of the AA codon rectangles track.
    height_codonpos : Optional[int], default 50
        Height (px) of the codon-position (1/2/3) track. If None, the track is hidden initially.
    details_panel : bool, default True
        Show a right-hand side panel listing all fields for the hovered variant.
    details_width : int, default 340
        Width (px) of the details panel.
    product_priority : Iterable[str], default ("phosphoprotein",)
        Product names to prioritize inside the jitter ordering (left→right within a bin).
    tooltip_all_cols : bool, default True
        Include all columns from `vdf` in tooltips.
    fix_tooltip_clip : bool, default True
        Pass through to `make_gene_variant_tracks` to inject CSS so tooltips aren’t clipped (Jupyter).
    display_now : bool, default True
        If True, displays the widget immediately. The function always returns the UI container.

    Returns
    -------
    dict
        {
          "ui": VBox widget containing controls + chart output,
          "controls": {dict of individual widgets},
          "render": callable to force a re-render,
        }

    Notes
    -----
    • Requires the function `make_gene_variant_tracks` to be defined in the notebook scope.
    • The “Codon height” slider is enabled only when “Show codon-pos” is checked.
    • You can programmatically update any control (e.g., `controls["gene"].value = "P"`)
      and the chart will re-render.
    """
    # ---- Collect gene options from both frames ----
    if "gene" not in df.columns or "gene" not in vdf.columns:
        raise ValueError("Both df and vdf must contain a 'gene' column.")

    genes_df  = df["gene"].dropna().astype(str).unique().tolist()
    genes_vdf = vdf["gene"].dropna().astype(str).unique().tolist()
    gene_options = sorted(set(genes_df) | set(genes_vdf))
    if not gene_options:
        raise ValueError("No genes found in df or vdf.")

    default_gene = gene if (gene is not None and gene in gene_options) else gene_options[0]

    # ---- Controls ----
    gene_dd = widgets.Dropdown(
        options=gene_options,
        value=default_gene,
        description="Gene:",
        layout=widgets.Layout(width="280px"),
    )
    width_slider = widgets.IntSlider(
        value=int(width), min=400, max=2400, step=50,
        description="Width:", continuous_update=False, readout=True
        )
    hvar_slider = widgets.IntSlider(
        value=int(height_variants), min=60, max=240, step=10,
        description="Var height:", continuous_update=False
        )
    haa_slider = widgets.IntSlider(
        value=int(height_aa), min=30, max=140, step=5,
        description="AA height:", continuous_update=False
        )
    show_codon = widgets.Checkbox(
        value=(height_codonpos is not None), description="Show codon-pos"
        )
    hcod_slider = widgets.IntSlider(
        value=int(height_codonpos or 50), min=30, max=140, step=5,
        description="Codon height:", continuous_update=False,
        disabled=not show_codon.value
        )
    details_toggle = widgets.Checkbox(
        value=bool(details_panel), description="Show details panel"
        )

    # Advanced toggles (optional row)
    tips_toggle = widgets.Checkbox(
        value=bool(tooltip_all_cols), description="All tooltips"
        )
    clip_toggle = widgets.Checkbox(
        value=bool(fix_tooltip_clip), description="Fix tooltip clipping"
        )

    controls_row1 = widgets.HBox([gene_dd, width_slider, details_toggle])
    controls_row2 = widgets.HBox([hvar_slider, haa_slider, show_codon, hcod_slider])
    controls_row3 = widgets.HBox([tips_toggle, clip_toggle])

    out = widgets.Output()

    # ---- Render function ----
    def _render_chart(*_):
        with out:
            clear_output(wait=True)
            chart = make_gene_variant_tracks(
                df=df,
                vdf=vdf,
                gene=gene_dd.value,
                width=width_slider.value,
                height_variants=hvar_slider.value,
                height_aa=haa_slider.value,
                details_panel=details_toggle.value,
                details_width=details_width,
                product_priority=product_priority,
                tooltip_all_cols=tips_toggle.value,
                fix_tooltip_clip=clip_toggle.value
            )
            display(chart)

    # ---- Wire up interactions ----
    for w in (gene_dd, width_slider, hvar_slider, haa_slider, details_toggle, tips_toggle, clip_toggle):
        w.observe(_render_chart, names="value")
    def _toggle_codon(change):
        hcod_slider.disabled = not show_codon.value
        _render_chart()
    show_codon.observe(_toggle_codon, names="value")
    hcod_slider.observe(_render_chart, names="value")

    # ---- Initial render & assemble UI ----
    _render_chart()
    ui = widgets.VBox([controls_row1, controls_row2, controls_row3, out])

    if display_now:
        display(ui)

    return {
        "ui": ui,
        "controls": {
            "gene": gene_dd,
            "width": width_slider,
            "height_variants": hvar_slider,
            "height_aa": haa_slider,
            "show_codonpos": show_codon,
            "height_codonpos": hcod_slider,
            "details_panel": details_toggle,
            "tooltip_all_cols": tips_toggle,
            "fix_tooltip_clip": clip_toggle,
        },
        "render": _render_chart,
    }

# Example usage:
# widget_pack = gene_variant_widget(df, vdf, gene="P", width=1000, height_variants=120, height_aa=60)
# widget_pack["ui"]  # already displayed if display_now=True


In [5]:
# Your accession must match the CHROM values in the variant table:
accession = "NC_001498.1"

# Fetch the record and build CDS position table
rec = fetch_genbank_record(accession)
df_pos = cds_position_table(rec)

# Load your variants
variants_df = pd.read_csv("https://usegalaxy.org/api/datasets/f9cad7b01a4721353f9676ab5a7eb228/display?to_ext=tabular", sep="\t")

################
# HACK!!!!!!!!!!!!!!!!!!!
###############
variants_df['CHROM'] = "NC_001498.1"


# Annotate (per-CDS rows)
annot_per_cds = annotate_variants(df_pos, variants_df, rec, collapse_overlaps=False)


In [6]:
# Add variant ID
# It gives a unique name to each distinct variant
# If there are two variants with different ALTs at a site
# they will have different IDs

annot_per_cds['var_id'] = (
    annot_per_cds[['CHROM', 'POS', 'ALT', 'effect', 'product']]
    .astype(str)
    .agg('-'.join, axis=1)
)

In [7]:
# Aggregate info for each variant across samples

var_summary = (
    annot_per_cds
    .groupby('var_id', as_index=False)
    .agg(
        samples=('Sample', 'nunique'),
        af=('AF', 'median'),
        af_min=('AF', 'min'),
        af_max=('AF', 'max'),
        pos_num=('POS', 'min'),
        ref=('REF', 'min'),
        alt=('ALT', 'min'),
        cds_nt_index=('cds_nt_index', 'min'),
        gene=('gene', 'min'),
        product=('product', 'min'),
        effect=('effect', 'min'),
        ref_codon=('ref_codon', 'min'),
        new_codon=('new_codon', 'min'),
        ref_aa=('ref_aa', 'min'),
        new_aa=('new_aa', 'min'),
    )
)
var_summary['sample_frac'] = var_summary['samples']/annot_per_cds['Sample'].nunique()

In [10]:
codon_track = df_pos
variant_track = var_summary.query(("samples > 2 and effect != 'noncoding'")).copy()

In [16]:
widget_pack = gene_variant_widget(codon_track, variant_track, gene="P", width=1000, height_variants=120, height_aa=60)

VBox(children=(HBox(children=(Dropdown(description='Gene:', layout=Layout(width='280px'), options=('F', 'H', '…