In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
from typing import Dict
import time
import warnings
warnings.filterwarnings('ignore')
import copy
import pickle
from tqdm import tqdm
import scanpy as sc
import numpy as np
import pandas as pd
import random
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import auc, roc_curve, precision_recall_curve
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

from stformer import logger
from stformer.tokenizer import GeneVocab
from stformer.tokenizer import tokenize_and_pad_batch_2
from stformer.model import TransformerModel

In [None]:
class Tokenizer():
    def __init__(self, tokenizer_dir, adata, label_path, vocab, pad_value, pad_token, pathway_ligands=None):
        self.tokenizer_dir = tokenizer_dir
        self.adata = adata
        self.label_path = label_path
        self.vocab = vocab
        self.pad_value = pad_value
        self.pad_token = pad_token
        self.pathway_ligands = pathway_ligands
        self.load_data()
    
    def load_data(self):
        gene_list_df = pd.read_csv(f'{self.tokenizer_dir}/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
        gene_list = list(gene_list_df['gene_name'])
        self.gene_ids = np.array(self.vocab(gene_list), dtype=int)

        ligand_database = pd.read_csv(self.tokenizer_dir+'ligand_database.csv', header=0, index_col=0)
        ligand_symbol = ligand_database[ligand_database.sum(1)>1].index.values
        ligand_symbol = gene_list_df.loc[gene_list_df['gene_name'].isin(ligand_symbol), 'gene_name'].values
        self.ligand_ids = np.array(self.vocab(ligand_symbol.tolist())*self.niche_composition.shape[1], dtype=int)

        df_label = pd.read_csv(f'{self.label_path}', header=0)
        selected_gene_index = sc.pp.filter_genes(self.adata, min_cells=100, inplace=False)
        selected_gene_names = self.adata.var_names.values[selected_gene_index[0]]
        df_label = df_label[df_label['gene_symbols'].isin(selected_gene_names)]

        selected_cell_index = []
        if self.pathway_ligands is not None:
            pathway_ligands = list(set(self.pathway_ligands).intersection(set(df_label.loc[df_label['class']=='outer_member', 'gene_symbols'])))
            pathway_ligands = set(self.vocab(pathway_ligands))
        else:
            pathway_ligands = set(self.vocab(list(df_label.loc[df_label['class']=='outer_member', 'gene_symbols'])))
        
        niche_ligands_expression = self.adata.obsm['niche_ligands_expression'].A
        for i in tqdm(range(niche_ligands_expression.shape[0])):
            if len(set(self.ligand_ids[np.nonzero(niche_ligands_expression[i])[0]]).intersection(pathway_ligands))>=2:
                selected_cell_index.append(i)

        self.expression_matrix = self.adata.X[selected_cell_index].A
        self.niche_ligands_expression = self.adata.obsm['niche_ligands_expression'][selected_cell_index].A
        self.niche_composition = self.adata.obsm['niche_composition'][selected_cell_index].A
        logger.info(f"Selected cell number: {len(selected_cell_index): 6d}/{self.adata.shape[0]: 6d}")

        member_genes = df_label.loc[df_label['class']=='inner_member', 'gene_symbols'].values.tolist()
        random.seed(0)
        nonmember_genes = random.sample(sorted(set(selected_gene_names).difference(set(df_label['gene_symbols']))), len(member_genes))
        gene_targets = member_genes + nonmember_genes
        gene_targets = np.array(self.vocab(gene_targets))
        gene_labels = [1]*len(member_genes)+[0]*len(nonmember_genes)
        
        gene2label = dict(zip(gene_targets, gene_labels))
        for g in self.gene_ids:
            if g not in gene2label:
                gene2label[g] = -100
        gene2label[self.vocab[self.pad_token]] = -100

        self.gene_targets = gene_targets
        self.gene_labels = gene_labels
        self.gene2label = gene2label

    def tokenize_data(self, train_index, valid_index):
        biases = np.zeros([self.niche_composition.shape[0], self.niche_composition.shape[1]*986])
        for k in range(biases.shape[0]):
            biases[k] = np.concatenate([[np.log(p)]*986 for p in self.niche_composition[k]])

        gene_targets_train = self.gene_targets[train_index]
        gene_targets_valid = self.gene_targets[valid_index]
        train_index = [k for k in range(self.expression_matrix.shape[0]) if len(set(self.gene_ids[np.nonzero(self.expression_matrix[k])[0]]).intersection(set(gene_targets_train)))>0]
        valid_index = [k for k in range(self.expression_matrix.shape[0]) if len(set(self.gene_ids[np.nonzero(self.expression_matrix[k])[0]]).intersection(set(gene_targets_valid)))>0]

        train_expression = self.expression_matrix[train_index]
        train_niche = self.niche_ligands_expression[train_index]
        train_biases = biases[train_index]
        
        valid_expression = self.expression_matrix[valid_index]
        valid_niche = self.niche_ligands_expression[valid_index]
        valid_biases = biases[valid_index]

        tokenized_train = tokenize_and_pad_batch_2(
            train_expression,
            train_niche,
            train_biases,
            self.gene_ids,
            self.ligand_ids,
            pad_id = self.vocab[self.pad_token],
            pad_value = self.pad_value,
        )

        tokenized_valid = tokenize_and_pad_batch_2(
            valid_expression,
            valid_niche,
            valid_biases,
            self.gene_ids,
            self.ligand_ids,
            pad_id = self.vocab[self.pad_token],
            pad_value = self.pad_value,
        )

        gene2label_train = self.gene2label.copy()
        for gene in gene2label_train:
            if gene not in gene_targets_train:
                gene2label_train[gene] = -100
        gene2label_valid = self.gene2label.copy()
        for gene in gene2label_valid:
            if gene not in gene_targets_valid:
                gene2label_valid[gene] = -100

        tokenized_train['center_labels'] = torch.from_numpy(np.vectorize(gene2label_train.get)(tokenized_train['center_genes'].numpy()))
        tokenized_valid['center_labels'] = torch.from_numpy(np.vectorize(gene2label_valid.get)(tokenized_valid['center_genes'].numpy()))
        
        logger.info(
            f"train set number of samples: {tokenized_train['center_genes'].shape[0]}, "
            f"\n\t feature length of center cell: {tokenized_train['center_genes'].shape[1]}"
            f"\n\t feature length of niche cells: {tokenized_train['niche_genes'].shape[1]}"
        )
        logger.info(
            f"valid set number of samples: {tokenized_valid['center_genes'].shape[0]}, "
            f"\n\t feature length of center cell: {tokenized_valid['center_genes'].shape[1]}"
            f"\n\t feature length of niche cells: {tokenized_valid['niche_genes'].shape[1]}"
        )

        self.tokenized_train = tokenized_train
        self.tokenized_valid = tokenized_valid

    def prepare_data(self):
        self.train_data_pt = {
            "center_gene_ids": self.tokenized_train["center_genes"],
            "input_center_values": self.tokenized_train["center_values"],
            "niche_gene_ids": self.tokenized_train["niche_genes"],
            "input_niche_values": self.tokenized_train["niche_values"],
            "cross_attn_bias": self.tokenized_train["cross_attn_bias"],
            "center_labels": self.tokenized_train["center_labels"],
        }

        self.valid_data_pt = {
            "center_gene_ids": self.tokenized_valid["center_genes"],
            "input_center_values": self.tokenized_valid["center_values"],
            "niche_gene_ids": self.tokenized_valid["niche_genes"],
            "input_niche_values": self.tokenized_valid["niche_values"],
            "cross_attn_bias": self.tokenized_valid["cross_attn_bias"],
            "center_labels": self.tokenized_valid["center_labels"],
        }
    
    def prepare_dataloader(self, batch_size):
        train_loader = DataLoader(
            dataset=SeqDataset(self.train_data_pt),
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=min(len(os.sched_getaffinity(0)), batch_size // 2),
            pin_memory=True,
        )

        valid_loader = DataLoader(
            dataset=SeqDataset(self.valid_data_pt),
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=min(len(os.sched_getaffinity(0)), batch_size // 2),
            pin_memory=True,
        )
        return train_loader, valid_loader


class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["center_gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}

In [None]:
criterion_cls = nn.CrossEntropyLoss()

def train(model: nn.Module, loader: DataLoader, scaler, optimizer, scheduler, log_interval, epoch, mode) -> None:
    """
    Train the model for one epoch.
    """
    amp = True

    model.train()
    total_gcl = 0.0
    total_error = 0.0
 
    start_time = time.time()

    num_batches = len(loader)

    for batch, batch_data in enumerate(loader):
        center_gene_ids = batch_data["center_gene_ids"].to(device)
        if center_gene_ids.size(0)<8:
            continue
        input_center_values = batch_data["input_center_values"].to(device)
        niche_gene_ids = batch_data["niche_gene_ids"].to(device)
        input_niche_values = batch_data["input_niche_values"].to(device)
        cross_attn_bias = batch_data["cross_attn_bias"].to(device)
        center_labels = batch_data["center_labels"].to(device)

        if mode == 'sp':
            encoder_src_key_padding_mask = niche_gene_ids.eq(vocab[pad_token])
        elif mode == 'sc':
            encoder_src_key_padding_mask = torch.ones_like(niche_gene_ids, dtype=torch.bool).to(device)
        decoder_src_key_padding_mask = center_gene_ids.eq(vocab[pad_token])

        with torch.cuda.amp.autocast(enabled=amp):
            output_dict = model(
                    niche_gene_ids,
                    input_niche_values,
                    encoder_src_key_padding_mask,
                    center_gene_ids,
                    input_center_values,
                    decoder_src_key_padding_mask,
                    cross_attn_bias,
                    GCL = True,
                )
    
            gcl_output = output_dict["gcl_output"]
            batch_logits = gcl_output[torch.logical_or(center_labels==1, center_labels==0)]
            batch_labels = center_labels[torch.logical_or(center_labels==1, center_labels==0)]
            targets = center_gene_ids[torch.logical_or(center_labels==1, center_labels==0)]
            loss_gcl = 0.0
            for t in set(targets):
                loss_gcl += criterion_cls(batch_logits[targets==t], batch_labels[targets==t])
            loss_gcl = loss_gcl/len(set(targets))

            error_rate_gcl = 1 - (
                    (batch_logits.argmax(1) == batch_labels)
                    .sum()
                    .item()
                ) / batch_labels.size(0)

        model.zero_grad()
        scaler.scale(loss_gcl).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"Found infinite gradient. This may be caused by the gradient "
                    f"scaler. The current scale is {scaler.get_scale()}. This warning "
                    "can be ignored if no longer occurs after autoscaling of the scaler."
                )
        scaler.step(optimizer)
        scaler.update()

        total_gcl += loss_gcl.item()
        total_error += error_rate_gcl

        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            sec_per_batch = (time.time() - start_time) / log_interval
            cur_gcl = total_gcl / log_interval
            cur_error = total_error / log_interval
            logger.info(
                f"| Split {split} | epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.8f} | sec/batch {sec_per_batch:5.1f} | "
                f"gcl {cur_gcl:5.5f} | "
                f"acc {1-cur_error:1.5f} | "
            )
            total_gcl = 0
            total_error = 0
            start_time = time.time()

def py_softmax(vector):
	e = np.exp(vector)
	return e / e.sum()

def evaluate(model: nn.Module, loader: DataLoader, mode, curve) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    amp = True
    
    model.eval()
    total_gcl = 0.0
    total_error = 0.0
    total_num = 0

    logits = []
    labels = []
    
    # batch_num = 0
    with torch.no_grad():
        for batch_data in tqdm(loader):
            # batch_num += 1
            # if batch_num>100:
            #     break
            center_gene_ids = batch_data["center_gene_ids"].to(device)
            if center_gene_ids.size(0)<8:
                continue
            input_center_values = batch_data["input_center_values"].to(device)
            niche_gene_ids = batch_data["niche_gene_ids"].to(device)
            input_niche_values = batch_data["input_niche_values"].to(device)
            cross_attn_bias = batch_data["cross_attn_bias"].to(device)
            center_labels = batch_data["center_labels"].to(device)

            if mode == 'sp':
                encoder_src_key_padding_mask = niche_gene_ids.eq(vocab[pad_token])
            elif mode == 'sc':
                encoder_src_key_padding_mask = torch.ones_like(niche_gene_ids, dtype=torch.bool).to(device)
            decoder_src_key_padding_mask = center_gene_ids.eq(vocab[pad_token])

            with torch.cuda.amp.autocast(enabled=amp):
                output_dict = model(
                        niche_gene_ids,
                        input_niche_values,
                        encoder_src_key_padding_mask,
                        center_gene_ids,
                        input_center_values,
                        decoder_src_key_padding_mask,
                        cross_attn_bias,
                        GCL = True,
                    )
                gcl_output = output_dict["gcl_output"]
                batch_logits = gcl_output[torch.logical_or(center_labels==1, center_labels==0)]
                batch_labels = center_labels[torch.logical_or(center_labels==1, center_labels==0)]
                logits.append(batch_logits.to('cpu'))
                labels.append(batch_labels.to('cpu'))

            accuracy = (batch_logits.argmax(1) == batch_labels).sum().item()
            total_error += batch_labels.size(0) - accuracy
            total_num += batch_labels.size(0)
            total_gcl += criterion_cls(batch_logits, batch_labels).item()*batch_labels.size(0) 

    logits = torch.concat(logits).float()
    labels = torch.concat(labels)

    y_score = [py_softmax(item)[1] for item in logits.numpy()]
    y_true = labels.numpy()

    if curve == 'roc':
        fpr, tpr, _ = roc_curve(y_true, y_score)
        auc_value = auc(fpr, tpr)
    elif curve == 'prc':
        precision, recall, _ = precision_recall_curve(y_true, y_score)
        auc_value = auc(recall, precision)

    val_err = total_error / total_num
    val_loss = total_gcl / total_num
    logger.info("-" * 89)
    logger.info(
        f"valid accuracy: {1-val_err:1.4f} | "
        f"valid auc: {auc_value:1.4f} | "
        f"valid loss: {val_loss:1.4f} | "
    )
    logger.info("-" * 89)
    
    if curve == 'roc':
        return auc_value, fpr, tpr
    elif curve == 'prc':
        return auc_value, recall, precision


def train_and_evaluate(model, train_loader, valid_loader, epochs, mean_curve_x, mode, curve):
    lr = 1e-4
    amp = True
    schedule_ratio = 0.9
    schedule_interval = 1
    log_interval = 10

    optimizer = torch.optim.Adam(
            model.parameters(), lr=lr, eps=1e-4 if amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, schedule_interval, gamma=schedule_ratio
    )
    scaler = torch.cuda.amp.GradScaler(enabled=amp)

    best_val_auc = 0
    best_curve_x = 0
    best_curve_y = 0

    for epoch in range(1, epochs+1):
        train(model, train_loader, scaler, optimizer, scheduler, log_interval, epoch, mode)

        val_auc, curve_x, curve_y = evaluate(model, valid_loader, mode, curve)
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_curve_x = curve_x
            best_curve_y = curve_y
            logger.info(f"Best model with auc {best_val_auc:1.4f}")

        scheduler.step()
    
    if curve == 'roc':
        interp_curve_y = np.interp(mean_curve_x, best_curve_x, best_curve_y)
        interp_curve_y[0] = 0.0
    elif curve == 'prc':
        interp_curve_y = np.interp(mean_curve_x, best_curve_x[::-1], best_curve_y[::-1])
        interp_curve_y[0] = 1.0
    
    return interp_curve_y, best_curve_x, best_curve_y

In [None]:
from scfoundation import load

def initialize_model(model_file):
    pretrainmodel, pretrainconfig = load.load_model_frommmf('scfoundation/models/models.ckpt')

    model = TransformerModel(
        embsize,
        nhead,
        d_hid,
        nlayers,
        do_gcl = True,
        nlayers_gcl = 3,
        n_gcl =2,
        dropout = dropout,
        cell_emb_style = cell_emb_style,
        scfoundation_token_emb1 = copy.deepcopy(pretrainmodel.token_emb),
        scfoundation_token_emb2 = copy.deepcopy(pretrainmodel.token_emb),
        scfoundation_pos_emb1 = copy.deepcopy(pretrainmodel.pos_emb),
        scfoundation_pos_emb2 = copy.deepcopy(pretrainmodel.pos_emb),
    )

    pt_model = torch.load(model_file, map_location='cpu')

    model_dict = model.state_dict()
    pretrained_dict = pt_model.state_dict()
    pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if 'cls_decoder' not in k and 'gcl_decoder' not in k
                # if k in model_dict and v.shape == model_dict[k].shape
    }
    # for k, v in pretrained_dict.items():
    #     logger.info(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
    for name, para in model.named_parameters():
        para.requires_grad = False
    for name, para in model.transformer_decoder.layers[4:6].named_parameters():
        para.requires_grad = True
    for name, para in model.gcl_decoder.named_parameters():
        para.requires_grad = True
    post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

    logger.info(f"Total Pre freeze Params {(pre_freeze_param_count )}")
    logger.info(f"Total Post freeze Params {(post_freeze_param_count )}")

    return model        

In [None]:
embsize = 768
d_hid = 3072
nhead = 12
nlayers = 6
dropout = 0.1
cell_emb_style = 'max-pool'
GCL = True
mode = 'sp'
curve = 'prc'

pad_token = "<pad>"
pad_value = 103
tokenizer_dir = '../stformer/tokenizer/'
vocab_file = tokenizer_dir + "scfoundation_gene_vocab.json"
vocab = GeneVocab.from_file(vocab_file)
vocab.append_token(pad_token)
vocab.set_default_index(vocab[pad_token])

In [None]:
model_file = '../pretraining/models/model_4.1M.ckpt'
label_path = 'gene_lists/TGF-beta_signaling_pathway_KEGG.csv'
pathway_ligands = ['BMP2','BMP4','BMP5','BMP6','BMP7','BMP8A','BMP8B','GDF5','GDF6','GDF7','INHBA','INHBB','INHBC','INHBE','TGFB1','TGFB2','TGFB3','NODAL']

dataset = 'human_myocardial_infarction'
slide = 'ACH005'
adata = sc.read_h5ad(f'../datasets/{slide}_niche.h5ad')

# max_niche_cell_num = 20 # [1,5,10,15,20,25]
# adata.obsm['niche_celltypes'] = adata.obsm[f'niche_celltypes_niche{max_niche_cell_num}']
# adata.obsm['niche_composition'] = adata.obsm[f'niche_composition_niche{max_niche_cell_num}']
# adata.obsm['niche_ligands_expression'] = adata.obsm[f'niche_ligands_expression_niche{max_niche_cell_num}']

# selected_celltype = 'Fibroblast'
# adata = adata[adata.obs['cell_type']==list(adata.uns['cell_types_list']).index(selected_celltype)+1]
np.random.seed(0)
shuffled_indices = np.random.permutation(adata.n_obs)
adata = adata[shuffled_indices]
# adata = adata[:2000]

tokenizer = Tokenizer(tokenizer_dir, adata, label_path, vocab, pad_value, pad_token, pathway_ligands)

epochs = 2
batch_size = 30

n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
mean_curve_x = np.linspace(0, 1, 100)
all_curve_y = []
all_auc = []
all_wt = []

split = 0
for train_index, valid_index in skf.split(tokenizer.gene_targets, tokenizer.gene_labels):
    split += 1
    logger.info(f"Cross-validate on dataset {dataset} slide {slide} - split {split}")
    tokenizer.tokenize_data(train_index, valid_index)
    tokenizer.prepare_data()
    train_loader, valid_loader = tokenizer.prepare_dataloader(batch_size)
    
    model = initialize_model(model_file)
    model = nn.DataParallel(model, device_ids = [0, 1, 2])
    device = torch.device("cuda:0")
    model.to(device)

    interp_curve_y, best_curve_x, best_curve_y = train_and_evaluate(model, train_loader, valid_loader, epochs, mean_curve_x, mode, curve)
    all_curve_y.append(interp_curve_y)
    all_auc.append(auc(best_curve_x, best_curve_y))
    all_wt.append(len(best_curve_y))

In [None]:
import math 

def get_cross_valid_metrics(all_curve_y, all_auc, all_wt, curve):
    wts = [count/sum(all_wt) for count in all_wt]
    print(wts)
    all_weighted_curve_y = [a*b for a,b in zip(all_curve_y, wts)]
    mean_curve_y = np.sum(all_weighted_curve_y, axis=0)
    if curve == 'roc':
        mean_curve_y[-1] = 1.0
    all_weighted_auc = [a*b for a,b in zip(all_auc, wts)]
    auc_mean = np.sum(all_weighted_auc)
    auc_sd = math.sqrt(np.average((all_auc-auc_mean)**2, weights=wts))
    return mean_curve_y, auc_mean, auc_sd, wts
 
mean_curve_y, auc_mean, auc_sd, wts = get_cross_valid_metrics(all_curve_y, all_auc, all_wt, curve)

print(f"Mean AUC: {auc_mean} +/- {auc_sd}")
cv_results = {'auc_mean':auc_mean, 'auc_sd':auc_sd, 'mean_curve_x':mean_curve_x, 'mean_curve_y':mean_curve_y, 'all_auc':all_auc, 'wts':wts, 'all_curve_y':all_curve_y}
pickle.dump(cv_results, open(f'figures/gene_classification/stformer_gcl_{slide}_nfkb_{curve}.pkl', 'wb'))

In [None]:
import matplotlib.pyplot as plt

def plot_ROC(bundled_data, title, curve):
    plt.figure()
    lw = 2
    for auc_mean, auc_sd, mean_curve_x, mean_curve_y, sample, color in bundled_data:
        plt.plot(mean_curve_x, mean_curve_y, color=color,
                 lw=lw, label="{0} (AUC {1:0.3f} $\pm$ {2:0.3f})".format(sample, auc_mean, auc_sd))
    if curve == 'roc':
        plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
    elif curve == 'prc':
        plt.xlabel('Recall')
        plt.ylabel('Precision')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.title(title)
    plt.legend(loc="lower right")
    plt.show()

bundled_data = [(auc_mean, auc_sd, mean_curve_x, mean_curve_y, "stFormer", "red")]

plot_ROC(bundled_data, 'Gene classification', curve)