# NETWORK PLOTS BASED ON THE IMPORTANCE OF CELL POPULATIONS IN CLASSIFICATION

### Required Libraries

In [None]:
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib import cm
from sklearn.manifold import TSNE

### Read Data

In [None]:
CONS_FILE = "Consolidated_A_B_minimal.csv"
IMP_A_FILE = "univariate_importance_classifier_WCCT.csv"
IMP_B_FILE = "univariate_importance_classifier_VXA.csv"
#Parameters
EDGE_CUTOFF = 0.70
TSNE_SEED = 42
TSNE_PERPLEXITY = 10
TSNE_MAX_ITER = 1200
SIZE_MIN, SIZE_MAX = 200, 1800
DPI = 300

In [None]:
cons = pd.read_csv(CONS_FILE)
impA = pd.read_csv(IMP_A_FILE)
impB = pd.read_csv(IMP_B_FILE)

### Cleaning

In [None]:
def clean_text(x: str) -> str:
    return str(x).strip().replace("_", " ").replace(".", " ")

In [None]:
cons["Consolidated Name"] = cons["Consolidated Name"].apply(clean_text)
impA["population"] = impA["population"].apply(clean_text)
impB["population"] = impB["population"].apply(clean_text)

### Match cell populations

In [None]:
valid_pops = sorted(set(cons["Consolidated Name"]) & set(impA["population"]) & set(impB["population"]))
cons = cons[cons["Consolidated Name"].isin(valid_pops)]
impA = impA[impA["population"].isin(valid_pops)]
impB = impB[impB["population"].isin(valid_pops)]

### Category colors

In [None]:
def norm_cat(c):
    return "Other" if pd.isna(c) else str(c).strip()

In [None]:
import matplotlib as mpl
import matplotlib.cm as cm
cats = sorted(cons["Category"].unique())
cmap = mpl.colormaps.get_cmap("tab20").resampled(len(cats))
CAT2COLOR = {cat: cm.colors.to_hex(cmap(i)) for i, cat in enumerate(cats)}
def build_profiles(df_cons):
    mat = df_cons.pivot_table(index="Consolidated Name", columns="reagent", values="median", aggfunc="mean")
    mat = mat.apply(lambda col: (col - col.mean()) / (col.std(ddof=0) or 1.0))
    return mat.fillna(0.0)
def corr_populations(mat):
    return mat.T.corr(method="spearman")

### Pooled data

In [None]:
pooled_cons = (
    cons.groupby(["Category", "Consolidated Name", "reagent"], as_index=False)
        .agg(median=("median", "mean"))
)
pooled_cons["Study"] = "Pooled"
mat_pooled = build_profiles(pooled_cons)
pops_order = mat_pooled.index.tolist()
tsne = TSNE(n_components=2, perplexity=TSNE_PERPLEXITY, max_iter=TSNE_MAX_ITER,
            random_state=TSNE_SEED, init="random", learning_rate="auto")
xy = tsne.fit_transform(mat_pooled.values)
pos_tsne = {p: xy[i] for i, p in enumerate(pops_order)}

In [None]:
def repel_positions(pos, sizes, pad=0.06, steps=1200, lr=0.05):
    names = list(pos.keys())
    pts = np.array([pos[n] for n in names])
    pts = (pts - pts.min(0)) / (pts.max(0) - pts.min(0) + 1e-9)
    arr_sizes = np.array([sizes[n] for n in names])
    rad = np.sqrt(arr_sizes)
    rad = (rad - rad.min()) / (rad.max() - rad.min() + 1e-9) * 0.05 + 0.02
    for _ in range(steps):
        for i in range(len(pts)):
            for j in range(i+1, len(pts)):
                d = pts[i] - pts[j]
                dist = np.linalg.norm(d) + 1e-12
                min_d = rad[i] + rad[j] + pad
                if dist < min_d:
                    push = (min_d - dist) * (d / dist)
                    pts[i] += lr * push
                    pts[j] -= lr * push
    pts = (pts - pts.mean(0)) * 1.8
    return {n: p for n, p in zip(names, pts)}

