In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUDNN_DETERMINISTIC"] = "1"


In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from torch_geometric.utils import softmax
from torch_geometric.nn import global_add_pool, global_mean_pool
import torch

In [None]:
import random
def set_seeds(seed):
  random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  np.random.seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
  torch.use_deterministic_algorithms(True)

In [None]:
adata = sc.read_h5ad("covid.h5ad") # change file name here (cardio.h5ad, icb.h5ad)
adata

In [None]:
try:
    df = pd.DataFrame(adata.X.toarray())
except:
    df = pd.DataFrame(adata.X)
df.index = adata.obs.index
df[["patient","cell_type_annotation", "response"]] = adata.obs[["patient","cell_type_annotation", "label"]]

In [None]:
def get_data_batch_count(df, all_ct, samples, meta=None, binary=True, attn2=True):
  ct_dict = dict({ct: idx for idx, ct in enumerate(all_ct)})
  Xs = []
  ys = []
  batches = []

  if meta is not None:
    meta = torch.tensor(meta.loc[samples, :].to_numpy(), dtype=torch.float)

  for idx, sample in enumerate(samples):
    sample_df = df[df["patient"]==sample]
    x = sample_df.iloc[:,:df.shape[-1]-3].to_numpy()
    y = sample_df.iloc[:,-1].to_numpy()[0]
    batch = [(idx * len(all_ct) + ct_dict[ct]) for ct in sample_df["cell_type_annotation"].to_list()]\
            if attn2 else [idx for _ in range(len(sample_df))]
    Xs.append(x)
    ys.append(y)
    batches.append(batch)
  Xs = torch.tensor(np.concatenate(Xs), dtype = torch.float)
  ys = torch.tensor(ys, dtype = torch.float if binary else torch.long)
  batches = torch.tensor(np.concatenate(batches))
  return Xs, ys, batches, meta


In [None]:
class Model(torch.nn.Module):
  def __init__(self, n_in, n_out=1, n_in_meta=0, attn1=True, attn2=True, dropout=0.0, use_softmax=True, n_layers_lin=1, n_layers_lin2=0, n_layers_lin_meta=1, n_hid=32, n_hid2=32):
    super().__init__()
    self.lin = torch.nn.Sequential(
        *self.get_lin_layers(n_layers_lin, n_in, n_hid, n_hid, dropout)
    )
    curr_in = n_in if len(self.lin)==0 else n_hid
    self.w_c = torch.nn.Sequential(
        torch.nn.Linear(curr_in, 1),
        torch.nn.Dropout(dropout)
    )
    self.n_in1 = curr_in
    self.lin2 = torch.nn.Sequential(
        *self.get_lin_layers(n_layers_lin2, curr_in, n_hid2, n_hid2, dropout)
    )
    curr_in = curr_in if len(self.lin2)==0 else n_hid2
    self.w_ct = torch.nn.Sequential(
        torch.nn.Linear(curr_in, 1),
        torch.nn.Dropout(dropout)
    )
    if n_in_meta > 0:
        self.lin_meta = torch.nn.Sequential(
            *self.get_lin_layers(n_layers_lin_meta, n_in_meta, curr_in, curr_in, dropout)
        )
        curr_in += (n_in_meta if n_layers_lin_meta == 0 else curr_in)
    self.lin_out = torch.nn.Linear(curr_in, n_out)
    self.attn1 = attn1
    self.attn2 = attn2
    self.use_softmax = use_softmax

  def get_lin_layers(self, n_layers, n_in, n_hid, n_out, dropout):
    layers = []
    for i in range(n_layers):
      curr_in = n_in if i == 0 else n_hid
      curr_out = n_out if i == n_layers - 1 else n_hid
      layers.extend([torch.nn.Linear(curr_in, curr_out), torch.nn.ReLU(), torch.nn.Dropout(dropout)])
    return layers

  def forward(self, X, batch, ct_size, n_ct, meta=None):
    X = self.lin(X)
    if self.attn1:
        if self.use_softmax:
            w_c = softmax(self.w_c(X).squeeze(), batch)
        else:
            w_c = torch.sigmoid((self.w_c(X)).squeeze())
        if self.attn2:
            X = global_add_pool(X * w_c.unsqueeze(dim=-1), batch, size=ct_size).reshape(-1, n_ct, self.n_in1)
        else:
            X = global_add_pool(X * w_c.unsqueeze(dim=-1), batch)
    else:
        if self.attn2:
            X = global_mean_pool(X, batch, size=ct_size).reshape(-1, n_ct, self.n_in1)
        else:
            X = global_mean_pool(X, batch)
    X = self.lin2(X)
    if self.attn2:
        if self.use_softmax:
            w_ct = torch.nn.Softmax(dim=1)(self.w_ct(X))
        else:
            w_ct = torch.sigmoid(self.w_ct(X))
        X = torch.sum(X * w_ct, dim=1)
    if meta is not None:
        meta = self.lin_meta(meta)
        X = torch.cat([X, meta], dim=1)
    X = self.lin_out(X)
    return X

  def decompose_logits(self, X, batch, ct_size, n_ct):
    X = self.lin(X)
    w_c = softmax(self.w_c(X).squeeze(), batch)
    X = global_add_pool(X * w_c.unsqueeze(dim=-1), batch, size=ct_size).reshape(-1, n_ct, self.n_in1)
    X = self.lin2(X)
    w_ct = torch.nn.Softmax(dim=1)(self.w_ct(X))
    X = X @ self.lin_out.weight.T
    return (w_ct * X).squeeze(), w_ct.squeeze()


