In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Generate immune mechanism hierarchy tree (Family ‚Üí Sub-branch ‚Üí Subtype)
with case-insensitive merging, canonical variant retention, acronym restoration,
smart 'Activation of X' ‚Üî 'X activation' merging, and multi-format outputs.

Outputs:
- full tree (all subtypes)
- no-leaf tree (family + branches only)
- two-leaf tree (truncated leaf listing)
"""

import os
import re
import pandas as pd
from collections import defaultdict

# ========= CONFIG =========
INPUT_GROUPED_CSV = "outputs/grouped_tree/Vaxjo_PMIDs_mechanism_grouped.csv"
OUTDIR = "outputs/grouped_tree/final/"
os.makedirs(OUTDIR, exist_ok=True)
os.makedirs("outputs/trees", exist_ok=True)

OUTPUT_FULL = os.path.join(OUTDIR, "tree_full.txt")
OUTPUT_NOLEAF = os.path.join(OUTDIR, "tree_no_leaf.txt")
OUTPUT_TWOLEAF = os.path.join(OUTDIR, "tree_two_leaf.txt")
OUTPUT_SUMMARY = os.path.join(OUTDIR, "tree_summary.csv")

PER_FAMILY_FILES = True

# ========= NORMALIZATION HELPERS =========
def normalize_text(s: str) -> str:
    s = str(s).strip()
    s = re.sub(r"\bresponses\b", "response", s, flags=re.I)
    s = re.sub(r"\bactivations\b", "activation", s, flags=re.I)
    s = re.sub(r"\bcells\b", "cell", s, flags=re.I)
    s = re.sub(r"\bcytokines\b", "cytokine", s, flags=re.I)
    s = re.sub(r"\bantibodies\b", "antibody", s, flags=re.I)
    s = re.sub(r"\bpathways\b", "pathway", s, flags=re.I)
    s = re.sub(r"\bmechanisms\b", "mechanism", s, flags=re.I)
    s = re.sub(r"\s+", " ", s)
    s = s.replace("‚Äì", "-").replace("‚Äî", "-")
    return s.strip()

CANONICAL_REPLACEMENTS = [
    (r"\bT[- ]?cell\b", "T cell"),
    (r"\bTh[- ]?1\b", "Th1"),
    (r"\bTh[- ]?2\b", "Th2"),
    (r"\bTh[- ]?17\b", "Th17"),
    (r"\bIFN ?- ?Œ≥\b", "IFN-Œ≥"),
    (r"\bNF.?Œ∫B\b", "NF-Œ∫B"),
]

def canonicalize(s: str) -> str:
    s = normalize_text(s)
    for pat, repl in CANONICAL_REPLACEMENTS:
        s = re.sub(pat, repl, s, flags=re.I)
    return s

def restore_acronyms(s: str) -> str:
    """Restore consistent capitalization for immune acronyms."""
    s = re.sub(r"\btlr\b", "TLR", s, flags=re.I)
    s = re.sub(r"\bdc\b", "DC", s, flags=re.I)
    s = re.sub(r"\bnlrp3\b", "NLRP3", s, flags=re.I)
    s = re.sub(r"\bifn\b", "IFN", s, flags=re.I)
    s = re.sub(r"\bmhc\b", "MHC", s, flags=re.I)
    s = re.sub(r"\bmyd88\b", "MyD88", s, flags=re.I)
    s = re.sub(r"\btrif\b", "TRIF", s, flags=re.I)
    s = re.sub(r"\bsting\b", "STING", s, flags=re.I)
    return s

# ========= LOAD DATA =========
df = pd.read_csv(INPUT_GROUPED_CSV)
required_cols = {"Family", "Subtype", "Frequency"}
if not required_cols.issubset(df.columns):
    raise ValueError(f"CSV must contain {required_cols}")

# Normalize
df["Family"] = df["Family"].fillna("Other / Unclassified").astype(str)
df["Subtype"] = df["Subtype"].fillna("Unspecified").astype(str).apply(canonicalize)

# ========= Frequency-based variant retention =========
def get_most_frequent_variants(df, col):
    mapping = defaultdict(lambda: defaultdict(int))
    for _, row in df.iterrows():
        norm = canonicalize(row[col]).lower()
        mapping[norm][row[col]] += row["Frequency"]
    return {k: max(v.items(), key=lambda kv: kv[1])[0] for k, v in mapping.items()}

family_display = get_most_frequent_variants(df, "Family")
subtype_display = get_most_frequent_variants(df, "Subtype")

# ========= Build family‚Üíbranch map (simplified here; customize as needed) =========
FAMILY_TO_SUBBRANCH = {
    "T cell activation / polarization": {
        "T cell branch": [r"T cell", r"T-cell", r"T lymphocyte"],
        "Th1 branch": [r"Th1"],
        "Th2 branch": [r"Th2"],
        "Th17 branch": [r"Th17"],
        "CD4/CD8 branch": [r"CD4", r"CD8"],
        "Tfh branch": [r"Tfh"],
        "Regulatory T cell branch": [r"Treg", r"regulatory T"],
    },
    "Dendritic cell activation": {
        "DC maturation": [r"maturation"],
        "DC polarization": [r"polarization"],
        "Plasmacytoid DC": [r"plasmacytoid"],
        "Antigen presentation-related DC": [r"antigen", r"\bAPC\b", r"presentation"],
        "TLR-related DC": [r"\bTLR"],
        "Other DC activation": [r"dendritic"],
    },

    "TLR signaling": {
        "TLR2 branch": [r"\bTLR2\b"],
        "TLR3 branch": [r"\bTLR3\b"],
        "TLR4 branch": [r"\bTLR4\b"],
        "TLR5 branch": [r"\bTLR5\b"],
        "TLR7/8 branch": [r"\bTLR7\b", r"\bTLR8\b"],
        "TLR9 branch": [r"\bTLR9\b"],
        "MyD88/TRIF-related": [r"MyD88", r"TRIF"],
        "Other TLR-related": [r"toll-?like receptor", r"\bTLR\b"],
    },
    "Cytokine signaling / production": {
        "Interleukins": [r"\bIL[- ]?\d", r"interleukin"],
        "Interferons": [r"\bIFN"],
        "TNF": [r"\bTNF"],
        "Chemokines": [r"chemokine", r"\bCCL", r"\bCXCL"],
        "Inflammasome / IL-1 family": [r"\bIL-?1", r"inflammasome"],
        "Other cytokines": [r"cytokine"],
    },
    "Macrophage / innate immune activation": {
        "Macrophage": [r"macrophage"],
        "NK / Monocyte": [r"\bNK\b", r"monocyte"],
        "Innate immune cells": [r"innate"],
        "Neutrophils / Granulocytes": [r"neutrophil", r"granulocyte"],
        "Other innate activation": [r"activation"],
    },
    "Pattern recognition / PRR sensing": {
        "PRR family": [r"\bPRR\b"],
        "RIG-I-like": [r"\bRIG"],
        "NOD-like": [r"\bNOD"],
        "Pattern recognition": [r"pattern recognition"],
        "C-type lectin receptors": [r"Dectin", r"Mincle", r"\bMCL\b"],
        "Other pattern sensors": [r"recognition", r"sensing"],
    },
    "NLRP3 inflammasome activation": {
        "NLRP3 core branch": [r"\bNLRP3\b"],
        "MAPK/JNK pathway": [r"\bMAPK\b", r"\bJNK\b"],
        "Caspase / pyroptosis": [r"caspase", r"pyroptosis"],
        "Other inflammasome activity": [r"inflammasome"],
    },
    "Antigen presentation / APCs": {
        "APC activation": [r"activation", r"\bAPC\b"],
        "Cross-presentation": [r"cross-?presentation", r"\bcross\b"],
        "MHC / Co-stimulation": [r"\bMHC\b", r"\bCD40\b", r"\bCD80\b", r"\bCD86\b", r"co-?stimul"],
        "Migration / trafficking": [r"migration", r"traffick"],
        "Antigen processing / uptake": [r"antigen", r"uptake", r"processing"],
        "Other APC function": [r"presentation"],
    },
    "B cell / antibody production": {
        "B cell activation": [r"\bB cell\b", r"\bB-cell\b"],
        "Antibody production": [r"antibody", r"\bIgG\b", r"\bIgA\b", r"\bIgM\b", r"\bIgE\b"],
        "Humoral immunity": [r"humoral"],
        "Plasma cell / differentiation": [r"\bplasma\b", r"plasmablast"],
        "Germinal center / memory": [r"germinal", r"memory"],
        "Other B cell mechanisms": [r"\bB\b", r"antibody"],
    },
    "Complement / depot / formulation": {
        "Complement activation": [r"complement"],
        "Depot / release mechanisms": [r"depot", r"release"],
        "Adjuvant formulation / emulsions": [r"\balum\b", r"emulsion", r"formulation"],
        "Other": [r"activation"],
    },
    "STING / TRIF / MyD88 / RIG-I signaling": {
        "STING": [r"\bSTING\b"],
        "TRIF": [r"\bTRIF\b"],
        "MyD88": [r"\bMyD88\b"],
        "RIG-I-like": [r"\bRIG"],
        "NOD-like": [r"\bNOD"],
        "Other signaling adaptors": [r"adaptor", r"signaling"],
    },
    "Inflammatory response": {
        "Pro-inflammatory genes": [r"inflamm", r"NF[- ]?Œ∫B", r"NF[- ]?kB", r"NFkB"],
        "Cytokine-mediated inflammation": [r"cytokine"],
        "Chemokine signaling": [r"chemokine", r"\bCCL", r"\bCXCL"],
        "Immune suppression / regulation": [r"regulation", r"inhibition"],
        "Other": [r"response", r"activation"],
    },
    "Adjuvant synergy / immune modulation": {
        "Immune enhancement": [r"enhanc", r"promotion"],
        "Costimulation": [r"co-?stimul", r"\bCD40\b", r"\bCD86\b"],
        "Immune modulation": [r"modulat"],
        "Synergy": [r"synerg", r"combination", r"co-?activation"],
        "Other": [r"activation"],
    },
}

# ========= Tree construction =========
tree = defaultdict(lambda: defaultdict(list))
family_totals = df.groupby("Family")["Frequency"].sum().to_dict()

for _, row in df.iterrows():
    fam, sub, freq = row["Family"], row["Subtype"], int(row["Frequency"])
    matched = False
    for subbranch, kws in FAMILY_TO_SUBBRANCH.get(fam, {}).items():
        if any(re.search(kw, sub, flags=re.I) for kw in kws):
            tree[fam][subbranch].append((sub, freq))
            matched = True
            break
    if not matched:
        tree[fam]["Other"].append((sub, freq))

# ========= Merge ‚ÄúActivation of X‚Äù ‚Üî ‚ÄúX activation‚Äù =========
ACTIVATION_PAT = re.compile(r"^(?:Activation of|Activation of the)\s+(?P<core>.+)$", re.I)

def merge_activation_variants(subtype_freqs):
    merged = {}
    for sub, freq in subtype_freqs.items():
        m = ACTIVATION_PAT.match(sub)
        if m:
            core = m.group("core").strip()
            alt = f"{core} activation"
            key = core.lower()
        elif sub.lower().endswith(" activation"):
            core = sub[:-11].strip()
            key = core.lower()
        else:
            key = sub.lower()
        merged.setdefault(key, {"variants": {}, "freq": 0})
        merged[key]["variants"][sub] = freq
        merged[key]["freq"] += freq
    collapsed = {
        max(v["variants"].items(), key=lambda kv: kv[1])[0]: v["freq"]
        for v in merged.values()
    }
    return collapsed

# ========= Sort helpers =========
def sort_branches(branch_items):
    def key_fn(item):
        name, entries = item
        total = sum(f for _, f in entries)
        if name.lower() == "other":
            return (1e12, 0)
        return (-total, 0)
    return sorted(branch_items, key=key_fn)

# ========= Renderer =========
def render_family_block(fam, leaf_mode="full", max_leafs=2):
    fam_display = restore_acronyms(family_display.get(canonicalize(fam).lower(), fam))
    fam_total = family_totals.get(fam, 0)
    lines = [f"{fam_display} ({fam_total})"]
    for branch, entries in sort_branches(tree[fam].items()):
        entries_dict = defaultdict(int)
        for sub, freq in entries:
            entries_dict[sub] += freq
        entries_dict = merge_activation_variants(entries_dict)
        entries_sorted = sorted(entries_dict.items(), key=lambda x: (-x[1], x[0].lower()))
        branch_total = sum(freq for _, freq in entries_sorted)
        branch_display = restore_acronyms(branch)
        lines.append(f"   ‚îú‚îÄ {branch_display} ({branch_total})")

        if leaf_mode == "noleaf":
            continue
        shown = entries_sorted if leaf_mode == "full" else entries_sorted[:max_leafs]
        for i, (sub, freq) in enumerate(shown):
            sub_disp = restore_acronyms(subtype_display.get(canonicalize(sub).lower(), sub))
            connector = "   ‚îÇ    ‚îî‚îÄ" if i == len(shown) - 1 else "   ‚îÇ    ‚îú‚îÄ"
            lines.append(f"{connector} {sub_disp} ({freq})")

        if leaf_mode == "twoleaf" and len(entries_sorted) > max_leafs:
            hidden_count = len(entries_sorted) - max_leafs
            hidden_freq = sum(freq for _, freq in entries_sorted[max_leafs:])
            lines.append(f"   ‚îÇ    ‚îî‚îÄ ... ({hidden_count} more, {hidden_freq} total freq)")
    lines.append("")
    return "\n".join(lines)

# ========= Generate outputs =========
families_sorted = sorted(family_totals.items(), key=lambda x: x[1], reverse=True)
summary_rows = []

with open(OUTPUT_FULL, "w", encoding="utf-8") as f_full, \
     open(OUTPUT_NOLEAF, "w", encoding="utf-8") as f_nl, \
     open(OUTPUT_TWOLEAF, "w", encoding="utf-8") as f_2l:
    for fam, _ in families_sorted:
        f_full.write(render_family_block(fam, "full"))
        f_nl.write(render_family_block(fam, "noleaf"))
        f_2l.write(render_family_block(fam, "twoleaf"))
        for branch, entries in tree[fam].items():
            total_b = sum(freq for _, freq in entries)
            summary_rows.append({"Family": fam, "Branch": branch, "Branch_Total": total_b, "Family_Total": family_totals[fam]})

# ========= Write per-family files =========
if PER_FAMILY_FILES:
    for fam, _ in families_sorted:
        fname = os.path.join("outputs/trees", f"{fam.replace('/', '_')}.txt")
        with open(fname, "w", encoding="utf-8") as ff:
            ff.write(render_family_block(fam, "full"))

# ========= Summary =========
pd.DataFrame(summary_rows).sort_values(["Family", "Branch_Total"], ascending=[True, False]).to_csv(OUTPUT_SUMMARY, index=False)

print(f"‚úÖ Full tree written to: {OUTPUT_FULL}")
print(f"‚úÖ No-leaf tree written to: {OUTPUT_NOLEAF}")
print(f"‚úÖ Two-leaf tree written to: {OUTPUT_TWOLEAF}")
print(f"‚úÖ Summary CSV written to: {OUTPUT_SUMMARY}")


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plot immune mechanism hierarchy (Family ‚Üí Branch)
as an interactive Sunburst chart.

Input: outputs/grouped_tree/final/tree_summary.csv (generated by previous script)
Output:
  - outputs/grouped_tree/final/immune_mechanism_sunburst.html
  - outputs/grouped_tree/final/immune_mechanism_sunburst.png (if Chrome available)
"""