In [None]:
def widen_positions(pos: dict, factor: float = 0.5) -> dict:
    pts = np.vstack([np.asarray(v, float) for v in pos.values()])
    cx, cy = pts.mean(axis=0)
    return {k: (np.asarray(v, float) - [cx, cy]) * factor + [cx, cy] for k, v in pos.items()}

In [None]:
def scale_sizes(values):
    v = values.replace([np.inf, -np.inf], np.nan).fillna(0).clip(lower=0)
    vmin, vmed, vmax = float(v.min()), float(v.median()), float(v.max())
    if vmax == vmin:
        scaled = pd.Series(SIZE_MIN, index=v.index)
    else:
        scaled = SIZE_MIN + (v - vmin) / (vmax - vmin) * (SIZE_MAX - SIZE_MIN)
    return scaled, {"min": vmin, "median": vmed, "max": vmax}

### Building Networks

In [None]:
def build_graph(cons_subset, imp_table, imp_col, pos_base):
    mat = build_profiles(cons_subset)
    corr = corr_populations(mat)
    pops = mat.index.tolist()

    imp_series = imp_table.set_index("population")[imp_col].reindex(pops).fillna(0)
    sizes, legend_stats = scale_sizes(imp_series)

    cat_map = cons.set_index("Consolidated Name")["Category"].to_dict()
    colors = {p: CAT2COLOR.get(cat_map.get(p, "Other"), "#BBBBBB") for p in pops}

    G = nx.Graph()
    for p in pops:
        G.add_node(p, size=float(sizes.loc[p]), color=colors[p])
    for i in range(len(pops)):
        for j in range(i+1, len(pops)):
            val = corr.iloc[i, j]
            if np.isfinite(val) and abs(val) >= EDGE_CUTOFF:
                G.add_edge(pops[i], pops[j], weight=float(val))

    pos_clean = repel_positions(pos_base, {p: float(sizes.loc[p]) for p in pops})
    return G, pos_clean, legend_stats

In [None]:
def draw_graph(G, pos, title, legend_stats, png_name):
    plt.figure(figsize=(10, 8))
    pos = widen_positions(pos, factor=0.5)
    nx.draw_networkx_edges(G, pos, alpha=0.35, width=0.9, edge_color="#666666")
    ns = [G.nodes[n]["size"] for n in G.nodes()]
    nc = [G.nodes[n]["color"] for n in G.nodes()]
    nx.draw_networkx_nodes(G, pos, node_size=ns, node_color=nc,
                           edgecolors="white", linewidths=0.7, alpha=0.98)

    # Legends
    handles = [plt.Line2D([0], [0], marker='o', color='w', label=cat,
                          markerfacecolor=CAT2COLOR[cat], markersize=9)
               for cat in cats]
    leg1 = plt.legend(handles=handles, title="Category", loc="upper left",
                      bbox_to_anchor=(-0.02, 0.25), frameon=False)
    plt.gca().add_artist(leg1)

    s_min, s_med, s_max = SIZE_MIN, SIZE_MIN + (SIZE_MAX - SIZE_MIN)*0.5, SIZE_MAX
    for lab, sval in [
        (f"min = {legend_stats['min']:.2f}", s_min),
        (f"med = {legend_stats['median']:.2f}", s_med),
        (f"max = {legend_stats['max']:.2f}", s_max)
    ]:
        plt.scatter([], [], s=sval, c="#777777", edgecolors="white", linewidths=0.6, label=lab)
    plt.legend(scatterpoints=1, frameon=False, labelspacing=1.0,
               title="Importance (−log10 p)", loc="upper right")

    plt.title(title, fontsize=18, fontweight="bold")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(png_name, dpi=DPI)
    plt.show()

In [None]:
G_A, posA, legA = build_graph(pooled_cons, impA, "neglog10p", pos_tsne)
draw_graph(G_A, posA, "Network — node size from Study A (Classifier)", legA, "Network_Classifier_A.png")

In [None]:
G_B, posB, legB = build_graph(pooled_cons, impB, "neglog10p", pos_tsne)
draw_graph(G_B, posB, "Network — node size from Study B (Classifier)", legB, "Network_Classifier_B.png")