## Cross val

In [None]:
# CHANGE THIS
attn1 = True
attn2 = True
n = 10
n_skf = 10
n_skf_in = 10
meta = None
use_meta = False


binary = len(set(adata.obs["label"])) == 2
n_classes = len(set(adata.obs["label"]))

def wrapper_objective(train_samples):
    def objective(trial):
        n_epochs = trial.suggest_categorical("n_epochs", [100, 500, 1000])
        dropout = trial.suggest_categorical("dropout", [0, 0.3, 0.5, 0.7])
        weight_decay = trial.suggest_categorical("weight_decay", [1e-4, 1e-3, 1e-2])
        n_layers_lin = trial.suggest_categorical("n_layers_lin", [1, 2])
        n_hid = trial.suggest_categorical("n_hid", [32, 64, 128])
        lr = trial.suggest_categorical("lr", [1e-3, 5e-3])
        n_layers_lin_meta = trial.suggest_categorical("n_layers_lin_meta", [0, 1, 2]) if use_meta else 1

        skf = StratifiedKFold(n_skf_in, shuffle=True, random_state=0)
        preds_ = []
        truths_ = []

        for train_idx, val_idx in skf.split(train_samples, train_samples["label"]):
            train_samples_in = train_samples.iloc[train_idx, :]["patient"].to_list()
            val_samples = train_samples.iloc[val_idx, :]["patient"].to_list()

            X_train, y_train, batch_train, meta_train = get_data_batch_count(df, all_ct, train_samples_in, binary=binary, meta=meta if use_meta else None, attn2=attn2)
            X_val,y_val, batch_val, meta_val = get_data_batch_count(df, all_ct, val_samples, binary=binary, meta=meta if use_meta else None, attn2=attn2)

            X_train, y_train, batch_train, meta_train = X_train.to(device), y_train.to(device), batch_train.to(device), meta_train.to(device) if meta_train is not None else meta_train
            X_val, y_val, batch_val, meta_val = X_val.to(device), y_val.to(device), batch_val.to(device), meta_val.to(device) if meta_val is not None else meta_val

            set_seeds(0)
            model = Model(X_train.shape[-1], n_out=1 if binary else n_classes, n_in_meta=0 if not use_meta else meta_train.shape[-1], \
                        attn1=attn1, attn2=attn2, use_softmax=True, dropout=dropout, n_layers_lin=n_layers_lin, n_layers_lin2=0, \
                        n_layers_lin_meta=n_layers_lin_meta, n_hid=n_hid, n_hid2=0).to(device)
            opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
            loss_fn = torch.nn.BCEWithLogitsLoss() if binary else torch.nn.CrossEntropyLoss()

            for epoch in range(n_epochs):
                model.train()
                opt.zero_grad()
                pred = model(X_train, batch_train, len(all_ct)*len(y_train), len(all_ct), meta=meta_train)
                loss = loss_fn(pred.squeeze(), y_train.squeeze())
                loss.backward()
                opt.step()

            with torch.no_grad():
                model.eval()
                pred = model(X_val, batch_val, len(all_ct)*len(y_val), len(all_ct), meta=meta_val)
                pred = torch.sigmoid(pred.squeeze()) if binary else torch.softmax(pred.squeeze(), -1)
                preds_.extend(pred.detach().cpu().numpy())
                truths_.extend(y_val.squeeze().cpu().numpy())

        return roc_auc_score(np.stack(truths_), np.stack(preds_), multi_class="ovo")
    return objective


