# bHLH Classes and Table Inputs

This notebook builds the main input tables used for downstream plotting, using the longest isoform per gene and InterPro domain ranges.

**Inputs**
- `data/intermediate/Metadata_CSVs/InterPro_Domains_cleaned.csv`
- `data/intermediate/Metadata_CSVs/Pfam_Domains_cleaned.csv`
- `data/intermediate/Metadata_CSVs/Transcript_Attributes.csv`
- `data/intermediate/longest_isoform.fasta`
- `data/raw/LI_HGNC.fasta`
- `data/raw/LS_classes.csv`

**Outputs**
- `data/intermediate/longest_isoform.csv`
- `data/intermediate/table_input.csv`
- `data/intermediate/table_input_withPAS.csv`

**Note**: Set `BHLH_PROJECT_ROOT` if running from a different working directory.


In [None]:
import pandas as pd
import numpy as np
from Bio import SeqIO
from pathlib import Path

project_root = Path(__import__("os").getenv("BHLH_PROJECT_ROOT", ".")).resolve()

def p(*parts):
    return str(project_root.joinpath(*parts))


## 1) Load inputs

In [None]:
interpro = pd.read_csv(p("data", "intermediate", "Metadata_CSVs", "InterPro_Domains_cleaned.csv"))
transcripts = pd.read_csv(p("data", "intermediate", "Metadata_CSVs", "Transcript_Attributes.csv"))
pfam = pd.read_csv(p("data", "intermediate", "Metadata_CSVs", "Pfam_Domains_cleaned.csv"))
classes = pd.read_csv(p("data", "raw", "LS_classes.csv"))


## 2) Build longest-isoform InterPro table (IPR011598 only)

In [None]:
fasta_file = p("data", "intermediate", "longest_isoform.fasta")
sequences = list(SeqIO.parse(fasta_file, "fasta"))
transcript_ids = [seq.description.split()[0] for seq in sequences]

filtered_interpro = interpro[interpro["ensembl_transcript_id"].isin(transcript_ids)].copy()
filtered_interpro = filtered_interpro[filtered_interpro["interpro"] == "IPR011598"].copy()

final_interpro = filtered_interpro.groupby("HGNC symbol").agg({
    "ensembl_gene_id": "first",
    "ensembl_transcript_id": "first",
    "interpro": "first",
    "interpro_start": "min",
    "interpro_end": "max",
    "interpro_short_description": "first",
    "interpro_description": "first",
}).reset_index()

final_interpro.to_csv(p("data", "intermediate", "longest_isoform.csv"), index=False)
print("Saved: data/intermediate/longest_isoform.csv")


## 3) Build `table_input.csv`

In [None]:
ordered_list = ["TFAP4", "MLX", "MLXIPL", "MLXIP", "TCFL5", "SOHLH1", "SOHLH2", "MYC", "MYCN", "MYCL", "MAX", "MNT", "MXD3", "MXD4", "MXI1", "MXD1", "SREBF2", "SREBF1", "MITF", "TFE3", "TFEC", "TFEB", "USF3", "USF2", "USF1", "NCOA1", "NCOA2", "NCOA3", "NPAS2", "CLOCK", "ARNTL2", "ARNTL", "ARNT2", "ARNT", "NPAS4", "AHRR", "AHR", "SIM2", "SIM1", "NPAS3", "NPAS1", "HIF1A", "HIF3A", "EPAS1", "HELT", "BHLHE41", "BHLHE40", "HEYL", "HEY2", "HEY1", "HES7", "HES6", "HES5", "HES3", "HES2", "HES4", "HES1", "ATOH8", "TCF4", "TCF3", "TCF12", "MYOG", "MYF6", "MYOD1", "MYF5", "FIGLA", "ID1", "ID4", "ID3", "ID2", "ASCL2", "ASCL1", "ASCL4", "ASCL5", "ASCL3", "TAL1", "LYL1", "TAL2", "NHLH2", "NHLH1", "MESP2", "MSGN1", "MESP1", "PTF1A", "FERD3L", "ATOH7", "ATOH1", "BHLHA9", "BHLHA15", "BHLHE23", "BHLHE22", "OLIG1", "OLIG3", "OLIG2", "NEUROG2", "NEUROG3", "NEUROG1", "NEUROD2", "NEUROD6", "NEUROD4", "NEUROD1", "TCF21", "MSC", "TCF24", "TCF23", "TWIST2", "TWIST1", "HAND2", "HAND1", "TCF15", "SCX"]

