
# Phase 3 - Extended Visualizations

This notebook adds richer plots for the phenotypes.
Inputs expected in `phase3_artifacts/`:
- phase3_integrated_data.csv
- cluster_labels.csv
- gower_distance.npy
- k_silhouette_scan.csv
- pca_loadings.csv (optional)
- umap_embeddings.csv or tsne_embeddings.csv (one of them is fine)

It will create figures in `phase3_artifacts/figs`.


In [None]:

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import linkage, fcluster, dendrogram




BASE = Path("C:/Users/HP/OneDrive/Desktop/VERO_code/Phase_3")
ART  = BASE / "phase3_outputs
FIGS = ART / "figs"
FIGS.mkdir(exist_ok=True)

df = pd.read_csv(ART / "phase3_integrated_data.csv")
labels = pd.read_csv(ART / "cluster_labels.csv")["cluster"].values

print("Data shape:", df.shape, "Clusters:", np.unique(labels, return_counts=True))


In [None]:

# 1) Embedding maps by cluster and by outcomes (size encoding)
emb_df = None
name = None
if (ART / "umap_embeddings.csv").exists():
    emb_df = pd.read_csv(ART / "umap_embeddings.csv")
    name = "UMAP"
elif (ART / "tsne_embeddings.csv").exists():
    emb_df = pd.read_csv(ART / "tsne_embeddings.csv")
    name = "tSNE"

def find_cols(patterns):
    out = []
    for c in df.columns:
        cl = c.lower()
        if any(p in cl for p in patterns):
            out.append(c)
    return list(dict.fromkeys(out))

if emb_df is not None:
    # colored by cluster
    plt.figure(figsize=(6,5))
    sc = plt.scatter(emb_df.iloc[:,0], emb_df.iloc[:,1], c=labels, s=14)
    plt.xlabel(emb_df.columns[0]); plt.ylabel(emb_df.columns[1])
    plt.title(f"{name} by cluster")
    plt.tight_layout()
    plt.savefig(FIGS / f"{name.lower()}_by_cluster.png", dpi=150)
    plt.show()

    # size by a binary outcome if available
    cand_bin = find_cols(["readmission","adr","toxicity","frailty","mortality","death","event"])
    if cand_bin:
        col = cand_bin[0]
        vals = pd.Series(df[col]).fillna(0).astype(int).values
        sizes = np.where(vals==1, 24, 8)
        plt.figure(figsize=(6,5))
        plt.scatter(emb_df.iloc[:,0], emb_df.iloc[:,1], s=sizes)
        plt.xlabel(emb_df.columns[0]); plt.ylabel(emb_df.columns[1])
        plt.title(f"{name} marker size by {col} (1 larger)")
        plt.tight_layout()
        plt.savefig(FIGS / f"{name.lower()}_size_{col}.png", dpi=150)
        plt.show()

    # size by a continuous outcome if available
    cand_cont = find_cols(["frailty_score","risk_index","survival_time","time_to_event"])
    if cand_cont:
        col = cand_cont[0]
        v = pd.Series(df[col]).to_numpy()
        if np.nanmax(v) > np.nanmin(v):
            s = 8 + 22*(v - np.nanmin(v)) / (np.nanmax(v) - np.nanmin(v))
        else:
            s = np.full_like(v, 10)
        plt.figure(figsize=(6,5))
        plt.scatter(emb_df.iloc[:,0], emb_df.iloc[:,1], s=s)
        plt.xlabel(emb_df.columns[0]); plt.ylabel(emb_df.columns[1])
        plt.title(f"{name} marker size by {col}")
        plt.tight_layout()
        plt.savefig(FIGS / f"{name.lower()}_size_{col}_cont.png", dpi=150)
        plt.show()
else:
    print("No embedding CSV found. Skipping embedding plots.")


In [None]:

# 2) Silhouette plot for the best k
KSCAN = ART / "k_silhouette_scan.csv"
GOWER = ART / "gower_distance.npy"

if KSCAN.exists() and GOWER.exists():
    ks = pd.read_csv(KSCAN).sort_values("silhouette", ascending=False)
    best_k = int(ks.iloc[0]["k"])
    D = np.load(GOWER)

    def square_to_condensed(square):
        idx = np.triu_indices_from(square, k=1)
        return square[idx]

    from sklearn.metrics import silhouette_samples
    Z = linkage(square_to_condensed(D), method="average")
    lab = fcluster(Z, best_k, criterion="maxclust")
    sil_vals = silhouette_samples(D, lab, metric="precomputed")

    plt.figure(figsize=(7,4))
    y_lower = 10
    for i in sorted(np.unique(lab)):
        ith = sil_vals[lab == i]
        ith.sort()
        size = ith.shape[0]
        y_upper = y_lower + size
        plt.fill_betweenx(np.arange(y_lower, y_upper), 0, ith)
        y_lower = y_upper + 10
    plt.axvline(np.mean(sil_vals), linestyle="--")
    plt.xlabel("Silhouette coefficient")
    plt.ylabel("Samples grouped by cluster")
    plt.title(f"Silhouette plot (k={best_k})")
    plt.tight_layout()
    plt.savefig(FIGS / "silhouette_plot.png", dpi=150)
    plt.show()
else:
    print("k_silhouette_scan.csv or gower_distance.npy missing. Skipping silhouette plot.")


In [None]:

# 3) Dendrogram from the Gower distance
GOWER = ART / "gower_distance.npy"
if GOWER.exists():
    D = np.load(GOWER)

    def square_to_condensed(square):
        idx = np.triu_indices_from(square, k=1)
        return square[idx]

    Z = linkage(square_to_condensed(D), method="average")
    plt.figure(figsize=(8,4))
    dendrogram(Z, no_labels=True, color_threshold=0.0)
    plt.title("Hierarchical dendrogram (average linkage)")
    plt.tight_layout()
    plt.savefig(FIGS / "dendrogram.png", dpi=150)
    plt.show()
else:
    print("gower_distance.npy missing. Skipping dendrogram.")


In [None]:

# 4) Distance heatmap ordered by cluster
GOWER = ART / "gower_distance.npy"
if GOWER.exists():
    D = np.load(GOWER)
    order = np.argsort(labels)
    D_ord = D[np.ix_(order, order)]
    plt.figure(figsize=(6,5))
    im = plt.imshow(D_ord, cmap="viridis")
    plt.title("Gower distance heatmap (ordered by cluster)")
    plt.colorbar(im)
    plt.tight_layout()
    plt.savefig(FIGS / "gower_heatmap_by_cluster.png", dpi=150)
    plt.show()
else:
    print("gower_distance.npy missing. Skipping heatmap.")


In [None]:

# 5) PCA biplot from loadings (if available)
LOAD = ART / "pca_loadings.csv"
if LOAD.exists():
    loadings = pd.read_csv(LOAD)
    if loadings.shape[0] >= 2:
        pc1 = loadings.iloc[0].values
        pc2 = loadings.iloc[1].values
        feats = loadings.columns.tolist()

        plt.figure(figsize=(7,6))
        plt.axhline(0, color="k", linewidth=0.5)
        plt.axvline(0, color="k", linewidth=0.5)
        for i, f in enumerate(feats):
            plt.arrow(0, 0, pc1[i], pc2[i], head_width=0.02, length_includes_head=True)
            if i % max(1, len(feats)//20) == 0:
                plt.text(pc1[i]*1.1, pc2[i]*1.1, f, fontsize=7)
        plt.xlabel("PC1 loading")
        plt.ylabel("PC2 loading")
        plt.title("PCA biplot (loadings)")
        plt.tight_layout()
        plt.savefig(FIGS / "pca_biplot_loadings.png", dpi=150)
        plt.show()
else:
    print("pca_loadings.csv not found. Skipping biplot.")


In [None]:

# 6) Cluster profile heatmap. Top numeric features by between-cluster variance.
num_cols = [c for c in df.columns if c != "cluster" and pd.api.types.is_numeric_dtype(df[c])]
if num_cols:
    bc = []
    for c in num_cols:
        means = df.groupby("cluster")[c].mean()
        bc.append((c, np.nanvar(means.values)))
    top = sorted(bc, key=lambda x: x[1], reverse=True)[:20]
    top_feats = [t[0] for t in top]

    prof = df.groupby("cluster")[top_feats].mean()
    prof = (prof - prof.min()) / (prof.max() - prof.min()).replace([np.inf, -np.inf], 0).fillna(0)

    plt.figure(figsize=(max(8, len(top_feats)/2), 5))
    im = plt.imshow(prof.values, aspect="auto", cmap="magma")
    plt.yticks(ticks=range(prof.shape[0]), labels=prof.index.tolist())
    plt.xticks(ticks=range(prof.shape[1]), labels=prof.columns.tolist(), rotation=60, ha="right")
    plt.title("Cluster profiles (scaled means of top variance features)")
    plt.colorbar(im)
    plt.tight_layout()
    plt.savefig(FIGS / "cluster_profile_heatmap.png", dpi=150)
    plt.show()
else:
    print("No numeric features found for profile heatmap.")


In [None]:

# 7) Outcome separation quick visuals. Up to four binary and four continuous.
def find_cols(patterns):
    out = []
    for c in df.columns:
        cl = c.lower()
        if any(p in cl for p in patterns):
            out.append(c)
    return list(dict.fromkeys(out))

cand_bin = find_cols(["readmission","adr","toxicity","frailty","mortality","death","event"])
cand_cont = find_cols(["frailty_score","risk_index","survival_time","time_to_event"])

for c in cand_bin[:4]:
    if set(pd.Series(df[c]).dropna().unique()).issubset({0,1}):
        rate = df.groupby("cluster")[c].mean()
        plt.figure(figsize=(5,3))
        plt.bar(rate.index.astype(str), rate.values)
        plt.title(f"Rate by cluster - {c}")
        plt.ylabel("Rate")
        plt.tight_layout()
        plt.savefig(FIGS / f"rate_by_cluster__{c}.png", dpi=150)
        plt.show()

for c in cand_cont[:4]:
    if pd.api.types.is_numeric_dtype(df[c]):
        plt.figure(figsize=(6,4))
        bins = 20
        for cl in sorted(np.unique(labels)):
            v = pd.Series(df[df["cluster"]==cl][c]).dropna().values
            if len(v) > 0:
                plt.hist(v, bins=bins, alpha=0.5, label=f"C{cl}")
        plt.legend()
        plt.title(f"Distribution by cluster - {c}")
        plt.tight_layout()
        plt.savefig(FIGS / f"dist_by_cluster__{c}.png", dpi=150)
        plt.show()


In [None]:

# 8) Print stability summary if present
stab = ART / "stability_report.txt"
if stab.exists():
    print(stab.read_text(encoding="utf-8"))
else:
    print("stability_report.txt not found. Run stability step in phenotyping notebook to produce it.")