import optuna as optuna
from optuna.samplers import TPESampler


device = "cuda:0" if torch.cuda.is_available() else "cpu"
all_ct = adata.obs["cell_type_annotation"].unique()
samples = adata.obs[["patient", "label"]].drop_duplicates()

def run():
    aucs = []
    for i in tqdm(range(n)):
        skf = StratifiedKFold(n_skf, shuffle=True, random_state=i)
        preds_ = []
        truths_ = []
        for train_idx, test_idx in skf.split(samples, samples["label"]):

            train_samples = samples.iloc[train_idx, :]
            test_samples = samples.iloc[test_idx, :]["patient"].to_list()

            X_test, y_test, batch_test, meta_test  = get_data_batch_count(df, all_ct, test_samples, binary=binary, meta=meta if use_meta else None, attn2=attn2)
            X_test, y_test, batch_test, meta_test = X_test.to(device), y_test.to(device), batch_test.to(device), meta_test.to(device) if meta_test is not None else meta_test

            sampler = TPESampler(seed=0)
            study = optuna.create_study(direction="maximize", sampler=sampler)
            study.optimize(wrapper_objective(train_samples), n_trials=30)
            best_params = study.best_params

            X_train, y_train, batch_train, meta_train = get_data_batch_count(df, all_ct, train_samples["patient"].to_list(), binary=binary, meta=meta if use_meta else None, attn2=attn2)
            X_train, y_train, batch_train, meta_train = X_train.to(device), y_train.to(device), batch_train.to(device), meta_train.to(device) if meta_train is not None else meta_train

            set_seeds(i)
            model = Model(X_train.shape[-1], n_out=1 if binary else n_classes, n_in_meta=0 if not use_meta else meta_train.shape[-1], \
                        attn1=attn1, attn2=attn2, use_softmax=True, dropout=best_params["dropout"], \
                        n_layers_lin=best_params["n_layers_lin"], n_layers_lin2=0, \
                        n_layers_lin_meta=1 if not use_meta else best_params["n_layers_lin_meta"], n_hid=best_params["n_hid"], n_hid2=0).to(device)
            opt = torch.optim.Adam(model.parameters(), lr=best_params["lr"], weight_decay=best_params["weight_decay"])
            loss_fn = torch.nn.BCEWithLogitsLoss() if binary else torch.nn.CrossEntropyLoss()

            for epoch in range(best_params["n_epochs"]):
                model.train()
                opt.zero_grad()
                pred = model(X_train, batch_train, len(all_ct)*len(y_train), len(all_ct), meta=meta_train)
                loss = loss_fn(pred.squeeze(), y_train.squeeze())
                loss.backward()
                opt.step()

            with torch.no_grad():
                model.eval()
                pred = model(X_test, batch_test, len(all_ct)*len(y_test), len(all_ct), meta=meta_test)
                pred = torch.sigmoid(pred.squeeze()) if binary else torch.softmax(pred.squeeze(), -1)
                preds_.extend(pred.detach().cpu().numpy())
                truths_.extend(y_test.squeeze().cpu().numpy())
        aucs.append(roc_auc_score(np.stack(truths_), np.stack(preds_), multi_class="ovo"))
    return np.mean(aucs), np.std(aucs)
mean_auc, std_auc = run()
print(f"AUC: {mean_auc} +/- {std_auc}")

## Vary train size

In [None]:
# CHANGE THIS
train_sizes = [0.25, 0.5, 0.75]
n = 100
n_skf_in = 4
attn1 = True

from sklearn.model_selection import train_test_split
binary = len(set(adata.obs["label"])) == 2
n_classes = len(set(adata.obs["label"]))