import os
import pandas as pd
import plotly.express as px
import plotly.io as pio

# ========= CONFIG =========
INPUT_SUMMARY = "outputs/grouped_tree/final/tree_summary.csv"
OUTDIR = "outputs/grouped_tree/final"
os.makedirs(OUTDIR, exist_ok=True)

OUTPUT_HTML = os.path.join(OUTDIR, "immune_mechanism_sunburst.html")
OUTPUT_PNG = os.path.join(OUTDIR, "immune_mechanism_sunburst.png")

# ========= LOAD =========
df = pd.read_csv(INPUT_SUMMARY)
df = df.fillna("Unknown")

# ========= Detect numeric frequency column =========
freq_col = None
for c in df.columns:
    if c.lower() in ["frequency", "freq", "count", "branch_total", "total", "n"]:
        freq_col = c
        break
if freq_col is None:
    raise ValueError(f"No frequency-like column found in {list(df.columns)}")

df["Frequency"] = pd.to_numeric(df[freq_col], errors="coerce").fillna(0).astype(int)

# ========= Deduplicate hierarchy labels =========
def dedup_hierarchy(df, levels):
    for i in range(1, len(levels)):
        parent, child = levels[i - 1], levels[i]
        df.loc[df[parent] == df[child], child] = None
    return df

