In [9]:
import numpy as np
import torch

In [10]:
# - load importance_train-hg30_test-mm10_scores.npy
# - sdc = score dict
sdc_hg38 = np.load(
    "importance_train-hg38_test-mm10_scores.npy", allow_pickle=True
).item()
sdc_mm10 = np.load(
    "importance_train-mm10_test-hg38_scores.npy", allow_pickle=True
).item()

In [11]:
def mwcor(
    x: torch.Tensor, y: torch.Tensor, ws: torch.Tensor | str | None = None
) -> torch.Tensor:
    """Weighted correlation coefficient between the rows of x and y.

    Args:
        x (torch.Tensor): 1D or 2D tensor
        y (torch.Tensor): 1D or 2D tensor
        ws (str, optional): Weighting scheme: If tensor, it contains weights.
            If str, it must be "geometric_mean" or "arithmetic_mean". Defaults to None (= unweighted correlation).

    Returns:
        torch.Tensor: correlation coefficient(s).
        If x and y are 1D tensors, returns a scalar.
        If x and y are 2D tensors, returns a 1D tensor
    """

    if ws is not None:
        wx = x.abs()
        wx = wx / wx.max()
        wy = y.abs()
        wy = wy / wy.max()
        if isinstance(ws, torch.Tensor):
            w = ws
        elif ws == "geometric_mean":
            w = torch.sqrt(wx * wy)
        elif ws == "arithmetic_mean":
            w = 0.5 * (wx + wy)
        else:
            raise ValueError(
                "ws must be a tensor or None or 'geometric_mean' or 'arithmetic_mean'"
            )
    else:
        w = torch.ones_like(x)

    # - assert that x, y, are the same shape
    assert x.shape == y.shape, "x and y must be the same shape"
    # - assert x and y are either both vectors or both 2D tensors
    assert x.dim() in [1, 2], "x and y must be 1D or 2D tensors"
    if x.dim() == 1:
        x = x.unsqueeze(0)
        y = y.unsqueeze(0)
    # - deal with the weights
    assert w.dim() == 1, "w must be a 1D tensor"
    # - assert that w is the same length as the rows of x
    assert (
        w.shape[0] == x.shape[1]
    ), "w must have the same length as the rows of x and y"

    ws = w.sum()
    wn = w / ws

    mx = (x * wn).sum(dim=1)
    my = (y * wn).sum(dim=1)

    cxy = ((x.t() - mx).t() * (y.t() - my).t() * wn).sum(dim=1)
    cxx = ((x.t() - mx).t().square() * wn).sum(dim=1)
    cyy = ((y.t() - my).t().square() * wn).sum(dim=1)

    return cxy / (cxx * cyy).sqrt()

In [67]:
def plot_cors2(tm, bm, mm, gm, tstr, ax=None):
    if ax is None:
        fig, ax = plt.subplots()

    # Calculate correlations
    ttb = []
    ttg = []
    ttm = []
    for i in range(mm.shape[0]):
        ttb.extend(mwcor(torch.Tensor(tm[i, :]), torch.Tensor(bm[i, :]), None).numpy())
        ttg.extend(mwcor(torch.Tensor(tm[i, :]), torch.Tensor(gm[i, :]), None).numpy())
        ttm.extend(mwcor(torch.Tensor(tm[i, :]), torch.Tensor(mm[i, :]), None).numpy())

    # Create density plot
    df = pd.DataFrame({"MORALE": ttm, "GRL": ttg})
    cmap = sns.light_palette("#333333", as_cmap=True)

    # Perform Wilcoxon rank-sum test
    #    morale_vs_grl = stats.wilcoxon(np.array(ttm), np.array(ttg))
    morale_vs_grl_g = stats.wilcoxon(np.array(ttm) - np.array(ttg), alternative="less")
    morale_vs_grl_m = stats.wilcoxon(
        np.array(ttm) - np.array(ttg), alternative="greater"
    )
    # stats_results[f"{fac}_dFP"] = {}
    # stats_results[f"{fac}_dFP"]["MORALE_vs_GRL"] = {
    #    "statistic": morale_vs_grl.statistic,
    #    "pvalue": morale_vs_grl.pvalue,
    # }
    pvg = morale_vs_grl_g.pvalue
    pvm = morale_vs_grl_m.pvalue

    tstr = tstr + f"|p={pvm:.2e}"  # + f"|pvm={pvm:.2e}"
    # Use the provided axis for all plotting
    sns.kdeplot(
        data=df, x="MORALE", y="GRL", cmap=cmap, fill=True, levels=7, alpha=1, ax=ax
    )
    sns.kdeplot(
        data=df,
        x="MORALE",
        y="GRL",
        colors="lightgray",
        fill=False,
        levels=7,
        alpha=1,
        linewidth=2,
        ax=ax,
    )
    sns.kdeplot(
        data=df,
        x="MORALE",
        y="GRL",
        colors="black",
        fill=False,
        levels=7,
        alpha=1,
        linewidths=0.5,
        ax=ax,
    )

    min_val = np.min([df["MORALE"].min(), df["GRL"].min()]) * 1.05
    max_val = np.max([df["MORALE"].max(), df["GRL"].max()]) * 1.05

    ax.scatter(ttm, ttg, s=4, alpha=0.5, color="black")  # Increased point size
    ax.plot(
        [min_val, max_val],
        [min_val, max_val],
        color="firebrick",
        alpha=0.5,
        linewidth=2,
    )

    # Larger font sizes for labels and title
    ax.set_xlabel("MORALE", fontsize=22)
    ax.set_ylabel("GRL", fontsize=22)
    ax.grid(True, linestyle="--", alpha=0.5)
    ax.set_title(tstr, fontsize=26)

    # Make tick labels larger
    ax.tick_params(axis="both", which="major", labelsize=16)

    return ax