def objective(trial):
    n_epochs = trial.suggest_categorical("n_epochs", [100, 500, 1000])
    dropout = trial.suggest_categorical("dropout", [0, 0.3, 0.5, 0.7])
    weight_decay = trial.suggest_categorical("weight_decay", [1e-4, 1e-3, 1e-2])
    n_layers_lin = trial.suggest_categorical("n_layers_lin", [1, 2])
    n_hid = trial.suggest_categorical("n_hid", [32, 64, 128])
    lr = trial.suggest_categorical("lr", [1e-3, 5e-3])

    skf = StratifiedKFold(n_skf_in, shuffle=True, random_state=0)
    preds_ = []
    truths_ = []

    for train_idx, val_idx in skf.split(train_samples, train_samples["label"]):
        train_samples_in = train_samples.iloc[train_idx, :]["patient"].to_list()
        val_samples = train_samples.iloc[val_idx, :]["patient"].to_list()

        X_train, y_train, batch_train, _ = get_data_batch_count(df, all_ct, train_samples_in, binary=binary)
        X_val,y_val, batch_val, _ = get_data_batch_count(df, all_ct, val_samples, binary=binary)

        X_train, y_train, batch_train = X_train.to(device), y_train.to(device), batch_train.to(device)
        X_val, y_val, batch_val = X_val.to(device), y_val.to(device), batch_val.to(device)

        set_seeds(0)
        model = Model(X_train.shape[-1], n_out=1 if binary else n_classes, \
                      attn1=attn1, attn2=True, use_softmax=True, dropout=dropout, \
                    n_layers_lin=n_layers_lin, n_layers_lin2=0, n_hid=n_hid, n_hid2=0).to(device)
        opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        loss_fn = torch.nn.BCEWithLogitsLoss() if binary else torch.nn.CrossEntropyLoss()

        for epoch in range(n_epochs):
            model.train()
            opt.zero_grad()
            pred = model(X_train, batch_train, len(all_ct)*len(y_train), len(all_ct))
            loss = loss_fn(pred.squeeze(), y_train.squeeze())
            loss.backward()
            opt.step()

        with torch.no_grad():
            model.eval()
            pred = model(X_val, batch_val, len(all_ct)*len(y_val), len(all_ct))
            pred = torch.sigmoid(pred.squeeze()) if binary else torch.softmax(pred.squeeze(), -1)
            preds_.extend(pred.detach().cpu().numpy())
            truths_.extend(y_val.squeeze().cpu().numpy())

    return roc_auc_score(np.stack(truths_), np.stack(preds_), multi_class="ovo")


import optuna as optuna
from optuna.samplers import TPESampler


device = "cuda:0" if torch.cuda.is_available() else "cpu"
all_ct = adata.obs["cell_type_annotation"].unique()
samples = adata.obs[["patient", "label"]].drop_duplicates()

for train_size in train_sizes:
    aucs = []
    for i in tqdm(range(n)):
        train_idx, test_idx = train_test_split(range(len(samples)), stratify=samples["label"].to_list(), train_size=train_size, random_state=i)
        train_samples = samples.iloc[train_idx, :]
        test_samples = samples.iloc[test_idx, :]["patient"].to_list()

        X_test, y_test, batch_test, _  = get_data_batch_count(df, all_ct, test_samples, binary=binary)
        X_test, y_test, batch_test = X_test.to(device), y_test.to(device), batch_test.to(device)

        sampler = TPESampler(seed=0)
        study = optuna.create_study(direction="maximize", sampler=sampler)
        study.optimize(objective, n_trials=30)
        best_params = study.best_params

        X_train, y_train, batch_train, _  = get_data_batch_count(df, all_ct, train_samples["patient"].to_list(), binary=binary)
        X_train, y_train, batch_train = X_train.to(device), y_train.to(device), batch_train.to(device)
        set_seeds(i)
        model = Model(X_train.shape[-1], n_out=1 if binary else n_classes, \
                      attn1=attn1, attn2=True, use_softmax=True, dropout=best_params["dropout"], \
                      n_layers_lin=best_params["n_layers_lin"], n_layers_lin2=0, n_hid=best_params["n_hid"], n_hid2=0).to(device)
        opt = torch.optim.Adam(model.parameters(), lr=best_params["lr"], weight_decay=best_params["weight_decay"])
        loss_fn = torch.nn.BCEWithLogitsLoss() if binary else torch.nn.CrossEntropyLoss()

        for epoch in range(best_params["n_epochs"]):
            model.train()
            opt.zero_grad()
            pred = model(X_train, batch_train, len(all_ct)*len(y_train), len(all_ct))
            loss = loss_fn(pred.squeeze(), y_train.squeeze())
            loss.backward()
            opt.step()

        with torch.no_grad():
            model.eval()
            pred = model(X_test, batch_test, len(all_ct)*len(y_test), len(all_ct))
            pred = torch.sigmoid(pred.squeeze()) if binary else torch.softmax(pred.squeeze(), -1)
        auc = roc_auc_score(y_test.squeeze().cpu().numpy(), pred.detach().cpu().numpy(), multi_class="ovo")
        aucs.append(auc)
    print("Train size", train_size, np.mean(aucs), "+/-", np.std(aucs))

