# Model Input Preparation

Title: ALS miRNA repression pipeline

What it does: Builds canonical transcript sequences (ORF+3’UTR), (optionally) applies variants, prepares model input + analyzes the output of model after

# Setup

Required python packages and command-line tools

In [None]:
!pip install pandas
!conda install -y -c bioconda bedtools gffutils
!pip install biopython pandas
!pip install gffutils
!pip install pysam pandas
!pip -q install pysam pandas
!pip install pyranges
!pip install pyranges cyvcf2
!pip -q install cyvcf2 intervaltree

/bin/bash: line 1: conda: command not found
Collecting biopython
  Downloading biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Downloading biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m79.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.86
Collecting gffutils
  Downloading gffutils-0.13-py3-none-any.whl.metadata (1.5 kB)
Collecting pyfaidx>=0.5.5.2 (from gffutils)
  Downloading pyfaidx-0.9.0.3-py3-none-any.whl.metadata (25 kB)
Collecting argh>=0.26.2 (from gffutils)
  Downloading argh-0.31.3-py3-none-any.whl.metadata (7.4 kB)
Collecting argcomplete>=1.9.4 (from gffutils)
  Downloading argcomplete-3.6.3-py3-none-any.whl.metadata (16 kB)
Downloading gffutils-0.13-py3-none-any.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

Downloading the Ensembl GRCh38 reference files:
1. the genome annotation (GTF)

used for:
- identifying gene structures (genes, transcripts, exons, CDS, UTRs..)
-  maping start and end positions of CDS, three_prime_utr, stop_codon
- extracting the canonical transcript for each gene

2. the primary genome sequence (FASTA) from Ensembl release 114.

used for:
- retreiving the actual DNA sequences from the genome based on the positions defined in GTF
- for each transcript it fetches CDS + stop codon sequences to build the ORF (open reading frame); 3'UTR sequence

These files are the foundation for extracting transcripts and UTR sequences later in the pipeline.

In [None]:
# Ensembl GRCh38 release 114 GTF
!wget -q ftp://ftp.ensembl.org/pub/release-114/gtf/homo_sapiens/Homo_sapiens.GRCh38.114.gtf.gz
!gunzip -f Homo_sapiens.GRCh38.114.gtf.gz

# GRCh38 genome fasta (primary assembly)
!wget -q ftp://ftp.ensembl.org/pub/release-114/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.primary_assembly.fa.gz
!gunzip -f Homo_sapiens.GRCh38.dna.primary_assembly.fa.gz

# Canonical transcripts

In this step, we process the Ensembl GTF file to identify **canonical transcripts**, then filter them to focus on ALS-related genes.