# You only have Family ‚Üí Branch in summary CSV, so fill Subtype = Branch for completeness
if "Subtype" not in df.columns:
    df["Subtype"] = df["Branch"]

df = dedup_hierarchy(df, ["Family", "Branch", "Subtype"])

# ========= Sort ‚ÄúOther‚Äù last =========
def sort_with_other_last(series):
    s = series.astype(str)
    sorted_unique = sorted([x for x in s.unique() if x.lower() != "other"]) + ["Other"]
    return pd.Categorical(s, categories=sorted_unique, ordered=True)

df["Family"] = sort_with_other_last(df["Family"])
df["Branch"] = sort_with_other_last(df["Branch"])

# ========= Build plot =========
fig = px.sunburst(
    df,
    path=["Family", "Branch"],
    values="Frequency",
    color="Family",
    color_discrete_sequence=px.colors.qualitative.Pastel,
    title="Immune Mechanism Hierarchy (Family ‚Üí Branch)",
    width=1000,
    height=900,
)

fig.update_traces(
    textinfo="label+value+percent parent",
    hovertemplate="<b>%{label}</b><br>Parent: %{parent}<br>Count: %{value}<br>% of parent: %{percentParent:.1%}<extra></extra>",
)

fig.update_layout(
    title=dict(text="Immune Mechanism Hierarchy ‚Äî Family ‚Üí Branch", x=0.5, font=dict(size=22)),
    margin=dict(t=80, l=0, r=0, b=0),
    uniformtext=dict(minsize=9, mode="show"),
)