li_file = p("data", "raw", "LI_HGNC.fasta")
sequences = SeqIO.to_dict(SeqIO.parse(li_file, "fasta"))
file_path = p("data", "intermediate", "longest_isoform.csv")

data = pd.read_csv(file_path)

data["Sequence"] = data["HGNC symbol"].map(lambda x: str(sequences[x].seq) if x in sequences else None)
data["length"] = data["HGNC symbol"].map(lambda x: len(sequences[x].seq) if x in sequences else None)

# Isoform counts from transcripts table
isoforms_count = (
    transcripts.groupby("ensembl_gene_id")["ensembl_transcript_id"].nunique()
    .reset_index()
    .rename(columns={"ensembl_transcript_id": "number.of.isoforms"})
)

ordered_data = data[data["HGNC symbol"].isin(ordered_list)].copy()
ordered_data["HGNC symbol"] = pd.Categorical(ordered_data["HGNC symbol"], categories=ordered_list, ordered=True)
ordered_data = ordered_data.sort_values("HGNC symbol")

ordered_data = ordered_data.merge(isoforms_count, on="ensembl_gene_id", how="left")
ordered_data.to_csv(p("data", "intermediate", "table_input.csv"), index=False)

print("Saved: data/intermediate/table_input.csv")


## 4) Build `table_input_withPAS.csv` (includes IPR000014)

In [None]:
interpro_path = p("data", "intermediate", "Metadata_CSVs", "InterPro_Domains_cleaned.csv")
transcripts_path = p("data", "intermediate", "Metadata_CSVs", "Transcript_Attributes.csv")
classes_path = p("data", "raw", "LS_classes.csv")
fasta_longest_isoform = p("data", "intermediate", "longest_isoform.fasta")

output_table_input = p("data", "intermediate", "table_input_withPAS.csv")

interpro = pd.read_csv(interpro_path)
transcripts = pd.read_csv(transcripts_path)
classes = pd.read_csv(classes_path)
classes = classes.rename(columns={"Ledent2002+Simionato2007": "Ledent2002.Simionato2007"})

sequences = {rec.id: str(rec.seq) for rec in SeqIO.parse(fasta_longest_isoform, "fasta")}

fasta_transcript_ids = set(sequences.keys())

domains_of_interest = ["IPR011598", "IPR000014"]

if "ensembl_transcript_id" in interpro.columns:
    transcript_col = "ensembl_transcript_id"
elif "transcript_id" in interpro.columns:
    transcript_col = "transcript_id"
else:
    raise ValueError("InterPro CSV is missing 'ensembl_transcript_id' or 'transcript_id'.")

interpro_fasta = interpro[interpro[transcript_col].isin(fasta_transcript_ids)].copy()
interpro_fasta = interpro_fasta[interpro_fasta["interpro"].isin(domains_of_interest)].copy()

if "HGNC symbol" in interpro_fasta.columns:
    gene_col_input = "HGNC symbol"
elif "HGNC_symbol" in interpro_fasta.columns:
    gene_col_input = "HGNC_symbol"
elif "HGNC" in interpro_fasta.columns:
    gene_col_input = "HGNC"
else:
    for cand in ["gene_symbol", "gene", "Gene name", "Gene name "]:
        if cand in interpro_fasta.columns:
            gene_col_input = cand
            break
    else:
        raise ValueError("HGNC symbol column not found in InterPro CSV.")

interpro_fasta = interpro_fasta.rename(columns={gene_col_input: "HGNC.symbol"})
interpro_fasta["interpro_start"] = pd.to_numeric(interpro_fasta["interpro_start"], errors="coerce")
interpro_fasta["interpro_end"] = pd.to_numeric(interpro_fasta["interpro_end"], errors="coerce")

# One row per gene x domain

grouped = interpro_fasta.groupby(["HGNC.symbol", "interpro"], dropna=False).agg({
    "ensembl_gene_id": "first",
    transcript_col: "first",
    "interpro_start": "min",
    "interpro_end": "max",
    "interpro_short_description": "first",
    "interpro_description": "first",
}).reset_index()

# Add sequence and length

def get_sequence_for_row(row):
    tx = row.get(transcript_col)
    if pd.notna(tx) and tx in sequences:
        return sequences[tx]
    gene = row["HGNC.symbol"]
    if pd.notna(gene) and gene in sequences:
        return sequences[gene]
    return None