- We extract only rows labeled as `feature == "transcript"`, and parse the attributes column to retain useful metadata: `gene_id`, `gene_name`, `transcript_id`, `transcript_biotype`, `tag`, and `transcript_support_level`.
- From this, we select transcripts tagged as `Ensembl_canonical` — typically the most biologically relevant isoform per gene.
- Finally, we filter the list to retain only genes associated with ALS, organized into clinical relevance tiers (A–D), based on sources like [ClinGen](https://search.clinicalgenome.org/kb/affiliate/10096?page=1&size=25&search=&sort=classification&order=desc).

This filtered list (`als_canonical_df`) will be used to extract transcript sequences for downstream analysis.

Loading the Ensembl GTF annotation and extracting only the rows corresponding to transcript, then parsing the attributes column to keep useful metadata like gene_id, gene_name, transcript_id, transcript_biotype, tag, and transcript_support_level. The result is a clean transcript table (tx) with gene and transcript information that will later be filtered to select canonical isoforms.

In [None]:
import pandas as pd

gtf_path = "Homo_sapiens.GRCh38.114.gtf"

# GTF columns
cols = ["seqname","source","feature","start","end","score","strand","frame","attribute"]

def parse_attrs(attr):
    d = {}
    for item in str(attr).strip().split(";"):
        item = item.strip()
        if not item:
            continue
        if " " not in item:   # tolerate flag-like entries
            d[item] = ""
            continue
        k, v = item.split(" ", 1)
        d[k] = v.strip('"')
    return d

# Stream the file and keep only transcript rows with useful attributes
parts = []
for chunk in pd.read_csv(gtf_path, sep="\t", comment="#", names=cols,
                         chunksize=400_000, low_memory=False):
    t = chunk[chunk["feature"] == "transcript"].copy()
    if t.empty:
        continue
    attrs = t["attribute"].apply(parse_attrs).apply(pd.Series)
    t = pd.concat([t.drop(columns=["attribute"]), attrs], axis=1)
    parts.append(t[[
        "gene_id","gene_name","transcript_id","transcript_biotype",
        "tag","transcript_support_level"
    ]])

tx = pd.concat(parts, ignore_index=True)
tx.head(), tx.shape

(           gene_id gene_name    transcript_id transcript_biotype  \
 0  ENSG00000142611    PRDM16  ENST00000511072     protein_coding   
 1  ENSG00000142611    PRDM16  ENST00000607632    retained_intron   
 2  ENSG00000142611    PRDM16  ENST00000378391     protein_coding   
 3  ENSG00000142611    PRDM16  ENST00000514189     protein_coding   
 4  ENSG00000142611    PRDM16  ENST00000270722     protein_coding   
 
                  tag            transcript_support_level  
 0    gencode_primary                                   5  
 1                NaN                                   2  
 2      gencode_basic                                   1  
 3    gencode_primary                                   5  
 4  Ensembl_canonical  1 (assigned to previous version 9)  ,
 (387954, 6))

In [None]:
tx.to_csv("parsed_gtf.csv", index=False)

Filtering parsed GTF to keep only canonical transcripts (tagged as Ensembl_canonical), keeping essential columns, removing duplicates, and saving the list as canonical_transcripts.csv for the next step.

In [None]:
import pandas as pd

# Load parsed GTF into DataFrame
gtf_df = pd.read_csv("parsed_gtf.csv")  # Or replace with your actual parsed GTF variable if already in memory

# Filter rows where 'tag' column contains 'Ensembl_canonical'
canonical_df = gtf_df[gtf_df['tag'].str.contains('Ensembl_canonical', na=False)]

# Keep only relevant columns
canonical_df = canonical_df[['gene_id', 'gene_name', 'transcript_id']]

# Remove duplicates (if multiple entries for same transcript)
canonical_df = canonical_df.drop_duplicates()

print(canonical_df.shape)
canonical_df.head()

# Save to file for next step
canonical_df.to_csv("canonical_transcripts.csv", index=False)

(78894, 3)


Filtering the canonical transcripts to include only ALS-related genes and creating a smaller DataFrame (als_canonical_df) for focused downstream analysis.

### Picking ALS genes

https://search.clinicalgenome.org/kb/affiliate/10096?page=1&size=25&search=&sort=classification&order=desc

Organizing ALS-related genes in 4 tiers:

- Tier A: Definitive + Strong (ALS)
- Tier B: Definitive (other)
- Tier C: Moderate (ALS and other)
- Tier D: Limited (ALS and other)

In [None]:
# Define your ALS-related genes - HERE WE CAN CHANGE

genes_tier_a = [
   "SOD1","TARDBP","OPTN","UBQLN2",
   "PFN1","ANXA11","FUS","VAPB","NEK1","KIF5A","GRN",
   "C9orf72","TBK1","CHMP2B","SPTLC2"
]

genes_tier_b = [
    "ALS2","SLC52A3","SLC52A2","MATR3","SPG11",
    "ERLIN2","ERLIN1","VCP","AR","ATXN2"
]

genes_tier_c = [
    "DCTN1","TUBA4A","CHCHD10","SQSTM1",
    "HNRNPA2B1","LRP12"
]

genes_tier_d = [
    "ARHGEF28","ARPP21","CAV1","CAV2","CFAP410",
    "DNAJC7","GLE1","GLT8D1","LGALSL","NEFH",
    "NUP50","PRPH","SS18L1","TAF15","TIA1",
    "FIG4","ERBB4","ANG","CCNF","CYLD"
]
# pick which tiers to analyse
genes = genes_tier_a + genes_tier_b + genes_tier_c + genes_tier_d

# Filter canonical transcripts for only those genes
als_canonical_df = canonical_df[canonical_df["gene_name"].isin(genes)]

print(als_canonical_df.shape)
als_canonical_df

(52, 3)


Unnamed: 0,gene_id,gene_name,transcript_id
13854,ENSG00000120948,TARDBP,ENST00000240185
40542,ENSG00000204843,DCTN1,ENST00000628224
40673,ENSG00000003393,ALS2,ENST00000264276
41576,ENSG00000127824,TUBA4A,ENST00000248437
44300,ENSG00000119862,LGALSL,ENST00000238875
45208,ENSG00000116001,TIA1,ENST00000433529
47579,ENSG00000178568,ERBB4,ENST00000342788
67707,ENSG00000083937,CHMP2B,ENST00000263780
70310,ENSG00000172995,ARPP21,ENST00000684406
73853,ENSG00000016864,GLT8D1,ENST00000266014


# Build transcript input (transcripts.txt)

Building the transcripts input file by extracting coding sequences (CDS + stop codon) as the ORF and three_prime_UTRs from the genome and GTF.
Selecting one canonical transcript per ALS gene (or falling back to the longest CDS if missing), reconstructing sequences in the correct strand, and saving transcripts.txt with ORF, UTR3, lengths, and concatenated ORF+UTR3.

## Choosing genes of interest

This section generates the core model input file `transcripts.txt`. It contains transcript-level nucleotide sequences of:

- **ORF (Open Reading Frame)**: concatenated CDS + stop codon
- **3' UTR**: untranslated region downstream of stop codon
- **orf_utr3**: concatenation of both

We:
- Parse the Ensembl GTF file to extract transcript structure.
- Use the genome FASTA to fetch actual DNA sequences.
- Select **one canonical transcript per ALS gene**, or fall back to the longest CDS if needed.
- Reverse complement sequences on the minus strand to ensure correct orientation.

This prepares the input needed for CNN models predicting miRNA-target interactions.

In [None]:
# ==== transcripts.txt generator (CDS + stop_codon for ORF; three_prime_utr for UTR) ====
# Outputs columns: transcript, orf, orf_length, utr3, utr3_length, orf_utr3
# Requires:
#   Homo_sapiens.GRCh38.114.gtf
#   Homo_sapiens.GRCh38.dna.primary_assembly.fa
# Optional:
#   canonical_transcripts.csv  (columns: gene_name, transcript_id)

import os, re, pandas as pd
from Bio import SeqIO

# -------- CONFIG --------
GENES_OF_INTEREST = genes
GTF_PATH  = "Homo_sapiens.GRCh38.114.gtf"
FA_PATH   = "Homo_sapiens.GRCh38.dna.primary_assembly.fa"
CANONICAL_CSV = "canonical_transcripts.csv"  # used if present; else fallback to longest CDS
INCLUDE_STOP_CODON = True   # append stop_codon bases to CDS so ORF ends with TAA/TAG/TGA
# ------------------------

assert os.path.exists(GTF_PATH),  "GTF file not found"
assert os.path.exists(FA_PATH),   "FASTA file not found"

# Load genome once
genome = SeqIO.to_dict(SeqIO.parse(FA_PATH, "fasta"))

def revcomp(s: str) -> str:
    tbl = str.maketrans("ACGTacgtNn", "TGCAtgcaNn")
    return s.translate(tbl)[::-1]

def parse_attrs(attr: str) -> dict:
    d = {}
    for item in str(attr).strip().split(";"):
        item = item.strip()
        if not item:
            continue
        if " " not in item:
            d[item] = ""
            continue
        k, v = item.split(" ", 1)
        d[k] = v.strip('"')
    return d

cols = ["seqname","source","feature","start","end","score","strand","frame","attribute"]
gene_regex = re.compile("|".join([re.escape(g) for g in GENES_OF_INTEREST]))

# Collect per-transcript data
cds_blocks   = {}   # tx -> [(start,end)]
stop_blocks  = {}   # tx -> [(start,end)]
utr3_blocks  = {}   # tx -> [(start,end)]
tx_meta      = {}   # tx -> {"chrom":..., "strand":..., "gene_name":...}
tx_version   = {}   # tx -> transcript_version (string/number as in GTF)

# Stream GTF with a fast pre-filter on the attribute text
for chunk in pd.read_csv(GTF_PATH, sep="\t", comment="#", names=cols,
                         chunksize=300_000, dtype={"seqname": str}, low_memory=False):
    mask = chunk["attribute"].str.contains(gene_regex, na=False)
    chunk = chunk[mask]
    if chunk.empty:
        continue
    attrs = chunk["attribute"].apply(parse_attrs).apply(pd.Series)
    df = pd.concat([chunk.drop(columns=["attribute"]), attrs], axis=1)

    # keep only rows for our genes
    df = df[df["gene_name"].isin(GENES_OF_INTEREST)]
    if df.empty:
        continue

    # record transcript meta + version if available from any row
    for r in df[["transcript_id","gene_name","seqname","strand","transcript_version"]].dropna(subset=["transcript_id"]).drop_duplicates().itertuples(index=False):
        tx_meta.setdefault(r.transcript_id, {"chrom": str(r.seqname), "strand": r.strand, "gene_name": r.gene_name})
        if pd.notna(r.transcript_version):
            tx_version.setdefault(r.transcript_id, r.transcript_version)

    # CDS
    for r in df[df["feature"]=="CDS"][["transcript_id","start","end"]].dropna().itertuples(index=False):
        cds_blocks.setdefault(r.transcript_id, []).append((int(r.start), int(r.end)))

    # stop_codon
    for r in df[df["feature"]=="stop_codon"][["transcript_id","start","end"]].dropna().itertuples(index=False):
        stop_blocks.setdefault(r.transcript_id, []).append((int(r.start), int(r.end)))

    # three_prime_utr (case-insensitive)
    is_utr3 = df["feature"].astype(str).str.lower() == "three_prime_utr"
    for r in df[is_utr3][["transcript_id","start","end"]].dropna().itertuples(index=False):
        utr3_blocks.setdefault(r.transcript_id, []).append((int(r.start), int(r.end)))

# Canonical transcript per gene
if os.path.exists(CANONICAL_CSV):
    can = pd.read_csv(CANONICAL_CSV)[["gene_name","transcript_id"]].drop_duplicates()
    can = can[can["gene_name"].isin(GENES_OF_INTEREST)]
else:
    can = pd.DataFrame(columns=["gene_name","transcript_id"])

present_genes = set(can["gene_name"])
missing = [g for g in GENES_OF_INTEREST if g not in present_genes]

# Fallback: pick tx with longest CDS for any missing gene
if missing:
    cds_len = {tx: sum(e-s+1 for s, e in blocks) for tx, blocks in cds_blocks.items()}
    tx2gene = {tx: meta["gene_name"] for tx, meta in tx_meta.items()}
    rows = []
    for g in missing:
        candidates = [tx for tx in cds_blocks.keys() if tx2gene.get(tx)==g]
        if not candidates:
            continue
        best = max(candidates, key=lambda t: cds_len.get(t, 0))
        rows.append({"gene_name": g, "transcript_id": best})
    if rows:
        can = pd.concat([can, pd.DataFrame(rows)], ignore_index=True)

# Ensure one transcript per gene
can = can.drop_duplicates(subset=["gene_name"], keep="first").reset_index(drop=True)

def fetch_concat(chrom, blocks, strand):
    if not blocks:
        return ""
    blocks = sorted(blocks, key=lambda x: x[0])
    seq = "".join(str(genome[chrom].seq[s-1:e]) for s, e in blocks)  # 1-based inclusive -> 0-based slice
    return revcomp(seq) if strand == "-" else seq

rows = []
for r in can.itertuples(index=False):
    txid = r.transcript_id
    meta = tx_meta.get(txid)
    if not meta:
        continue
    chrom, strand = meta["chrom"], meta["strand"]

    # ORF = CDS (+ optional stop_codon)
    cds_seq = fetch_concat(chrom, cds_blocks.get(txid, []), strand)
    if not cds_seq:
        continue
    stop_seq = fetch_concat(chrom, stop_blocks.get(txid, []), strand) if INCLUDE_STOP_CODON else ""
    orf_seq = cds_seq + stop_seq
    orf_len = len(orf_seq)

    # UTR3
    utr_seq = fetch_concat(chrom, utr3_blocks.get(txid, []), strand)
    utr_len = len(utr_seq)

    # transcript with version (if known)
    ver = tx_version.get(txid, None)
    if ver is not None:
        try:
            ver_int = int(ver)
            tx_with_ver = f"{txid}.{ver_int}"
        except Exception:
            tx_with_ver = f"{txid}.{ver}"
    else:
        tx_with_ver = txid

    rows.append([tx_with_ver, orf_seq, orf_len, utr_seq, utr_len, orf_seq + utr_seq])

# Final output in CNN format
out = pd.DataFrame(rows, columns=["transcript","orf","orf_length","utr3","utr3_length","orf_utr3"])
out.to_csv("transcripts.txt", sep="\t", index=False)
print(f"Saved transcripts.txt with {len(out)} transcripts for: {', '.join(can['gene_name'])}")
display(out.head(6))

Saved transcripts.txt with 51 transcripts for: TARDBP, DCTN1, ALS2, TUBA4A, LGALSL, TIA1, ERBB4, CHMP2B, ARPP21, GLT8D1, NEK1, MATR3, ARHGEF28, SQSTM1, FIG4, CAV2, CAV1, HNRNPA2B1, UBQLN2, AR, SLC52A2, LRP12, ERLIN2, VCP, GLE1, C9orf72, ANXA11, OPTN, ERLIN1, ATXN2, KIF5A, TBK1, PRPH, ANG, SPTLC2, SPG11, CCNF, CYLD, FUS, DNAJC7, TAF15, PFN1, GRN, SS18L1, SLC52A3, VAPB, NEFH, CHCHD10, NUP50, CFAP410, SOD1


Unnamed: 0,transcript,orf,orf_length,utr3,utr3_length,orf_utr3
0,ENST00000240185.8,ATGTCTGAATATATTCGGGTAACCGAAGATGAGAACGATGAGCCCA...,1245,ACAGTGGGGTTGTGGTTGGTTGGTATAGAATGGTGGGAATTCAAAT...,2838,ATGTCTGAATATATTCGGGTAACCGAAGATGAGAACGATGAGCCCA...
1,ENST00000628224.3,ATGGCACAGAGCAAGAGGCACGTGTACAGCCGGACGCCCAGCGGCA...,3837,GCACTCCTTTCCCCTGCTGTCCCCTTCGACCCTCAGCCCTCTGGTG...,344,ATGGCACAGAGCAAGAGGCACGTGTACAGCCGGACGCCCAGCGGCA...
2,ENST00000264276.11,ATGGACTCAAAGAAGAGAAGCTCAACAGAGGCAGAAGGATCCAAGG...,4974,GCTGCATAACAGCTTGAAAACTGGATTATCTACTACAGAGTGTTAT...,1584,ATGGACTCAAAGAAGAGAAGCTCAACAGAGGCAGAAGGATCCAAGG...
3,ENST00000248437.9,ATGCGTGAATGCATCTCAGTCCACGTGGGGCAGGCAGGTGTCCAGA...,1347,AGCAGCTGCCTGGAGCCTATTCACTATGTTTATTGCAAAATCCTTT...,642,ATGCGTGAATGCATCTCAGTCCACGTGGGGCAGGCAGGTGTCCAGA...
4,ENST00000238875.10,ATGGCGGGATCAGTGGCCGACAGCGATGCCGTGGTGAAACTAGATG...,519,TTTAAACCACCTCTATTTCAAATAGGATCACGTGCCACAACTATCT...,2953,ATGGCGGGATCAGTGGCCGACAGCGATGCCGTGGTGAAACTAGATG...
5,ENST00000433529.7,ATGGAGGACGAGATGCCCAAGACTCTATACGTCGGTAACCTTTCCA...,1161,ATAAGGACTCCAGAATCTAAAGCCAGTGGCTTGAGGCTACAGGGAG...,3275,ATGGAGGACGAGATGCCCAAGACTCTATACGTCGGTAACCTTTCCA...


## Sanity check

Briefly: verifying that each ORF/UTR3 we built is biologically consistent and correctly recorded.

	•	Loading transcripts.txt.
	•	Normalizing sequences (uppercasing and stripping non‑ACGTN).
	•	Checking ORF properties: length is multiple of 3; starts with ATG; ends with a stop codon (TAA/TAG/TGA); no internal stops; translation ends exactly at the terminal stop.
	•	Checking bookkeeping: recorded lengths match actual string lengths; concatenated orf_utr3 equals orf + utr3.
	•	Summarizing pass/fail per transcript and listing any failures for quick debugging.

In [None]:
import pandas as pd
from Bio.Seq import Seq

df = pd.read_csv("transcripts.txt", sep="\t")

# Basic normalizations
df["orf"] = df["orf"].str.upper().str.replace(r"[^ACGTN]", "", regex=True)
df["utr3"] = df["utr3"].str.upper().str.replace(r"[^ACGTN]", "", regex=True)

# Helpers
STOP_CODONS = {"TAA", "TAG", "TGA"}

def ends_with_stop(s: str) -> bool:
    return len(s) >= 3 and s[-3:] in STOP_CODONS

def has_internal_stop(s: str) -> bool:
    if len(s) < 6:  # shorter than two codons can't have an internal stop
        return False
    # scan all codons except the last one
    return any(s[i:i+3] in STOP_CODONS for i in range(0, len(s)-3, 3))

def trans_ok(s: str):
    """Translate and confirm only the last AA is stop."""
    if len(s) % 3 != 0 or len(s) == 0:
        return False
    try:
        p = Seq(s).translate(to_stop=False)  # includes '*' if final codon is stop
        return (len(p) > 0) and (p[-1] == "*") and ("*" not in str(p[:-1]))
    except Exception:
        return False

# Checks
checks = pd.DataFrame({
    "transcript": df["transcript"],
    # length bookkeeping
    "len_matches_field": (df["orf"].str.len() == df["orf_length"]) & (df["utr3"].str.len() == df["utr3_length"]),
    "orf_mod3": (df["orf_length"] % 3 == 0),
    "orf_starts_ATG": df["orf"].str.startswith("ATG"),
    "orf_ends_stop": df["orf"].apply(ends_with_stop),
    "orf_has_internal_stop": df["orf"].apply(has_internal_stop),
    "translate_ok": df["orf"].apply(trans_ok),
    "concat_ok": (df["orf_utr3"] == (df["orf"] + df["utr3"])),
})

# Summary
total = len(checks)
fail_any = checks.assign(
    ok = checks["len_matches_field"] &
         checks["orf_mod3"] &
         checks["orf_starts_ATG"] &
         checks["orf_ends_stop"] &
         (~checks["orf_has_internal_stop"]) &
         checks["translate_ok"] &
         checks["concat_ok"]
)
n_ok = fail_any["ok"].sum()

print(f"Passed all checks: {n_ok}/{total}")
print("\nProblems (if any):")
display(fail_any.loc[~fail_any["ok"]])

# Optional: show exactly which conditions failed per transcript
bad = fail_any.loc[~fail_any["ok"], ["transcript","len_matches_field","orf_mod3","orf_starts_ATG",
                                     "orf_ends_stop","orf_has_internal_stop","translate_ok","concat_ok"]]
if not bad.empty:
    print("\nDetailed failures:")
    display(bad)

Passed all checks: 51/51

Problems (if any):


Unnamed: 0,transcript,len_matches_field,orf_mod3,orf_starts_ATG,orf_ends_stop,orf_has_internal_stop,translate_ok,concat_ok,ok


### Optional: Cross-check ORFs/UTRs against Ensembl REST API

You can validate that the extracted `ORF` and `UTR3` sequences match Ensembl's reference transcripts using this separate notebook:

👉 [Ensembl REST Validation Notebook](https://colab.research.google.com/drive/1MUrefOg_XwKjq4vpNi0ZIVs0_i7MF2mV#scrollTo=mnL1EwXUQ7tZ)

This notebook:
- Fetches reference CDS/cDNA from Ensembl
- Compares with your `transcripts.txt` output
- Flags any mismatches (useful for debugging or pipeline validation)


# Canonical transcript anatomy

In this section, we build structured representations of each canonical transcript's **CDS** and **3'UTR** genomic intervals — needed for downstream mutation mapping, variant scoring, or visualization.

This step includes:

- Selecting one **canonical transcript** per ALS gene (from file or fallback)
- Collecting per-transcript **CDS** and **3'UTR** coordinates from the GTF
- Sorting blocks in transcript order (reverse if on minus strand)
- Calculating **relative positions** within transcripts
- Saving 3 output CSVs:
  - `canonical_meta.csv` — basic transcript metadata (gene, ID, version, strand, chrom)
  - `canonical_blocks.csv` — CDS/UTR lengths and block summaries
  - `canonical_blockmap.csv` — per-block genomic coordinates with transcript-relative offsets
- Optional: exporting BED files for downstream lookup (e.g. for `bedtools` or visualization)

These outputs define the **transcript structure maps** used to interpret variant positions in the CNN model.

In [None]:
# Build canonical transcript anatomy (no VCF required).
# Outputs:
#   - canonical_meta.csv
#   - canonical_blocks.csv
#   - canonical_blockmap.csv
#   - gene_spans.bed, utr3_regions.bed, cds_regions.bed

import os, re, pandas as pd

# ==== CONFIG ====
GENES = genes
GTF   = "Homo_sapiens.GRCh38.114.gtf"
CANON = "canonical_transcripts.csv"             # at least transcript_id; gene_name optional

assert os.path.exists(GTF),   "Upload Homo_sapiens.GRCh38.114.gtf"
assert os.path.exists(CANON), "Upload canonical_transcripts.csv"

# ==== Helpers ====
gtf_cols = ["seqname","source","feature","start","end","score","strand","frame","attribute"]

def parse_attrs(attr):
    d = {}
    for item in str(attr).strip().split(";"):
        item = item.strip()
        if not item:
            continue
        if " " not in item:
            d[item] = ""
        else:
            k, v = item.split(" ", 1)
            d[k] = v.strip('"')
    return d

def read_mini_gtf(gtf_path, genes):
    """Stream‑filter the GTF to only rows for the selected gene names (saves RAM)."""
    parts = []
    gene_pat = re.compile("|".join(re.escape(g) for g in genes))
    for chunk in pd.read_csv(gtf_path, sep="\t", comment="#", names=gtf_cols,
                             chunksize=300_000, dtype={"seqname": str}, low_memory=False):
        attrs = chunk["attribute"].apply(parse_attrs).apply(pd.Series)
        merged = pd.concat([chunk.drop(columns=["attribute"]), attrs], axis=1)
        # Keep only rows whose gene_name matches our list
        keep = merged["gene_name"].astype(str).str.fullmatch(gene_pat, na=False)
        m2 = merged[keep]
        if not m2.empty:
            parts.append(m2)
    if not parts:
        raise ValueError("No rows found for your genes in the provided GTF.")
    mini = pd.concat(parts, ignore_index=True)
    # Ensure numeric
    mini["start"] = mini["start"].astype(int)
    mini["end"]   = mini["end"].astype(int)
    return mini

mini = read_mini_gtf(GTF, GENES)

# ==== Canonical list: tolerate minimal CSVs ====
canon = pd.read_csv(CANON)
# Ensure at least transcript_id exists
assert "transcript_id" in canon.columns, "canonical_transcripts.csv must contain a 'transcript_id' column"

# If gene_name missing, fetch from GTF (transcript rows)
if "gene_name" not in canon.columns:
    tx_gene_map = (mini[mini["feature"]=="transcript"][["transcript_id","gene_name"]]
                   .drop_duplicates())
    canon = canon.merge(tx_gene_map, on="transcript_id", how="left")

# Filter to our target genes
canon = canon[canon["gene_name"].isin(GENES)].drop_duplicates(subset=["gene_name","transcript_id"])

# Attach transcript_version / chrom / strand from GTF transcript lines
tx_meta = (mini[mini["feature"]=="transcript"]
           [["transcript_id","transcript_version","seqname","strand","gene_name"]]
           .drop_duplicates())
canon = canon.merge(tx_meta, on=["transcript_id","gene_name"], how="left")

# If any gene is still missing a canonical transcript, pick the one with the longest CDS
missing = [g for g in GENES if g not in set(canon["gene_name"])]
if missing:
    cds_only = mini[(mini["feature"]=="CDS") & (mini["gene_name"].isin(missing))]
    if not cds_only.empty:
        cds_len = (cds_only.assign(len=lambda d: d["end"]-d["start"]+1)
                   .groupby(["gene_name","transcript_id"], as_index=False)["len"].sum())
        pick = cds_len.sort_values(["gene_name","len"], ascending=[True,False]).drop_duplicates("gene_name")
        # bring meta
        pick = pick.merge(tx_meta, on=["transcript_id","gene_name"], how="left")
        canon = pd.concat([canon, pick[canon.columns]], ignore_index=True)

# Make versioned transcript label
canon["transcript"] = canon.apply(
    lambda r: f"{r.transcript_id}.{int(r.transcript_version)}"
              if pd.notna(r.transcript_version) else r.transcript_id,
    axis=1
)

# Keep one transcript per gene (first if duplicates)
canon = canon.sort_values(["gene_name","transcript"]).drop_duplicates("gene_name", keep="first").reset_index(drop=True)

keep_tx = set(canon["transcript_id"])

# ==== Collect CDS and 3'UTR blocks for canonical transcripts ====
sub = mini[mini["transcript_id"].isin(keep_tx)].copy()

def blocks_for(df, featname_lower):
    t = df[df["feature"].str.lower()==featname_lower][["transcript_id","seqname","strand","start","end"]].copy()
    t["start"] = t["start"].astype(int)
    t["end"]   = t["end"].astype(int)
    # Genomic order; we'll flip to transcript order using strand later
    return t.sort_values(["transcript_id","start","end"]).reset_index(drop=True)

cds_blocks  = blocks_for(sub, "cds")
utr3_blocks = blocks_for(sub, "three_prime_utr")

# ==== Build per-block transcript-order maps with cumulative offsets ====
meta = canon.set_index("transcript_id")[["transcript","gene_name","seqname","strand"]].to_dict(orient="index")

def blockmap(df, region_label):
    rows = []
    for tx, g in df.groupby("transcript_id", sort=False):
        m = meta.get(tx, {"transcript":tx,"gene_name":None,"seqname":None,"strand":"+"})
        # transcript order: + strand = ascending start; - strand = reverse
        g_ord = g.sort_values("start", ascending=True)
        if m["strand"] == "-":
            g_ord = g_ord.iloc[::-1].copy()
        cum = 0
        for i, r in enumerate(g_ord.itertuples(index=False), start=1):
            length = int(r.end) - int(r.start) + 1
            rows.append({
                "gene_name": m["gene_name"],
                "transcript": m["transcript"],
                "transcript_id": tx,
                "region": region_label,
                "block_index_tx_order": i,
                "seqname": str(r.seqname),
                "strand": m["strand"],
                "genomic_start": int(r.start),
                "genomic_end":   int(r.end),
                "block_len":     length,
                "tx_cum_start":  cum + 1,           # 1-based within concatenated region
                "tx_cum_end":    cum + length       # inclusive
            })
            cum += length
    return pd.DataFrame(rows)

cds_map  = blockmap(cds_blocks,  "CDS")
utr3_map = blockmap(utr3_blocks, "three_prime_utr")
blockmap_all = pd.concat([cds_map, utr3_map], ignore_index=True)

# ==== Summaries + BEDs ====
canonical_meta = (canon[["gene_name","transcript","transcript_id","seqname","strand"]]
                  .drop_duplicates().sort_values("gene_name"))
canonical_meta.to_csv("canonical_meta.csv", index=False)

def region_lengths(blockmap_df):
    return (blockmap_df.groupby(["transcript_id","region"], as_index=False)
            .agg(region_len=("block_len","sum")))

lens = pd.concat([region_lengths(cds_map), region_lengths(utr3_map)], ignore_index=True)
lens_wide = (lens.pivot(index="transcript_id", columns="region", values="region_len")
                 .reset_index().fillna(0).rename_axis(None, axis=1)
                 .rename(columns={"CDS":"CDS_len","three_prime_utr":"UTR3_len"}))

def blocks_string(df):
    return (df.sort_values(["transcript_id","block_index_tx_order"])
              .groupby("transcript_id")
              .apply(lambda g: ";".join(f"{s}-{e}" for s,e in zip(g["genomic_start"], g["genomic_end"])))
              .reset_index(name="blocks"))

cds_blocks_str  = blocks_string(cds_map).rename(columns={"blocks":"CDS_blocks"})
utr3_blocks_str = blocks_string(utr3_map).rename(columns={"blocks":"UTR3_blocks"})

canonical_blocks = (canonical_meta.merge(lens_wide, on="transcript_id", how="left")
                    .merge(cds_blocks_str, on="transcript_id", how="left")
                    .merge(utr3_blocks_str, on="transcript_id", how="left"))
canonical_blocks.to_csv("canonical_blocks.csv", index=False)

blockmap_all.to_csv("canonical_blockmap.csv", index=False)

# gene spans BED (for quick VCF subsetting later)
gene_spans = (mini[mini["feature"]=="gene"][["seqname","start","end","gene_name"]]
              .drop_duplicates().sort_values(["seqname","start"]))
gene_spans.to_csv("gene_spans.bed", sep="\t", header=False, index=False)

# UTR3 & CDS BEDs for canonical only
utr3_bed = (utr3_blocks[utr3_blocks["transcript_id"].isin(keep_tx)]
            [["seqname","start","end"]].sort_values(["seqname","start"]))
utr3_bed.to_csv("utr3_regions.bed", sep="\t", header=False, index=False)

cds_bed = (cds_blocks[cds_blocks["transcript_id"].isin(keep_tx)]
           [["seqname","start","end"]].sort_values(["seqname","start"]))
cds_bed.to_csv("cds_regions.bed", sep="\t", header=False, index=False)

# ==== Preview ====
print("Saved files:")
print(" - canonical_meta.csv")
print(" - canonical_blocks.csv")
print(" - canonical_blockmap.csv")
print(" - gene_spans.bed, utr3_regions.bed, cds_regions.bed")

print("\ncanonical_meta.csv")
display(pd.read_csv("canonical_meta.csv").head())

print("\ncanonical_blocks.csv")
display(pd.read_csv("canonical_blocks.csv").head())

print("\ncanonical_blockmap.csv")
display(pd.read_csv("canonical_blockmap.csv").head())

Saved files:
 - canonical_meta.csv
 - canonical_blocks.csv
 - canonical_blockmap.csv
 - gene_spans.bed, utr3_regions.bed, cds_regions.bed

canonical_meta.csv


  .apply(lambda g: ";".join(f"{s}-{e}" for s,e in zip(g["genomic_start"], g["genomic_end"])))
  .apply(lambda g: ";".join(f"{s}-{e}" for s,e in zip(g["genomic_start"], g["genomic_end"])))


Unnamed: 0,gene_name,transcript,transcript_id,seqname,strand
0,ALS2,ENST00000264276.11,ENST00000264276,2,-
1,ANG,ENST00000397990.5,ENST00000397990,14,+
2,ANXA11,ENST00000422982.8,ENST00000422982,10,-
3,AR,ENST00000374690.9,ENST00000374690,X,+
4,ARHGEF28,ENST00000513042.7,ENST00000513042,5,+



canonical_blocks.csv


Unnamed: 0,gene_name,transcript,transcript_id,seqname,strand,CDS_len,UTR3_len,CDS_blocks,UTR3_blocks
0,ALS2,ENST00000264276.11,ENST00000264276,2,-,4971,1584,201768866-201768885;201767229-201767383;201760...,201700267-201701850
1,ANG,ENST00000397990.5,ENST00000397990,14,+,441,175,20693565-20694005,20694009-20694183
2,ANXA11,ENST00000422982.8,ENST00000422982,10,-,1515,4964,80172807-80172861;80170800-80170915;80168969-8...,80150889-80155852
3,AR,ENST00000374690.9,ENST00000374690,X,+,2760,6778,67545147-67546762;67643256-67643407;67686010-6...,67723842-67730619
4,ARHGEF28,ENST00000513042.7,ENST00000513042,5,+,5115,977,73684852-73684884;73749837-73749984;73752909-7...,73941014-73941990



canonical_blockmap.csv


Unnamed: 0,gene_name,transcript,transcript_id,region,block_index_tx_order,seqname,strand,genomic_start,genomic_end,block_len,tx_cum_start,tx_cum_end
0,GRN,ENST00000053867.8,ENST00000053867,CDS,1,17,+,44349165,44349302,138,1,138
1,GRN,ENST00000053867.8,ENST00000053867,CDS,2,17,+,44349426,44349551,126,139,264
2,GRN,ENST00000053867.8,ENST00000053867,CDS,3,17,+,44349667,44349751,85,265,349
3,GRN,ENST00000053867.8,ENST00000053867,CDS,4,17,+,44350228,44350340,113,350,462
4,GRN,ENST00000053867.8,ENST00000053867,CDS,5,17,+,44350442,44350577,136,463,598


In [None]:
import pandas as pd

bm = pd.read_csv("canonical_blockmap.csv")

# restrict only to your transcripts (if needed)
# bm = bm[bm["transcript"].isin(your_list_of_transcripts)]

# Count how many regions of each type exist
region_counts = bm["region"].value_counts()
print("Region counts:\n", region_counts)

# Count per gene
per_gene = bm.groupby(["gene_name", "region"])["transcript"].nunique().unstack(fill_value=0)
print("\nPer gene transcript counts:\n", per_gene)

Region counts:
 region
CDS                673
three_prime_utr     53
Name: count, dtype: int64

Per gene transcript counts:
 region     CDS  three_prime_utr
gene_name                      
ALS2         1                1
ANG          1                1
ANXA11       1                1
AR           1                1
ARHGEF28     1                1
ARPP21       1                1
ATXN2        1                1
C9orf72      1                1
CAV1         1                1
CAV2         1                1
CCNF         1                1
CFAP410      1                1
CHCHD10      1                1
CHMP2B       1                1
CYLD         1                1
DCTN1        1                1
DNAJC7       1                1
ERBB4        1                1
ERLIN1       1                1
ERLIN2       1                1
FIG4         1                1
FUS          1                1
GLE1         1                1
GLT8D1       1                1
GRN          1                1
HNRNPA2B1  

# Variant integration

## Input for the CNN model

Loading the canonical blockmap of transcripts and filtering for CDS and 3′UTR regions, then opening the ALS VCF file with pysam. Each variant in the VCF is being scanned against these transcript regions, and any that fall within a CDS or UTR3 block are collected into a hits table. The results are then saved to overlap_hits.csv, giving a direct list of which variants overlap which canonical transcripts and regions.

In [None]:
import csv, re
import pandas as pd
from intervaltree import Interval, IntervalTree
from cyvcf2 import VCF

VCF_PATH = "ALL.hg38.vcf.gz"   # THIS IS WHERE VARIANT FILE GOES IN
BLOCKMAP = "canonical_blockmap.csv"
OUT_CSV  = "overlap_hits.csv"

# 1) Load CDS/UTR3 regions and build per-chrom interval trees
bm = pd.read_csv(BLOCKMAP)
# normalize region labels
bm["region"] = (
    bm["region"].astype(str).str.strip().str.lower()
      .replace({"three_prime_utr":"utr3", "3utr":"utr3"})
      .replace({"cds":"CDS", "utr3":"UTR3"})
)
# now keep CDS + UTR3
regions = bm[bm["region"].isin(["CDS","UTR3"])][
    ["seqname","genomic_start","genomic_end","transcript","region"]
].copy()
regions = bm[bm["region"].isin(["CDS","UTR3"])][
    ["seqname","genomic_start","genomic_end","transcript","region"]
].copy()

def norm_chrom(x: str) -> str:
    x = str(x)
    return x[3:] if x.startswith("chr") else x

regions["seqname"] = regions["seqname"].map(norm_chrom)
trees = {}
for chrom, g in regions.groupby("seqname", sort=False):
    it = IntervalTree()
    # convert to 0-based half-open for tree queries
    for _, r in g.iterrows():
        s = int(r["genomic_start"]) - 1
        e = int(r["genomic_end"])     # inclusive in CSV -> half-open end
        it.addi(s, e, (r["transcript"], r["region"]))
    trees[chrom] = it

# 2) Stream the VCF and write overlapped hits incrementally
def is_dna(a):
    return bool(re.fullmatch(r"[ACGTacgt]+", a))

n, n_hit = 0, 0
with open(OUT_CSV, "w", newline="") as fh:
    w = csv.writer(fh)
    w.writerow(["chrom","pos","ref","alt","transcript","region"])
    for rec in VCF(VCF_PATH):
        n += 1
        chrom = norm_chrom(rec.CHROM)
        pos1  = int(rec.POS)               # 1-based
        pos0  = pos1 - 1                   # 0-based
        ref   = rec.REF
        alts  = [a for a in (rec.ALT or []) if a and is_dna(a)]
        if not alts:
            continue
        tree = trees.get(chrom)
        if not tree:
            continue
        # SNVs and left-normalized indels: use interval [pos0, pos0+len(ref))
        for iv in tree.overlap(pos0, pos0 + max(1, len(ref))):
            tx, region = iv.data
            for alt in alts:
                w.writerow([chrom, pos1, ref, alt, tx, region])
                n_hit += 1

print(f"Scanned {n:,} VCF records; wrote {n_hit:,} overlaps to {OUT_CSV}")

Scanned 167,751,855 VCF records; wrote 9,369 overlaps to overlap_hits.csv


In [None]:
import pandas as pd
hits = pd.read_csv("overlap_hits.csv")  # columns: chrom,pos,ref,alt,transcript,region

print("rows:", len(hits))
print("unique variant sites:", hits[['chrom','pos','ref','alt']].drop_duplicates().shape[0])
print("unique transcripts hit:", hits['transcript'].nunique())
print("by region:\n", hits['region'].value_counts())

rows: 9369
unique variant sites: 9368
unique transcripts hit: 49
by region:
 region
UTR3    4851
CDS     4518
Name: count, dtype: int64


In [None]:
tx_counts = hits.groupby('transcript').size().sort_values(ascending=False).head(20)
print(tx_counts)

transcript
ENST00000342788.9     546
ENST00000475243.6     421
ENST00000216484.7     376
ENST00000261866.12    371
ENST00000422982.8     360
ENST00000347635.9     342
ENST00000331758.8     313
ENST00000310624.7     297
ENST00000513042.7     296
ENST00000427738.8     274
ENST00000507142.6     265
ENST00000673436.1     260
ENST00000519638.3     254
ENST00000397066.9     247
ENST00000264276.11    246
ENST00000618183.5     239
ENST00000455537.7     223
ENST00000389805.9     210
ENST00000339818.9     202
ENST00000433529.7     199
dtype: int64


In [None]:
import pandas as pd

# files (adjust paths if needed)
hits   = pd.read_csv("overlap_hits.csv")                  # chrom,pos,ref,alt,transcript,region
bm     = pd.read_csv("canonical_blockmap.csv")            # must include: gene_name, transcript, region, ...

# tiers ?????
tiers  = set(genes_tier_a + genes_tier_b + genes_tier_c + genes_tier_d)

# reduce blockmap to transcript->gene mapping
tx2gene = bm[["transcript","gene_name"]].drop_duplicates()

# join hits to gene
hits_g = hits.merge(tx2gene, on="transcript", how="left")

# coverage among Tier A+B
tier_hits   = hits_g[hits_g["gene_name"].isin(tiers)].copy()
genes_hit   = sorted(tier_hits["gene_name"].dropna().unique())
genes_miss  = sorted(list(tiers - set(genes_hit)))

print(f"Total overlaps: {len(hits_g):,}")
print(f"Unique transcripts hit: {hits_g['transcript'].nunique()}")
print(f"Tier A+B+C+D genes with ≥1 overlapping variant: {len(genes_hit)} / {len(tiers)}")
print("Covered genes:", ", ".join(genes_hit) if genes_hit else "—")
print("Missing genes:", ", ".join(genes_miss) if genes_miss else "—")

# quick per-gene counts (how many overlapped variants per gene)
counts = (
    tier_hits
    .groupby("gene_name", as_index=False)
    .size()
    .sort_values("size", ascending=False)
)
counts.rename(columns={"size":"n_overlaps"}, inplace=True)
display(counts)

Total overlaps: 9,369
Unique transcripts hit: 49
Tier A+B+C+D genes with ≥1 overlapping variant: 49 / 51
Covered genes: ALS2, ANG, ANXA11, ARHGEF28, ARPP21, ATXN2, C9orf72, CAV1, CAV2, CCNF, CFAP410, CHCHD10, CHMP2B, CYLD, DCTN1, DNAJC7, ERBB4, ERLIN1, ERLIN2, FIG4, FUS, GLE1, GLT8D1, GRN, HNRNPA2B1, KIF5A, LGALSL, LRP12, MATR3, NEFH, NEK1, NUP50, OPTN, PFN1, PRPH, SLC52A2, SLC52A3, SOD1, SPG11, SPTLC2, SQSTM1, SS18L1, TAF15, TARDBP, TBK1, TIA1, TUBA4A, VAPB, VCP
Missing genes: AR, UBQLN2


Unnamed: 0,gene_name,n_overlaps
16,ERBB4,546
47,VAPB,421
39,SPTLC2,376
38,SPG11,371
2,ANXA11,360
31,NUP50,342
41,SS18L1,313
29,NEFH,297
3,ARHGEF28,296
13,CYLD,274


In [None]:
# mutate_from_overlaps.py
# Build mutated transcripts + combined file using precomputed overlaps.
# Inputs: transcripts.txt, canonical_blockmap.csv, overlap_hits.csv
# Outputs: transcripts_mutated.txt, transcripts_all.txt, variants_applied.csv, mutated_id_map.csv
# Extra: variants_applied_summary.txt with quick diagnostics

# transcripts_CNN.txt is what we need for the CNN model

import re, pandas as pd, numpy as np
from collections import Counter

BASE_TXT = "transcripts.txt"
BLOCKMAP = "canonical_blockmap.csv"
OVERLAP  = "overlap_hits.csv"

OUT_MUT   = "transcripts_mutated.txt"
OUT_ALL   = "transcripts_all.txt"
OUT_AUDIT = "variants_applied.csv"
OUT_MAP   = "mutated_id_map.csv"
OUT_SUM   = "variants_applied_summary.txt"

# --- helpers ---
_comp = str.maketrans("ACGTacgtNn", "TGCAtgcaNn")
def revcomp(s: str) -> str:
    return s.translate(_comp)[::-1]

def sanitize_id(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9._-]', '_', str(s))

def chrom_key(x):
    x = str(x)
    return x[3:] if x.startswith("chr") else x

def norm_region(x: str) -> str:
    x = str(x)
    if x.lower() in ("utr3","three_prime_utr","three-prime-utr","3utr","3'utr","3_prime_utr"):
        return "UTR3"
    if x.upper()=="CDS": return "CDS"
    return x

# --- load base transcripts ---
base = pd.read_csv(BASE_TXT, sep="\t")
ren = {"transcript_id":"transcript","ORF_seq":"orf","ORF_len":"orf_length",
       "UTR3_seq":"utr3","UTR3_len":"utr3_length","Full_seq":"orf_utr3"}
for k,v in ren.items():
    if k in base.columns and v not in base.columns:
        base = base.rename(columns={k:v})

need_cols = ["transcript","orf","orf_length","utr3","utr3_length","orf_utr3"]
miss = [c for c in need_cols if c not in base.columns]
assert not miss, f"transcripts.txt missing columns: {miss}"

# ensure clean DNA strings
for c in ["orf","utr3","orf_utr3"]:
    base[c] = base[c].astype(str).str.upper().str.replace(r"[^ACGTN]", "N", regex=True)

orf   = dict(zip(base["transcript"], base["orf"]))
utr3  = dict(zip(base["transcript"], base["utr3"]))
orfL  = dict(zip(base["transcript"], base["orf_length"]))
utr3L = dict(zip(base["transcript"], base["utr3_length"]))

# --- load blockmap & index blocks per transcript for CDS/UTR3 ---
bm = pd.read_csv(BLOCKMAP)
if "region" in bm.columns:
    bm["region"] = bm["region"].map(norm_region)
need_bm = ["transcript","region","seqname","strand","genomic_start","genomic_end","tx_cum_start","tx_cum_end"]
assert set(need_bm).issubset(bm.columns), "canonical_blockmap.csv missing required columns"

# keep only our transcripts & regions of interest
bm = bm[bm["transcript"].isin(base["transcript"]) & bm["region"].isin(["CDS","UTR3"])].copy()

# sort and build per-transcript block lists
bm[["genomic_start","genomic_end","tx_cum_start","tx_cum_end"]] = bm[["genomic_start","genomic_end","tx_cum_start","tx_cum_end"]].astype(int)
per_tx_blocks = {}
for tx, g in bm.sort_values(["region","tx_cum_start"]).groupby("transcript"):
    per_tx_blocks[tx] = g[["region","seqname","strand","genomic_start","genomic_end","tx_cum_start","tx_cum_end"]].to_dict("records")

def locate_in_tx(tx, chrom, pos):
    """Return (region, strand, tx_offset0) if (chrom,pos) lands in CDS/UTR3; else None."""
    blocks = per_tx_blocks.get(tx, [])
    for b in blocks:
        if chrom_key(b["seqname"]) != chrom:
            continue
        s, e = int(b["genomic_start"]), int(b["genomic_end"])
        if s <= pos <= e:
            strand = b["strand"]
            tx0 = int(b["tx_cum_start"]) - 1
            off = (pos - s) if strand == "+" else (e - pos)
            return b["region"], strand, tx0 + off
    return None

def apply_allele_to_tx(tx, chrom, pos, ref, alt, loose_snv_ok=True):
    """
    Strictly apply edit when reference matches in transcript coordinates.
    If ref mismatch AND it's a single-base SNV, optionally 'rescue' by applying anyway,
    recorded as 'applied_loose_ref_mismatch'. Indels remain strict.
    """
    loc = locate_in_tx(tx, chrom, pos)
    if loc is None:
        return dict(ok=False, reason="not_in_CDS_or_UTR3", region=None, mode="")
    region, strand, tx_offset0 = loc

    full = orf[tx] + utr3[tx]
    orf_len = int(orfL[tx]); utr_len = int(utr3L[tx])
    if len(full) != orf_len + utr_len:
        return dict(ok=False, reason="length_mismatch_concat", region=region, mode="")

    tref, talt = (ref, alt) if strand == "+" else (revcomp(ref), revcomp(alt))

    rlen = len(tref)
    if tx_offset0 < 0 or tx_offset0 + rlen > len(full):
        return dict(ok=False, reason="bounds_error", region=region, mode="")

    ref_slice = full[tx_offset0:tx_offset0+rlen].upper()
    if ref_slice != tref.upper():
        # allow only SNV rescue (1>1)
        if loose_snv_ok and rlen == 1 and len(talt) == 1 and ref_slice in "ACGT" and tref.upper() in "ACGT":
            mutated_full = full[:tx_offset0] + talt + full[tx_offset0+rlen:]
            new_orf  = mutated_full[:orf_len]
            new_utr3 = mutated_full[orf_len:]
            return dict(ok=True, reason="applied_loose_ref_mismatch", region=region,
                        new_orf=new_orf, new_utr3=new_utr3,
                        new_orf_len=len(new_orf), new_utr3_len=len(new_utr3), mode="loose")
        else:
            return dict(ok=False, reason=f"ref_mismatch_full:{ref_slice}->{tref}", region=region, mode="")
    # strict apply
    mutated_full = full[:tx_offset0] + talt + full[tx_offset0+rlen:]
    new_orf  = mutated_full[:orf_len]
    new_utr3 = mutated_full[orf_len:]
    return dict(ok=True, reason="applied", region=region,
                new_orf=new_orf, new_utr3=new_utr3,
                new_orf_len=len(new_orf), new_utr3_len=len(new_utr3), mode="strict")

# --- load & de-duplicate overlaps ---
over = pd.read_csv(OVERLAP)
# expected columns: chrom,pos,ref,alt,transcript,region (region optional)
must = {"chrom","pos","ref","alt","transcript"}
assert must.issubset(set(over.columns)), f"overlap_hits.csv missing columns: {must - set(over.columns)}"
over["chrom"] = over["chrom"].map(chrom_key)
if "region" in over.columns:
    over["region"] = over["region"].map(norm_region)

# keep only transcripts we have
over = over[over["transcript"].isin(base["transcript"])].copy()

# drop exact dupes
over = over.drop_duplicates(subset=["transcript","chrom","pos","ref","alt"]).reset_index(drop=True)

mut_rows, audit_rows, map_rows = [], [], []
var_idx = 1

for r in over.itertuples(index=False):
    tx   = r.transcript
    chrom = str(r.chrom)
    pos   = int(r.pos)          # 1-based genomic
    ref   = str(r.ref)
    alt   = str(r.alt)

    res = apply_allele_to_tx(tx, chrom, pos, ref, alt, loose_snv_ok=True)
    var_tag = f"chr{chrom}_{pos}_{ref}_{alt}"
    if res["ok"]:
        mut_id = sanitize_id(f"{tx}_{var_tag}_var{var_idx:04d}")
        mut_rows.append({
            "parent_transcript": tx,
            "transcript": mut_id,
            "orf": res["new_orf"],
            "orf_length": res["new_orf_len"],
            "utr3": res["new_utr3"],
            "utr3_length": res["new_utr3_len"],
            "orf_utr3": res["new_orf"] + res["new_utr3"]
        })
        audit_rows.append({
            "parent_transcript": tx,
            "mutated_transcript": mut_id,
            "variant": f"chr{chrom}:{pos} {ref}>{alt}",
            "region": res["region"],
            "status": "applied",
            "note": res["reason"],
            "mode": res.get("mode","")
        })
        map_rows.append({
            "mutated_transcript": mut_id,
            "parent_transcript": tx,
            "variant": f"chr{chrom}:{pos} {ref}>{alt}",
            "region": res["region"]
        })
        var_idx += 1
    else:
        audit_rows.append({
            "parent_transcript": tx,
            "mutated_transcript": "",
            "variant": f"chr{chrom}:{pos} {ref}>{alt}",
            "region": res.get("region"),
            "status": "skipped",
            "note": res["reason"],
            "mode": ""
        })

# --- write mutated-only ---
mut_df = pd.DataFrame(mut_rows,
                      columns=["parent_transcript","transcript","orf","orf_length","utr3","utr3_length","orf_utr3"])
mut_df.to_csv(OUT_MUT, sep="\t", index=False)

# --- write ALL (orig + mutated) ---
orig_cols = ["parent_transcript","transcript","orf","orf_length","utr3","utr3_length","orf_utr3"]
orig_df = base.copy()
orig_df.insert(0, "parent_transcript", orig_df["transcript"])
orig_df = orig_df[orig_cols]
all_df = pd.concat([orig_df, mut_df[orig_cols]], axis=0, ignore_index=True)
assert all_df["transcript"].is_unique, "Non-unique transcript IDs in transcripts_all.txt"
all_df.to_csv(OUT_ALL, sep="\t", index=False)

# --- write logs ---
audit_df = pd.DataFrame(audit_rows)
audit_df.to_csv(OUT_AUDIT, index=False)
pd.DataFrame(map_rows).to_csv(OUT_MAP, index=False)

# --- quick summary ---
applied = audit_df.query("status=='applied'")
skipped = audit_df.query("status=='skipped'")

lines = []
lines.append(f"Overlaps read: {len(over):,}")
lines.append(f"Applied total: {len(applied):,}  (strict: {(applied['mode']=='strict').sum():,} | loose SNV: {(applied['mode']=='loose').sum():,})")
lines.append(f"Skipped total: {len(skipped):,}")
if "region" in audit_df.columns:
    lines.append("\nApplied by region:")
    lines.extend(applied["region"].value_counts().to_string().splitlines())
lines.append("\nTop skip reasons:")
lines.extend(skipped["note"].value_counts().head(10).to_string().splitlines())

with open(OUT_SUM,"w") as fh:
    fh.write("\n".join(lines) + "\n")

print("\n".join(lines))
print(f"\nWrote: {OUT_MUT}, {OUT_ALL}, {OUT_AUDIT}, {OUT_MAP}, {OUT_SUM}")

# CNN-ready file (no parent_transcript)
all_df.drop(columns=["parent_transcript"]).to_csv("transcripts_CNN.txt", sep="\t", index=False)

# check the final model input
print("\ntranscripts_CNN.txt")
display(pd.read_csv("transcripts_CNN.txt").head())

Overlaps read: 9,368
Applied total: 8,616  (strict: 5,563 | loose SNV: 3,053)
Skipped total: 752

Applied by region:
region
CDS     4428
UTR3    4188

Top skip reasons:
note
ref_mismatch_full:G->C    27
ref_mismatch_full:C->G    23
ref_mismatch_full:A->G    23
ref_mismatch_full:C->A    20
ref_mismatch_full:T->G    20
ref_mismatch_full:A->T    17
ref_mismatch_full:T->A    16
ref_mismatch_full:T->C    15
ref_mismatch_full:G->A    15
ref_mismatch_full:A->C    14

Wrote: transcripts_mutated.txt, transcripts_all.txt, variants_applied.csv, mutated_id_map.csv, variants_applied_summary.txt

transcripts_CNN.txt


Unnamed: 0,transcript\torf\torf_length\tutr3\tutr3_length\torf_utr3
0,ENST00000240185.8\tATGTCTGAATATATTCGGGTAACCGAA...
1,ENST00000628224.3\tATGGCACAGAGCAAGAGGCACGTGTAC...
2,ENST00000264276.11\tATGGACTCAAAGAAGAGAAGCTCAAC...
3,ENST00000248437.9\tATGCGTGAATGCATCTCAGTCCACGTG...
4,ENST00000238875.10\tATGGCGGGATCAGTGGCCGACAGCGA...


## What's with mismatches?

In [None]:
import pandas as pd

# Load the audit file with applied/skipped status
audit = pd.read_csv("variants_applied.csv")

# Total overlaps (applied + skipped)
total = len(audit)

# Count mismatches specifically
mismatches = (audit["note"].astype(str)
              .str.startswith("ref_mismatch_full")).sum()

# % of all overlaps
pct = mismatches / total * 100

print(f"Total overlaps: {total}")
print(f"Ref mismatches: {mismatches} ({pct:.2f}%)")

In [None]:
import pandas as pd

audit = pd.read_csv("variants_applied.csv")  # cols: parent_transcript, mutated_transcript, variant, region, status, note
bm    = pd.read_csv("canonical_blockmap.csv")  # has gene_name, transcript, region, seqname, genomic_start...
hits  = pd.read_csv("overlap_hits.csv")  # chrom,pos,ref,alt,transcript,region  (from your big scan)

# --- filter to ref_mismatch skips ---
bad = audit[(audit.status=="skipped") & (audit.note.str.startswith("ref_mismatch_full"))].copy()

# parse variant "chr<chrom>:<pos> <REF>><ALT>"
def parse_variant(s):
    # e.g. "chr21:31659783 C>T"
    chrom_pos, ra = s.split()
    chrom = chrom_pos.split(":")[0].replace("chr","")
    pos   = int(chrom_pos.split(":")[1])
    ref, alt = ra.split(">")
    return chrom, pos, ref, alt

bp = bad["variant"].apply(parse_variant)
bad["chrom"] = [c for c,_,_,_ in bp]
bad["pos"]   = [p for _,p,_,_ in bp]
bad["ref"]   = [r for *_,r,_ in bp]
bad["alt"]   = [a for *_,a in bp]

# add gene_name via blockmap
tx2gene = bm[["transcript","gene_name"]].drop_duplicates()
bad = bad.merge(tx2gene, left_on="parent_transcript", right_on="transcript", how="left", suffixes=("","_bm"))

print("Total ref_mismatch skips:", len(bad))
print("\nTop genes with ref_mismatch skips:")
print(bad.groupby("gene_name").size().sort_values(ascending=False).head(20))

print("\nTop transcripts with ref_mismatch skips:")
print(bad.groupby("parent_transcript").size().sort_values(ascending=False).head(20))

# Are they clustered? Show loci with many mismatches (same chrom:pos across many transcripts)
loci = bad.groupby(["chrom","pos"]).size().sort_values(ascending=False)
print("\nMost frequent loci among ref_mismatch (chrom:pos -> count):")
print(loci.head(20))

# Are multi-allelic sites involved? (same chrom:pos with multiple ALTs in the original overlap_hits)
multi_site = (hits
              .assign(chrom=hits["chrom"].astype(str).str.replace("^chr","",regex=True))
              .groupby(["chrom","pos"])["alt"].nunique()
              .reset_index(name="n_alt"))
multi_site_bad = bad.merge(multi_site, on=["chrom","pos"], how="left")
print("\nBad ref_mismatch rows that are at multi-allelic sites (n_alt>=2):",
      (multi_site_bad["n_alt"]>=2).sum(), " / ", len(multi_site_bad))

print("\nExamples of multi-allelic mismatch sites:")
print(multi_site_bad.loc[multi_site_bad["n_alt"]>=2, ["gene_name","parent_transcript","chrom","pos","variant","note"]].head(30))

# Prepare miRNA Input Sequences

In this step, we extract guide and passenger sequences for human miRNAs from the TargetScan file miR_Family_Info.txt.

	•	Only human (Species ID == 9606) entries are retained.
	•	A consistent mir identifier is built by combining miRNA family names with strand info (e.g., _3p, _5p) when available.
	•	The guide sequence is taken directly from the Mature sequence, and the passenger sequence is calculated as its reverse complement.
	•	We also construct family-level identifiers for downstream compatibility.

  Finally, two output files are created:
  
	•	mirseqs.txt with guide/passenger sequences and family names.
	•	mirseqs_pass.txt listing all corresponding passenger names (for filtering).

In [None]:
import re, pandas as pd

IN  = "miR_Family_Info.txt"      # TargetScan file
OUT_MAIN = "mirseqs.txt"
OUT_PASS = "mirseqs_pass.txt"

# helper
_comp = str.maketrans("ACGTUacgtuNn", "TGCAAtgcaaNn")
def revcomp(s): return s.translate(_comp)[::-1].replace("U","T")

def sanitize(s):
    s = re.sub(r'[^A-Za-z0-9._-]+', '_', s)
    return s.strip('_')

# load + keep human only
df = pd.read_csv(IN, sep=r"\t", engine="python")
df = df[df["Species ID"] == 9606].copy()

# Build a stable miRNA "name" from family + (optional) arm hinted by MiRBase ID
# e.g., family "miR-122" -> name "mir122"; if MiRBase ID ends with "-5p"/"-3p" we append "_5p"/"_3p".
def name_from_row(r):
    fam = sanitize(r["miR family"]).lower().replace("mir-","mir").replace("let-","let")
    # optional arm tag from MiRBase ID (e.g. hsa-miR-133a-3p)
    m = re.search(r"-([53]p)\b", str(r["MiRBase ID"]) or "", flags=re.I)
    arm = f"_{m.group(1).lower()}" if m else ""
    return sanitize(f"{fam}{arm}")

df["mir"] = df.apply(name_from_row, axis=1)

# choose one representative mature sequence per 'mir' (if duplicates, keep the first)
df = df.sort_values(["miR family","MiRBase ID"]).drop_duplicates(subset=["mir"], keep="first")

# sequences: use Mature sequence as guide; passenger = reverse complement (approximation)
df["guide_seq"] = df["Mature sequence"].str.upper().str.replace("U","T")
df["pass_seq"]  = df["guide_seq"].map(revcomp)

# families
df["guide_family"] = df["miR family"].apply(lambda x: sanitize(str(x)).lower().replace("mir-","mir").replace("let-","let"))
df["pass_family"]  = df["guide_family"] + "_pass"

# write main file
df_out = df[["mir","guide_seq","pass_seq","guide_family","pass_family"]].rename(
    columns={"mir":"mir"})
df_out.to_csv(OUT_MAIN, sep="\t", index=False)

# write passenger names (one per line)
with open(OUT_PASS, "w") as f:
    for n in df_out["mir"]:
        f.write(n + "_pass\n")

print(f"Wrote {OUT_MAIN} (rows: {len(df_out)}) and {OUT_PASS} (rows: {len(df_out)})")

Wrote mirseqs.txt (rows: 2244) and mirseqs_pass.txt (rows: 2244)


In [None]:
from google.colab import files
files.download("mirseqs.txt")
files.download("mirseqs_pass.txt")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>