# ========= Export =========
pio.write_html(fig, OUTPUT_HTML, include_plotlyjs="cdn")
print(f"‚úÖ Interactive HTML saved to {OUTPUT_HTML}")

try:
    pio.write_image(fig, OUTPUT_PNG, width=1200, height=900, scale=3)
    print(f"‚úÖ PNG image saved to {OUTPUT_PNG}")
except Exception:
    print("‚ö†Ô∏è Skipping PNG export ‚Äî Chrome not found. To enable PNG, run: plotly_get_chrome")



In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Simplified immune mechanism pie chart (Family-level only)
from outputs/grouped_tree/final/tree_summary.csv
"""

import os
import pandas as pd
import plotly.express as px
import plotly.io as pio

# ========= CONFIG =========
INPUT_SUMMARY = "outputs/grouped_tree/final/tree_summary.csv"
OUTDIR = "outputs/grouped_tree/final"
os.makedirs(OUTDIR, exist_ok=True)

OUTPUT_HTML = os.path.join(OUTDIR, "immune_mechanism_pie.html")
OUTPUT_PNG = os.path.join(OUTDIR, "immune_mechanism_pie.png")

# ========= LOAD =========
df = pd.read_csv(INPUT_SUMMARY)
df = df.fillna("Unknown")

# Detect frequency column automatically
freq_col = None
for c in df.columns:
    if c.lower() in ["frequency", "freq", "count", "branch_total", "total", "n"]:
        freq_col = c
        break
if freq_col is None:
    raise ValueError(f"No frequency-like column found in {list(df.columns)}")

# ========= Aggregate by Family =========
family_summary = (
    df.groupby("Family", dropna=False)[freq_col]
    .sum()
    .reset_index()
    .sort_values(freq_col, ascending=False)
)

# Ensure "Other" is last
family_summary["Family"] = family_summary["Family"].astype(str)
if any(family_summary["Family"].str.lower() == "other"):
    other_row = family_summary[family_summary["Family"].str.lower() == "other"]
    family_summary = pd.concat([
        family_summary[family_summary["Family"].str.lower() != "other"],
        other_row
    ])

# ========= Plot pie chart =========
fig = px.pie(
    family_summary,
    names="Family",
    values=freq_col,
    color="Family",
    color_discrete_sequence=px.colors.qualitative.Pastel,
    title="Immune Mechanism Families (Simplified Pie Chart)",
    hole=0.0,  # full pie (set to 0.4 for donut)
)

fig.update_traces(
    textinfo="label+percent",
    textfont_size=14,
    hovertemplate="<b>%{label}</b><br>Count: %{value}<br>Percent: %{percent}<extra></extra>",
)

fig.update_layout(
    title=dict(text="Immune Mechanism Family Distribution", x=0.5, font=dict(size=22)),
    margin=dict(t=80, l=0, r=0, b=0),
)

# ========= Export =========
pio.write_html(fig, OUTPUT_HTML, include_plotlyjs="cdn")
print(f"‚úÖ Interactive Family Pie chart saved to {OUTPUT_HTML}")

try:
    pio.write_image(fig, OUTPUT_PNG, width=1200, height=800, scale=3)
    print(f"‚úÖ PNG image saved to {OUTPUT_PNG}")
except Exception:
    print("‚ö†Ô∏è Skipping PNG export ‚Äî Chrome not found. Run: plotly_get_chrome")


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Generate a simplified pie chart (inner-layer only)
showing immune mechanism family distribution with:
- Name + count + percent labels
- No legend (clean look)
- Publication-quality layout
"""

