In [None]:
import numpy as np
import pandas as pd

from biom import load_table, Table

from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr, kruskal

import matplotlib.pyplot as plt
try:
    import ptitprince as pt
    HAS_PTITPRINCE = True
except Exception:
    HAS_PTITPRINCE = False
import seaborn as sns
import textwrap

In [None]:
# Load metadata
metadata = pd.read_csv("tests/thdmi/thdmi_metadata_filtered_samples_clean_variables.tsv", sep='\t', dtype=str, header=0, index_col=0)

In [None]:
out_prefix = 'THDMI'

df = pd.read_csv(f'{out_prefix}.scores.tsv', sep = '\t', index_col=0)
    
X_loadings = pd.read_csv(f'{out_prefix}.Lx.tsv', sep = '\t', index_col=0)
Y_loadings = pd.read_csv(f'{out_prefix}.Ly.tsv', sep = '\t', index_col=0)

In [None]:
def plot_X_Y(df, metadata, group_col, palette):
    df[group_col] = metadata[group_col]
    n_components = len(df.filter(like='U').columns)
    plt.figure(figsize=(4 * n_components, 4))
    for i in range(n_components):
        axp = plt.subplot(1, n_components, i + 1)
        legend_labels = {}
        for cohort, group in df.groupby(group_col):
            x = group[f"U{i+1}"]; y = group[f"V{i+1}"]
            if len(x) > 1:
                r, p = pearsonr(x, y); label = f"{cohort} (r={r:.2f}, p={p:.1e})"
            else:
                label = f"{cohort} (n=1)"
            if cohort not in legend_labels: legend_labels[cohort] = label
            axp.scatter(x, y, alpha=0.3, color=palette.get(cohort, "#555"), label=legend_labels.get(cohort))
            if len(x) > 1:
                slope, intercept = np.polyfit(x, y, 1)
                xs = np.sort(x)
                axp.plot(xs, slope * xs + intercept, linestyle="--", linewidth=2, color=palette.get(cohort, "#555"))
        r_all, p_all = pearsonr(df[f"U{i+1}"], df[f"V{i+1}"])
        axp.set_xlabel("Canonical Variate (X)")
        axp.set_ylabel("Canonical Variate (Y)")
        if p_all < .0001:
            p_lb = '<.0001'
        else:
            p_lb = f"{p_all:.1e}"
        axp.set_title(f"CC{i+1}: r={r_all:.2f}, p={p_lb}")
        handles, labels = axp.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        axp.legend(by_label.values(), by_label.keys(), fontsize=8, frameon=True)
        axp.grid(False)
    plt.tight_layout()
    plt.savefig(f"../results/{out_prefix}_CV_scatter.png", dpi=400)
    plt.savefig(f"../results/{out_prefix}_CV_scatter.svg", dpi=400)
    plt.show()

In [None]:
palette = {'Japan': '#9467bd', 'Mexico': '#2ca02c', 'Spain': '#ff7f0e', 'US': '#1f77b4', 'UK': '#d62728'}

plot_X_Y(df, metadata, 'thdmi_cohort', palette)