## Vary cell count

In [None]:
# CHANGE THIS
cell_props = [0.25, 0.5, 0.75]
n = 10
n_skf = 10
n_skf_in = 10
attn1 = True

from sklearn.model_selection import StratifiedKFold

binary = len(set(adata.obs["label"])) == 2
n_classes = len(set(adata.obs["label"]))
def objective(trial):
    n_epochs = trial.suggest_categorical("n_epochs", [100, 500, 1000])
    dropout = trial.suggest_categorical("dropout", [0, 0.3, 0.5, 0.7])
    weight_decay = trial.suggest_categorical("weight_decay", [1e-4, 1e-3, 1e-2])
    n_layers_lin = trial.suggest_categorical("n_layers_lin", [1, 2])
    n_hid = trial.suggest_categorical("n_hid", [32, 64, 128])
    lr = trial.suggest_categorical("lr", [1e-3, 5e-3])

    skf = StratifiedKFold(n_skf_in, shuffle=True, random_state=0)
    preds_ = []
    truths_ = []

    for train_idx, val_idx in skf.split(train_samples, train_samples["label"]):
        train_samples_in = train_samples.iloc[train_idx, :]["patient"].to_list()
        val_samples = train_samples.iloc[val_idx, :]["patient"].to_list()

        X_train, y_train, batch_train, _ = get_data_batch_count(df_subsampled, all_ct, train_samples_in, binary=binary)
        X_val,y_val, batch_val, _ = get_data_batch_count(df_subsampled, all_ct, val_samples, binary=binary)

        X_train, y_train, batch_train = X_train.to(device), y_train.to(device), batch_train.to(device)
        X_val, y_val, batch_val = X_val.to(device), y_val.to(device), batch_val.to(device)

        set_seeds(0)
        model = Model(X_train.shape[-1], n_out=1 if binary else n_classes, \
                      attn1=attn1, attn2=True, use_softmax=True, dropout=dropout, \
                    n_layers_lin=n_layers_lin, n_layers_lin2=0, n_hid=n_hid, n_hid2=0).to(device)
        opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        loss_fn = torch.nn.BCEWithLogitsLoss() if binary else torch.nn.CrossEntropyLoss()

        for epoch in range(n_epochs):
            model.train()
            opt.zero_grad()
            pred = model(X_train, batch_train, len(all_ct)*len(y_train), len(all_ct))
            loss = loss_fn(pred.squeeze(), y_train.squeeze())
            loss.backward()
            opt.step()

        with torch.no_grad():
            model.eval()
            pred = model(X_val, batch_val, len(all_ct)*len(y_val), len(all_ct))
            pred = torch.sigmoid(pred.squeeze()) if binary else torch.softmax(pred.squeeze(), -1)
            preds_.extend(pred.detach().cpu().numpy())
            truths_.extend(y_val.squeeze().cpu().numpy())

    return roc_auc_score(np.stack(truths_), np.stack(preds_), multi_class="ovo")


import optuna as optuna
from optuna.samplers import TPESampler


device = "cuda:0" if torch.cuda.is_available() else "cpu"
all_ct = adata.obs["cell_type_annotation"].unique()
samples = adata.obs[["patient", "label"]].drop_duplicates()