import os
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio

# ===============================
# 1Ô∏è‚É£  INPUT FILE ‚Äî same summary CSV
# ===============================
INPUT_SUMMARY = "outputs/grouped_tree/final/tree_summary.csv"
OUTPUT_HTML = "outputs/grouped_tree/final/immune_mechanism_pie.html"
OUTPUT_PNG = "outputs/grouped_tree/final/immune_mechanism_pie.png"

# ===============================
# 2Ô∏è‚É£  LOAD & AGGREGATE DATA
# ===============================
df = pd.read_csv(INPUT_SUMMARY)

# compute totals per Family
df_family = df.groupby("Family", as_index=False)["Family_Total"].max()
df_family = df_family.sort_values("Family_Total", ascending=False)

# ===============================
# 3Ô∏è‚É£  BUILD PIE CHART
# ===============================
# custom text labels: "Family<br>(count, %)"
df_family["percent"] = 100 * df_family["Family_Total"] / df_family["Family_Total"].sum()
df_family["label_text"] = df_family.apply(
    lambda r: f"{r['Family']}<br>{r['Family_Total']} ({r['percent']:.1f}%)", axis=1
)

fig = go.Figure(
    data=[
        go.Pie(
            labels=df_family["Family"],
            values=df_family["Family_Total"],
            text=df_family["label_text"],
            textinfo="text",
            textposition="inside",
            insidetextorientation="radial",
            hovertemplate="<b>%{label}</b><br>%{value} counts<br>%{percent}",
            marker=dict(line=dict(color="white", width=1.5)),
        )
    ]
)