In [None]:
def kw_test(df, metadata, group_col):
    df[group_col] = metadata[group_col]
    
    # U1 by cohort
    order = df.groupby(group_col)["U1"].median().sort_values().index
    grouped = [df.loc[df[group_col] == c, "U1"].dropna() for c in order]
    kw_stat, kw_pval = kruskal(*grouped)
    plt.figure(figsize=(5, 5))
    pt.RainCloud(x=group_col, y=f"U1", data=df, palette=palette,  width_viol=0.5,
                 width_box=0.2, orient="v", alpha=0.8, order=order)
    if kw_pval < .0001:
        p_lb = '<.0001'
    else:
        p_lb = f"{kw_pval:.1e}"
    plt.title(f"X CV1\nKruskal–Wallis: H={kw_stat:.2f}, p={p_lb}")
    plt.grid(False)
    plt.ylabel(f"Canonical Variate (X)", fontsize=12)
    plt.xlabel('')
    plt.tight_layout()
    plt.savefig(f"../results/X_CV1_cohort_raincloud_plot.png", dpi=400)
    plt.savefig(f"../results/X_CV1_cohort_raincloud_plot.svg", dpi=400)
    plt.show()

    # Y1 by cohort
    order = df.groupby(group_col)["V1"].median().sort_values().index
    grouped = [df.loc[df[group_col] == c, "V1"].dropna() for c in order]
    kw_stat, kw_pval = kruskal(*grouped)
    plt.figure(figsize=(5, 5))
    pt.RainCloud(x=group_col, y="V1", data=df, palette=palette,  width_viol=0.5,
                 width_box=0.2, orient="v", alpha=0.8, order=order)
    plt.grid(False)
    if kw_pval < .0001:
        p_lb = '<.0001'
    else:
        p_lb = f"{kw_pval:.1e}"
    plt.title(f"Y CV1\nKruskal–Wallis: H={kw_stat:.2f}, p={p_lb}")
    plt.tight_layout()
    plt.xlabel('')
    plt.ylabel(f"Y CV1", fontsize=12)
    plt.savefig(f"../results/Y_CV1_cohort_raincloud_plot.png", dpi=400)
    plt.savefig(f"../results/Y_CV1_cohort_raincloud_plot.svg", dpi=400)
    plt.show()



In [None]:
kw_test(df, metadata, 'thdmi_cohort')

In [None]:
# --- Top loadings plot (CV1) ---
def plot_loadings(top_x: pd.Series, top_y: pd.Series, outfile: str):
    s1, s2 = top_x.copy(), top_y.copy()
    lim = 1.1 * np.nanmax(np.abs(pd.concat([s1, s2]).values))
    n_cols = max(len(s1), len(s2))
    fig_w = 1.2 + 0.35 * n_cols
    fig, axes = plt.subplots(1, 2, figsize=(fig_w + 7, 4), sharey=True)
    sns.set(style="whitegrid")

    def _bar(ax, series, title):
        series = series.sort_values(ascending=False)
        xlabels = [textwrap.fill(str(i), width=20) for i in series.index]
        colors = np.where(series.values >= 0, "0.30", "0.70")
        ax.bar(xlabels, series.values, color=colors, edgecolor="none")
        ax.axhline(0, lw=1, color="0.3", linestyle="-")
        ax.set_ylim(-lim, lim)
        ax.set_ylabel("Loading")
        ax.set_title(title, fontsize=12, weight="semibold")
        ax.grid(axis="y", linestyle=":", alpha=0.4)
        ax.tick_params(axis="x", rotation=90, labelsize=9)
        ax.grid(False)
        sns.despine(ax=ax, left=True, bottom=True)

    _bar(axes[0], s1, "Top 10 X by CV1")
    _bar(axes[1], s2, f"Top 10 Y by CV1")
    
    plt.tight_layout()
    plt.savefig(outfile, bbox_inches="tight", dpi=300)
    plt.show()




In [None]:
tax = pd.read_csv('/home/lakhatib/wol_tree/taxonomy.tsv', sep = '\t', index_col='Feature ID')
def add_tax(gotus):
    tax['species'] = [taxonomy.split(';')[-1].strip() for taxonomy in tax['Taxon']]
    tax['species_otu'] = tax['species'] + '; ' + tax.index
    tax_map = tax['species_otu'].to_dict()
    species_otu = gotus.map(tax_map)
    return species_otu

In [None]:
X_loadings.index = add_tax(X_loadings.index)

In [None]:
X_loadings.sort_values(by='Lx')

In [None]:
# Select top 10 by absolute loading
top_X = X_loadings["Lx"].sort_values(key=np.abs, ascending=False).head(10)
top_Y = Y_loadings["Ly"].sort_values(key=np.abs, ascending=False).head(10)
plot_loadings(top_X, top_Y,
              outfile=f"../results/{out_prefix}_feature_loadings.png")