for cell_prop in cell_props:
    aucs = []
    for i in tqdm(range(n)):
        skf = StratifiedKFold(n_skf, shuffle=True, random_state=i)
        preds_ = []
        truths_ = []
        for train_idx, test_idx in skf.split(samples, samples["label"]):

            df_subsampled = df.groupby('patient', observed=False).apply(lambda x: x.sample(frac=cell_prop, random_state=i), include_groups=True).reset_index(drop=True)
            train_samples = samples.iloc[train_idx, :]
            test_samples = samples.iloc[test_idx, :]["patient"].to_list()

            X_test, y_test, batch_test, _  = get_data_batch_count(df_subsampled, all_ct, test_samples, binary=binary)
            X_test, y_test, batch_test = X_test.to(device), y_test.to(device), batch_test.to(device)


            sampler = TPESampler(seed=0)
            study = optuna.create_study(direction="maximize", sampler=sampler)
            study.optimize(objective, n_trials=30)
            best_params = study.best_params


            X_train, y_train, batch_train, _  = get_data_batch_count(df_subsampled, all_ct, train_samples["patient"].to_list(), binary=binary)
            X_train, y_train, batch_train = X_train.to(device), y_train.to(device), batch_train.to(device)
            set_seeds(i)
            model = Model(X_train.shape[-1], n_out=1 if binary else n_classes, \
                        attn1=attn1, attn2=True, use_softmax=True, dropout=best_params["dropout"], \
                        n_layers_lin=best_params["n_layers_lin"], n_layers_lin2=0, n_hid=best_params["n_hid"], n_hid2=0).to(device)
            opt = torch.optim.Adam(model.parameters(), lr=best_params["lr"], weight_decay=best_params["weight_decay"])
            loss_fn = torch.nn.BCEWithLogitsLoss() if binary else torch.nn.CrossEntropyLoss()

            for epoch in range(best_params["n_epochs"]):
                model.train()
                opt.zero_grad()
                pred = model(X_train, batch_train, len(all_ct)*len(y_train), len(all_ct))
                loss = loss_fn(pred.squeeze(), y_train.squeeze())
                loss.backward()
                opt.step()

            with torch.no_grad():
                model.eval()
                pred = model(X_test, batch_test, len(all_ct)*len(y_test), len(all_ct))
                pred = torch.sigmoid(pred.squeeze()) if binary else torch.softmax(pred.squeeze(), -1)
                preds_.extend(pred.detach().cpu().numpy())
                truths_.extend(y_test.squeeze().cpu().numpy())
        aucs.append(roc_auc_score(np.stack(truths_), np.stack(preds_), multi_class="ovo"))
    print("Cell prop", cell_prop, np.mean(aucs), "+/-", np.std(aucs))

## Vary cell type annot

In [None]:
# CHANGE THIS
cell_props = [0.25, 0.5]
n = 10
n_skf = 10
n_skf_in = 10
attn1 = True

from sklearn.model_selection import StratifiedKFold
def reassign_cell_types(df, prop, all_ct, seed=0):
    np.random.seed(seed)

    df = df.copy()
    patients = df['patient'].unique()

    for patient in patients:
        patient_data = df[df['patient'] == patient]
        num_to_select = int(len(patient_data) * prop)

        selected_indices = np.random.choice(patient_data.index, num_to_select, replace=False)

        new_annotations = np.random.choice(all_ct, num_to_select, replace=True)

        df.loc[selected_indices, 'cell_type_annotation'] = new_annotations

    return df

binary = len(set(adata.obs["label"])) == 2
n_classes = len(set(adata.obs["label"]))
def objective(trial):
    n_epochs = trial.suggest_categorical("n_epochs", [100, 500, 1000])
    dropout = trial.suggest_categorical("dropout", [0, 0.3, 0.5, 0.7])
    weight_decay = trial.suggest_categorical("weight_decay", [1e-4, 1e-3, 1e-2])
    n_layers_lin = trial.suggest_categorical("n_layers_lin", [1, 2])
    n_hid = trial.suggest_categorical("n_hid", [32, 64, 128])
    lr = trial.suggest_categorical("lr", [1e-3, 5e-3])

    skf = StratifiedKFold(n_skf_in, shuffle=True, random_state=0)
    preds_ = []
    truths_ = []

    for train_idx, val_idx in skf.split(train_samples, train_samples["label"]):
        train_samples_in = train_samples.iloc[train_idx, :]["patient"].to_list()
        val_samples = train_samples.iloc[val_idx, :]["patient"].to_list()

        X_train, y_train, batch_train, _ = get_data_batch_count(df_subsampled, all_ct, train_samples_in, binary=binary)
        X_val,y_val, batch_val, _ = get_data_batch_count(df_subsampled, all_ct, val_samples, binary=binary)

        X_train, y_train, batch_train = X_train.to(device), y_train.to(device), batch_train.to(device)
        X_val, y_val, batch_val = X_val.to(device), y_val.to(device), batch_val.to(device)

        set_seeds(0)
        model = Model(X_train.shape[-1], n_out=1 if binary else n_classes, \
                      attn1=attn1, attn2=True, use_softmax=True, dropout=dropout, \
                    n_layers_lin=n_layers_lin, n_layers_lin2=0, n_hid=n_hid, n_hid2=0).to(device)
        opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        loss_fn = torch.nn.BCEWithLogitsLoss() if binary else torch.nn.CrossEntropyLoss()

        for epoch in range(n_epochs):
            model.train()
            opt.zero_grad()
            pred = model(X_train, batch_train, len(all_ct)*len(y_train), len(all_ct))
            loss = loss_fn(pred.squeeze(), y_train.squeeze())
            loss.backward()
            opt.step()

        with torch.no_grad():
            model.eval()
            pred = model(X_val, batch_val, len(all_ct)*len(y_val), len(all_ct))
            pred = torch.sigmoid(pred.squeeze()) if binary else torch.softmax(pred.squeeze(), -1)
            preds_.extend(pred.detach().cpu().numpy())
            truths_.extend(y_val.squeeze().cpu().numpy())

    return roc_auc_score(np.stack(truths_), np.stack(preds_), multi_class="ovo")