# ===============================
# 4Ô∏è‚É£  LAYOUT TUNING
# ===============================
fig.update_layout(
    title=dict(
        text="Immune Mechanism Family Distribution",
        x=0.5,
        font=dict(size=22)
    ),
    showlegend=False,  # ‚úÖ remove legend for clean publication style
    width=1200,
    height=900,
    margin=dict(t=100, b=100, l=100, r=100),
)

# ===============================
# 5Ô∏è‚É£  SAVE OUTPUTS
# ===============================
os.makedirs(os.path.dirname(OUTPUT_HTML), exist_ok=True)
pio.write_html(fig, OUTPUT_HTML, include_plotlyjs="cdn")

# To export static PNG (needs Chrome via kaleido)
try:
    pio.write_image(fig, OUTPUT_PNG, width=1200, height=900, scale=3)
    print(f"‚úÖ Saved static image to {OUTPUT_PNG}")
except Exception as e:
    print(f"‚ö†Ô∏è Could not export PNG automatically: {e}")
    print("   You can still open the HTML and use the camera icon to download manually.")

print(f"‚úÖ Interactive HTML saved to {OUTPUT_HTML}")


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Hybrid pie chart:
- Inside labels for large slices
- Outside labels for small slices
- Shows both counts + percentages
"""

import os
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio

# ===============================
# 1Ô∏è‚É£ INPUT & OUTPUT PATHS
# ===============================
INPUT_SUMMARY = "outputs/grouped_tree/final/tree_summary.csv"
OUTPUT_HTML = "outputs/grouped_tree/final/immune_mechanism_pie_hybrid.html"
OUTPUT_PNG = "outputs/grouped_tree/final/immune_mechanism_pie_hybrid.png"

# ===============================
# 2Ô∏è‚É£ LOAD & PREPARE DATA
# ===============================
df = pd.read_csv(INPUT_SUMMARY)
df_family = df.groupby("Family", as_index=False)["Family_Total"].max()
df_family = df_family.sort_values("Family_Total", ascending=False)

df_family["percent"] = 100 * df_family["Family_Total"] / df_family["Family_Total"].sum()
df_family["label_text"] = df_family.apply(
    lambda r: f"{r['Family']}<br>{r['Family_Total']} ({r['percent']:.1f}%)", axis=1
)

# dynamically choose text position
df_family["textposition"] = df_family["percent"].apply(
    lambda p: "outside" if p < 5 else "inside"
)

# ===============================
# 3Ô∏è‚É£ BUILD PIE CHART
# ===============================
fig = go.Figure(
    data=[
        go.Pie(
            labels=df_family["Family"],
            values=df_family["Family_Total"],
            text=df_family["label_text"],
            textinfo="text",
            textposition=df_family["textposition"],
            insidetextorientation="radial",
            hovertemplate="<b>%{label}</b><br>%{value} counts<br>%{percent}",
            marker=dict(line=dict(color="white", width=1.5)),
            pull=[0.03 if p < 5 else 0 for p in df_family["percent"]],  # small slices slightly pulled out
        )
    ]
)

# ===============================
# 4Ô∏è‚É£ LAYOUT
# ===============================
fig.update_layout(
    title=dict(
        text="Immune Mechanism Family Distribution",
        x=0.5,
        font=dict(size=22)
    ),
    showlegend=False,  # cleaner look
    width=1200,
    height=900,
    margin=dict(t=100, b=100, l=100, r=150),
)

# ===============================
# 5Ô∏è‚É£ SAVE OUTPUTS
# ===============================
os.makedirs(os.path.dirname(OUTPUT_HTML), exist_ok=True)
pio.write_html(fig, OUTPUT_HTML, include_plotlyjs="cdn")

try:
    pio.write_image(fig, OUTPUT_PNG, width=1200, height=900, scale=3)
    print(f"‚úÖ Saved static image to {OUTPUT_PNG}")
except Exception as e:
    print(f"‚ö†Ô∏è Could not export PNG automatically: {e}")
    print("   You can still open the HTML and click the camera icon to download manually.")

print(f"‚úÖ Interactive HTML saved to {OUTPUT_HTML}")


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Integrate immune mechanism tree with LLM-derived adjuvant mappings
to generate a Mechanism √ó Adjuvant heatmap.

X-axis  = immune mechanism family
Y-axis  = adjuvant
Cell    = count of mechanism mentions (from LLM text)
"""

