In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# unscaled, log-normalized counts, with conditions subsampled to the same number of cells 
# and 2000 highly variable genes calculated jointly across all perturbation conditions, including control, using scanpy28 with default parameters (Supplementary Methods)

In [3]:
from anndata import read_h5ad
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd 
import scanpy as sc
import seaborn as sns
import string
import sys
from sklearn.linear_model import LinearRegression
sys.path.append("../../scxmatch/src/")
from scxmatch import *
np.random.seed(42)

In [4]:
reference = 0
group_by = "dose_value"

In [5]:
ram = pd.read_csv("total_RAM.csv")
ram.loc[ram["k"] == 0, "k"] = ram.loc[ram["k"] == 0, "n_obs"] 
ram = ram.astype(float)

In [6]:
df = pd.concat([pd.read_csv(f"../evaluation_results/runtime_memory_log_{n_obs}.txt", delimiter=",") for n_obs in [500, 1000, 2000, 5000, 10000, 20000, 50000]])

In [7]:
df_filtered = df[~(df.eq(df.columns).all(axis=1))]
df_filtered.loc[df_filtered["k"].isna(), "k"] = df_filtered.loc[df_filtered["k"].isna(), "n_obs"] 
df_filtered = df_filtered.astype(float)

ValueError: could not convert string to float: '193.44 MB'

In [None]:
df_filtered.set_index(["k", "n_obs", "n_var"])]

In [None]:
conc = pd.concat([ram.set_index(["k", "n_obs", "n_var"]), df_filtered.set_index(["k", "n_obs", "n_var"])], axis=1).reset_index()

In [None]:
conc[conc.isna().any(axis=1)].sort_values(["n_obs", "k"])[["k", "n_obs", "n_var"]].sort_values(["k", "n_obs", "n_var"])

In [None]:
conc.dropna()["n_var"].unique()

In [None]:
conc["upper_lim_edges"] = conc["k"] * conc["n_obs"]

In [None]:
conc["t_total"] = conc["t_NN [s]"] + conc["t_matching [s]"]

In [None]:
f, axs = plt.subplots(1, 2, figsize=(12, 6))

X = conc.loc[(conc["k"] > 0) & (conc["n_edges"].notna()), "n_edges"].values.reshape(-1, 1)
y = conc.loc[(conc["k"] > 0) & (conc["n_edges"].notna()), "total_ram_gb"].values.reshape(-1, 1)
lr = LinearRegression().fit(X, y)
lr.coef_

conc[["k", "n_obs", "n_var"]] = conc[["k", "n_obs", "n_var"]].astype(int)
sns.scatterplot(conc.loc[conc["k"] > 0], ax=axs[0], x="n_edges", y="total_ram_gb", hue="n_obs", size="k", sizes=(50, 200), style="n_var", palette=sns.color_palette("hls", 7), alpha=0.5, legend=False)

X_fit = np.linspace(X.min(), X.max(), 100).reshape(-1, 1)
y_fit = lr.predict(X_fit)
axs[0].plot(X_fit, y_fit, color="grey", label="Linear Fit", ls="dotted")
axs[0].set_ylabel("Total RAM [GB]")
axs[0].set_xlabel("Number of edges in distance graph")
axs[0].set_xscale("log")
axs[0].set_title("Memory requirements")

sns.scatterplot(conc.loc[conc["k"] > 0], ax=axs[1], x="n_obs", y="t_total", hue="n_obs", size="k", sizes=(50, 200), style="n_var", palette=sns.color_palette("hls", 7), alpha=0.5, legend="full")
sns.move_legend(axs[1], loc="right", bbox_to_anchor=(1.24, 0.5))
axs[1].set_ylabel("Total time [s]")
axs[1].set_xlabel("Number of samples")
axs[1].set_title("Time requirements")

labels = ["a", "b"]
for i, label in enumerate(labels):
    axs[i].text(
        -0.05, 1.1,  # Position (normalized figure coordinates)
        labels[i],   # Corresponding letter
        transform=axs[i].transAxes,  # Relative to subplot
        fontsize=10, fontweight='bold', va='top', ha='left'
    )
    axs[i].tick_params(axis='x', labelrotation=90)

plt.tight_layout()

In [None]:
conc.loc[conc["n_obs"] == 50000]

In [None]:
df.loc[df["k"].isna(), "k"] = df.loc[df["k"].isna(), "#nodes"] - 1

In [None]:
df.rename({"test": "dose value", "s": "rel. support"}, axis=1, inplace=True)

In [None]:
groups = sorted(df["dose value"].unique())

In [None]:
colors = sns.color_palette("hls", len(groups) + 1)
pal = {group: colors[i] for i, group in enumerate(groups + [reference])}

In [None]:
melted = pd.melt(df[df["#nodes"] >= 1000], id_vars=["k", "ref", "dose value", "#nodes"], var_name="metric")

In [None]:
g = sns.relplot(melted, x="k", y="value", row="metric", hue="dose value", col="#nodes", facet_kws={"sharex": False, "sharey":False}, kind="line", palette=pal, marker="o", linestyle="dashed", height=3, aspect=1)
sns.move_legend(g, loc="upper center", ncol=4, bbox_to_anchor=(0.5, 1.05))
    
labels = string.ascii_lowercase
for i in range(len(melted["metric"].unique())):
    g.axes[i, 0].text(
        -0.05, 1.2,  # Position (normalized figure coordinates)
        labels[i],   # Corresponding letter
        transform=g.axes[i, 0].transAxes,  # Relative to subplot
        fontsize=10, fontweight='bold', va='top', ha='left'
    )
plt.savefig("../plots/matching_weights.pdf", bbox_inches="tight")