In [74]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import scipy

# Set global font sizes
plt.rcParams.update(
    {
        "font.size": 18,  # Default font size
        "axes.titlesize": 22,  # Title font size
        "axes.labelsize": 18,  # Axis label size
        "xtick.labelsize": 18,  # X-tick label size
        "ytick.labelsize": 18,  # Y-tick label size
        "legend.fontsize": 18,  # Legend font size
        "figure.titlesize": 22,  # Figure title size
    }
)

TFs = sdc_hg38.keys()
MDs = ["BM-hg38", "BM-mm10", "MORALE-hg38", "GRL-hg38"]
DSs = ["dFP", "dFN"]
n_rows = 4
n_cols = 2

# Create figure and axes grid
fig, axes = plt.subplots(
    n_rows, n_cols, figsize=(14, 6 * n_rows)
)  # Increased width for better readability

# Iterate over factors
for idx, fac in enumerate(TFs):
    bm = sdc_hg38[fac]["BM-hg38"]["dFP"].mean(axis=1)
    tm = sdc_hg38[fac]["BM-mm10"]["dFP"].mean(axis=1)
    mm = sdc_hg38[fac]["MORALE-hg38"]["dFP"].mean(axis=1)
    gm = sdc_hg38[fac]["GRL-hg38"]["dFP"].mean(axis=1)

    plot_cors2(tm, bm, mm, gm, f"dFP:{fac}", ax=axes[idx, 0])

    bm = sdc_hg38[fac]["BM-hg38"]["dFN"].mean(axis=1)
    tm = sdc_hg38[fac]["BM-mm10"]["dFN"].mean(axis=1)
    mm = sdc_hg38[fac]["MORALE-hg38"]["dFN"].mean(axis=1)
    gm = sdc_hg38[fac]["GRL-hg38"]["dFN"].mean(axis=1)

    plot_cors2(tm, bm, mm, gm, f"dFN:{fac}", ax=axes[idx, 1])

# Adjust layout to prevent overlapping
plt.tight_layout()
# plt.show()

# Save with higher DPI for better quality
plt.savefig(
    "importance_correlations_unweighted_hg38-source.png",
    bbox_inches="tight",
    dpi=600,
    format="png",
)
plt.close()

  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
  cset = contour_func(


In [75]:
# Set global font sizes
plt.rcParams.update(
    {
        "font.size": 18,  # Default font size
        "axes.titlesize": 22,  # Title font size
        "axes.labelsize": 18,  # Axis label size
        "xtick.labelsize": 18,  # X-tick label size
        "ytick.labelsize": 18,  # Y-tick label size
        "legend.fontsize": 18,  # Legend font size
        "figure.titlesize": 22,  # Figure title size
    }
)

TFs = sdc_hg38.keys()
MDs = ["BM-hg38", "BM-mm10", "MORALE-hg38", "GRL-hg38"]
DSs = ["dFP", "dFN"]
n_rows = 4
n_cols = 2

# Create figure and axes grid
fig, axes = plt.subplots(
    n_rows, n_cols, figsize=(14, 6 * n_rows)
)  # Increased width for better readability

# Iterate over factors
for idx, fac in enumerate(TFs):
    bm = sdc_mm10[fac]["BM-mm10"]["dFP"].mean(axis=1)
    tm = sdc_mm10[fac]["BM-hg38"]["dFP"].mean(axis=1)
    mm = sdc_mm10[fac]["MORALE-mm10"]["dFP"].mean(axis=1)
    gm = sdc_mm10[fac]["GRL-mm10"]["dFP"].mean(axis=1)

    plot_cors2(tm, bm, mm, gm, f"dFP:{fac}", ax=axes[idx, 0])

    bm = sdc_mm10[fac]["BM-mm10"]["dFN"].mean(axis=1)
    tm = sdc_mm10[fac]["BM-hg38"]["dFN"].mean(axis=1)
    mm = sdc_mm10[fac]["MORALE-mm10"]["dFN"].mean(axis=1)
    gm = sdc_mm10[fac]["GRL-mm10"]["dFN"].mean(axis=1)

    plot_cors2(tm, bm, mm, gm, f"dFN:{fac}", ax=axes[idx, 1])

# Adjust layout to prevent overlapping
plt.tight_layout()
# plt.show()

# Save with higher DPI for better quality
plt.savefig(
    "importance_correlations_unweighted_mm10-source.png",
    bbox_inches="tight",
    dpi=600,
    format="png",
)
plt.close()

  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
  cset = contour_func(