import optuna as optuna
from optuna.samplers import TPESampler


device = "cuda:0" if torch.cuda.is_available() else "cpu"
all_ct = adata.obs["cell_type_annotation"].unique()
samples = adata.obs[["patient", "label"]].drop_duplicates()

for cell_prop in cell_props:
    aucs = []
    for i in tqdm(range(n)):
        skf = StratifiedKFold(n_skf, shuffle=True, random_state=i)
        preds_ = []
        truths_ = []
        for train_idx, test_idx in skf.split(samples, samples["label"]):

            df_subsampled = reassign_cell_types(df, cell_prop, all_ct, seed=i)
            train_samples = samples.iloc[train_idx, :]
            test_samples = samples.iloc[test_idx, :]["patient"].to_list()

            X_test, y_test, batch_test, _  = get_data_batch_count(df_subsampled, all_ct, test_samples, binary=binary)
            X_test, y_test, batch_test = X_test.to(device), y_test.to(device), batch_test.to(device)


            sampler = TPESampler(seed=0)
            study = optuna.create_study(direction="maximize", sampler=sampler)
            study.optimize(objective, n_trials=30)
            best_params = study.best_params


            X_train, y_train, batch_train, _  = get_data_batch_count(df_subsampled, all_ct, train_samples["patient"].to_list(), binary=binary)
            X_train, y_train, batch_train = X_train.to(device), y_train.to(device), batch_train.to(device)
            set_seeds(i)
            model = Model(X_train.shape[-1], n_out=1 if binary else n_classes, \
                        attn1=attn1, attn2=True, use_softmax=True, dropout=best_params["dropout"], \
                        n_layers_lin=best_params["n_layers_lin"], n_layers_lin2=0, n_hid=best_params["n_hid"], n_hid2=0).to(device)
            opt = torch.optim.Adam(model.parameters(), lr=best_params["lr"], weight_decay=best_params["weight_decay"])
            loss_fn = torch.nn.BCEWithLogitsLoss() if binary else torch.nn.CrossEntropyLoss()

            for epoch in range(best_params["n_epochs"]):
                model.train()
                opt.zero_grad()
                pred = model(X_train, batch_train, len(all_ct)*len(y_train), len(all_ct))
                loss = loss_fn(pred.squeeze(), y_train.squeeze())
                loss.backward()
                opt.step()

            with torch.no_grad():
                model.eval()
                pred = model(X_test, batch_test, len(all_ct)*len(y_test), len(all_ct))
                pred = torch.sigmoid(pred.squeeze()) if binary else torch.softmax(pred.squeeze(), -1)
                preds_.extend(pred.detach().cpu().numpy())
                truths_.extend(y_test.squeeze().cpu().numpy())
        aucs.append(roc_auc_score(np.stack(truths_), np.stack(preds_), multi_class="ovo"))
    print("Cell prop", cell_prop, np.mean(aucs), "+/-", np.std(aucs))

## Perm

In [None]:
# CHANGE THIS
n_perm = 100
dropout = 0.5
lr = 1e-3
weight_decay = 1e-3
n_epochs = 1000
n_layers_lin = 1
n_hid = 32