import os, re, json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

# ======================================================
# 1Ô∏è‚É£ Load canonical immune mechanism families
# ======================================================
tree_summary = "outputs/grouped_tree/final/tree_summary.csv"
df_tree = pd.read_csv(tree_summary)

canonical_families = df_tree["Family"].unique().tolist()

# For regex-based matching
FAMILY_PATTERNS = {
    "T cell activation / polarization": [r"T[- ]?cell", r"Th1", r"Th2", r"CD4", r"CD8"],
    "Dendritic cell activation": [r"dendritic", r"DC"],
    "TLR signaling": [r"\bTLR", r"toll-?like receptor"],
    "Cytokine signaling / production": [r"cytokine", r"interleukin", r"IFN", r"TNF"],
    "Macrophage / innate immune activation": [r"macrophage", r"innate", r"NK", r"monocyte"],
    "Pattern recognition / PRR sensing": [r"\bPRR\b", r"pattern recognition", r"NOD", r"RIG"],
    "NLRP3 inflammasome activation": [r"NLRP3", r"inflammasome"],
    "Antigen presentation / APCs": [r"APC", r"antigen", r"presentation", r"MHC"],
    "B cell / antibody production": [r"\bB[- ]?cell", r"antibody", r"humoral"],
    "Complement / depot / formulation": [r"complement", r"depot", r"alum", r"formulation"],
    "STING / TRIF / MyD88 / RIG-I signaling": [r"STING", r"TRIF", r"MyD88", r"RIG"],
    "Inflammatory response": [r"inflamm", r"NF[- ]?Œ∫B", r"NF[- ]?kB"],
    "Adjuvant synergy / immune modulation": [r"synerg", r"modulat", r"enhanc", r"co-?stimul"]
}