grouped["Sequence"] = grouped.apply(get_sequence_for_row, axis=1)
grouped["length"] = grouped["Sequence"].map(lambda s: len(s) if pd.notna(s) else None)

# Isoform counts
if "HGNC symbol" in transcripts.columns:
    tx_gene_col = "HGNC symbol"
elif "HGNC_symbol" in transcripts.columns:
    tx_gene_col = "HGNC_symbol"
elif "gene_symbol" in transcripts.columns:
    tx_gene_col = "gene_symbol"
else:
    tx_gene_col = None

if tx_gene_col:
    isoforms_count = (
        transcripts.groupby(tx_gene_col)[transcript_col]
        .nunique()
        .reset_index()
        .rename(columns={transcript_col: "number.of.isoforms", tx_gene_col: "HGNC.symbol"})
    )
else:
    isoforms_count = pd.DataFrame(columns=["HGNC.symbol", "number.of.isoforms"])

grouped = grouped.merge(isoforms_count, on="HGNC.symbol", how="left")
grouped["number.of.isoforms"] = grouped["number.of.isoforms"].fillna(1).astype(int)

# Add class labels
if "HGNC symbol" in classes.columns:
    classes = classes.rename(columns={"HGNC symbol": "HGNC.symbol"})
elif "HGNC_symbol" in classes.columns:
    classes = classes.rename(columns={"HGNC_symbol": "HGNC.symbol"})

grouped = grouped.merge(classes[["HGNC.symbol", "Ledent2002.Simionato2007"]], on="HGNC.symbol", how="left")

# Order rows
ordered_list = ["TFAP4", "MLX", "MLXIPL", "MLXIP", "TCFL5", "SOHLH1", "SOHLH2", "MYC", "MYCN", "MYCL", "MAX", "MNT", "MXD3", "MXD4", "MXI1", "MXD1", "SREBF2", "SREBF1", "MITF", "TFE3", "TFEC", "TFEB", "USF3", "USF2", "USF1", "NCOA1", "NCOA2", "NCOA3", "NPAS2", "CLOCK", "ARNTL2", "ARNTL", "ARNT2", "ARNT", "NPAS4", "AHRR", "AHR", "SIM2", "SIM1", "NPAS3", "NPAS1", "HIF1A", "HIF3A", "EPAS1", "HELT", "BHLHE41", "BHLHE40", "HEYL", "HEY2", "HEY1", "HES7", "HES6", "HES5", "HES3", "HES2", "HES4", "HES1", "ATOH8", "TCF4", "TCF3", "TCF12", "MYOG", "MYF6", "MYOD1", "MYF5", "FIGLA", "ID1", "ID4", "ID3", "ID2", "ASCL2", "ASCL1", "ASCL4", "ASCL5", "ASCL3", "TAL1", "LYL1", "TAL2", "NHLH2", "NHLH1", "MESP2", "MSGN1", "MESP1", "PTF1A", "FERD3L", "ATOH7", "ATOH1", "BHLHA9", "BHLHA15", "BHLHE23", "BHLHE22", "OLIG1", "OLIG3", "OLIG2", "NEUROG2", "NEUROG3", "NEUROG1", "NEUROD2", "NEUROD6", "NEUROD4", "NEUROD1", "TCF21", "MSC", "TCF24", "TCF23", "TWIST2", "TWIST1", "HAND2", "HAND1", "TCF15", "SCX"]

grouped = grouped[grouped["HGNC.symbol"].isin(ordered_list)].copy()
grouped["HGNC.symbol"] = pd.Categorical(grouped["HGNC.symbol"], categories=ordered_list, ordered=True)
grouped = grouped.sort_values("HGNC.symbol")

expected_cols = [
    "HGNC.symbol", "ensembl_gene_id", "ensembl_transcript_id", "interpro",
    "interpro_start", "interpro_end", "interpro_short_description", "interpro_description",
    "Sequence", "length", "number.of.isoforms", "Ledent2002.Simionato2007"
]

if transcript_col != "ensembl_transcript_id":
    grouped = grouped.rename(columns={transcript_col: "ensembl_transcript_id"})

for c in expected_cols:
    if c not in grouped.columns:
        grouped[c] = pd.NA

grouped = grouped[expected_cols]

grouped.to_csv(output_table_input, index=False)
print("Saved:", output_table_input)


## Exploratory checks (optional)

- PAS-domain entries can be inspected by filtering InterPro for `IPR000014`.
- Class mismatches can be reviewed by comparing PAS genes to `LS_classes`.