In [None]:
def get_data_batch_count_perm(tmp, all_ct, samples, meta=None, perm_annot=False, seed=None, binary=True):
    ct_dict = dict({ct: idx for idx, ct in enumerate(all_ct)})
    Xs = []
    batches = []

    if meta is not None:
        meta = torch.tensor(meta.loc[samples["patient"], :].to_numpy(), dtype=torch.float)

    for idx, sample in enumerate(samples["patient"].to_list()):
        sample_df = tmp[tmp["patient"]==sample]
        x = sample_df.iloc[:,:df.shape[-1]-3].to_numpy()
        batch = [(idx * len(all_ct) + ct_dict[ct]) for ct in sample_df["cell_type_annotation"].to_list()]
        if perm_annot:
            set_seeds(seed)
            batch = np.random.permutation(batch)
        Xs.append(x)
        batches.append(batch)
    Xs = torch.tensor(np.concatenate(Xs), dtype = torch.float)
    batches = torch.tensor(np.concatenate(batches))
    ys = torch.tensor(samples["label"].to_list(), dtype = torch.float if binary else torch.long)
    return Xs, ys, batches, meta

In [None]:

n_classes = len(set(adata.obs["label"]))
binary = n_classes == 2

device = "cuda:0" if torch.cuda.is_available() else "cpu"
all_ct = adata.obs["cell_type_annotation"].unique()

def train(perm=False, seed=None, reduce=True, n_skf=5):
    samples = adata.obs[["patient", "label"]].drop_duplicates()
    if perm:
        samples["label"] = samples["label"].sample(len(samples), random_state=seed).to_list()
    skf = StratifiedKFold(n_skf, shuffle=True, random_state=0)
    ct_logits = []
    truths = []
    for train_idx, test_idx in skf.split(samples, samples["label"]):

        train_samples = samples.iloc[train_idx, :]
        test_samples = samples.iloc[test_idx, :]

        X_test, y_test, batch_test, _ = get_data_batch_count_perm(df, all_ct, test_samples, binary=binary)
        X_test, y_test, batch_test = X_test.to(device), y_test.to(device), batch_test.to(device)

        X_train, y_train, batch_train, _  = get_data_batch_count_perm(df, all_ct, train_samples, binary=binary)
        X_train, y_train, batch_train = X_train.to(device), y_train.to(device), batch_train.to(device)
        set_seeds(0)
        model = Model(X_train.shape[-1], n_out=1 if binary else n_classes, attn1=True, attn2=True, use_softmax=True, dropout=dropout, \
                      n_layers_lin=n_layers_lin, n_layers_lin2=0, n_hid=n_hid, n_hid2=0).to(device)
        opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        loss_fn = torch.nn.BCEWithLogitsLoss() if binary else torch.nn.CrossEntropyLoss()

        for epoch in range(n_epochs):
            model.train()
            opt.zero_grad()
            pred = model(X_train, batch_train, len(all_ct)*len(y_train), len(all_ct))
            loss = loss_fn(pred.squeeze(), y_train.squeeze())
            loss.backward()
            opt.step()

        with torch.no_grad():
            model.eval()
            ct_logit, _ = model.decompose_logits(X_test, batch_test, len(all_ct)*len(y_test), len(all_ct))
            ct_logits.extend(ct_logit)
            truths.extend(y_test.long().squeeze().cpu().numpy())
    ct_logits = torch.stack(ct_logits).cpu().numpy()
    truths = np.array(truths)
    if binary:
        tmp = pd.DataFrame(ct_logits, columns=all_ct)
        tmp["label"] = truths
        tmp = tmp.groupby("label").mean().loc[[0,1],:].to_numpy()
        if reduce:
            tmp = tmp[1] - tmp[0]
    else:
        tmp = []
        for i in range(n_classes):
            curr_class_logits = ct_logits[:, :, i]
            other_class_logits = ct_logits[:, :, [c for c in range(n_classes) if c != i]].mean(-1)
            diff = curr_class_logits - other_class_logits
            tmp.append(np.mean(diff[truths == i], 0))
        tmp = pd.DataFrame(tmp, columns=all_ct).to_numpy()
        if reduce:
            tmp = tmp.sum(0)
    return tmp, ct_logits, truths

In [None]:
orig = train(False, reduce=False)
orig = pd.DataFrame(orig, columns=all_ct, index=["normal", "covid"])
(orig.iloc[1,:] - orig.iloc[0, :]).sort_values(ascending=False)

In [None]:
perms = []
for i in tqdm(range(n_perm)):
  perms.append(train(True, i))
perms = np.stack(perms)
truth = train(False)

In [None]:
p_vals = (truth[None, :] < perms).sum(0) / n_perm
p_vals

In [None]:
from statsmodels.stats.multitest import multipletests

_, p_vals_corrected, _, _ = multipletests(p_vals, alpha=0.05, method='fdr_bh')
p_vals_corrected