# ======================================================
# 2Ô∏è‚É£ Parse LLM-generated JSON-like data
# ======================================================
file_path = "outputs/Vaxjo_PMIDs_mechanism_summary_raw_outputs_llama3.2.txt"

adjuvant_data = []
with open(file_path, "r", encoding="utf-8") as f:
    content = f.read()

for block in content.split("{"):
    if '"adjuvant"' in block and '"mechanism_subtypes"' in block:
        try:
            js = "{" + block.split("}")[0] + "}"
            adjuvant_data.append(json.loads(js))
        except Exception:
            continue

print(f"Parsed {len(adjuvant_data)} adjuvants from LLM output")

# ======================================================
# 3Ô∏è‚É£ Map each mechanism subtype ‚Üí family
# ======================================================
records = []
for entry in adjuvant_data:
    adjuvant = entry.get("adjuvant", "").strip()
    subtypes = entry.get("mechanism_subtypes", [])
    for sub in subtypes:
        mech = sub.get("mechanism subtype", "").strip()
        fam_match = "Other / Unclassified"
        for fam, pats in FAMILY_PATTERNS.items():
            if any(re.search(p, mech, flags=re.I) for p in pats):
                fam_match = fam
                break
        records.append((adjuvant, fam_match))

df = pd.DataFrame(records, columns=["Adjuvant", "Family"])

# ======================================================
# 4Ô∏è‚É£ Aggregate counts
# ======================================================
heatmap_df = (
    df.groupby(["Adjuvant", "Family"])
      .size()
      .reset_index(name="Count")
      .pivot(index="Adjuvant", columns="Family", values="Count")
      .fillna(0)
      .astype(int)
)

# Keep top adjuvants for readability
heatmap_df = heatmap_df.loc[
    heatmap_df.sum(axis=1).sort_values(ascending=False).head(25).index
]

# ======================================================
# 5Ô∏è‚É£ Plot heatmap
# ======================================================
plt.figure(figsize=(16, max(8, len(heatmap_df) * 0.4)))
sns.heatmap(
    heatmap_df,
    cmap="YlGnBu",
    linewidths=0.5,
    annot=True,
    fmt=".0f",
    cbar_kws={"label": "# of Mechanism Mentions"}
)
plt.title("Mechanism‚ÄìAdjuvant Association Heatmap", fontsize=18, pad=20)
plt.xlabel("Immune Mechanism Family", fontsize=13)
plt.ylabel("Adjuvant", fontsize=13)
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()


In [None]:
%pip install seaborn

In [None]:
import pandas as pd

# Load the CSV
df = pd.read_csv("outputs/Vaxjo_PMIDs_mechanism_subtypes_frequency.csv")

# Ensure correct column name ‚Äî adjust if needed
total = df["Frequency"].sum()

print(f"üî¢ Total Frequency = {total:,}")
