In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
from typing import Dict
import time
import warnings
warnings.filterwarnings('ignore')
import copy
import pickle
import itertools
from tqdm import tqdm
import random
import scanpy as sc
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import auc, roc_curve
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from flash_attn.bert_padding import pad_input

from stformer import logger
from stformer.tokenizer import GeneVocab
from stformer.tokenizer import tokenize_and_pad_batch
from stformer.model import TransformerModel

In [2]:
class SlideData():
    def __init__(self, data_path, slide, lr_path, vocab, pad_value, pad_token):
        self.data_path = data_path
        self.slide = slide
        self.lr_path = lr_path
        self.vocab = vocab
        self.pad_value = pad_value
        self.pad_token = pad_token
        self.load_data()
    
    def load_data(self):
        adata = sc.read_h5ad(f'{self.data_path}/{self.slide}_Visium_deconv.h5ad')

        scfoundation_gene_df = pd.read_csv(f'{tokenizer_dir}/scfoundation_gene_df.csv')
        scfoundation_gene_df.set_index('gene_ids', inplace=True)
        total_gene_num = adata.shape[1]
        adata = adata[:, adata.var_names.isin(scfoundation_gene_df.index)]
        adata.var['gene_name'] = scfoundation_gene_df.loc[adata.var_names, 'gene_symbols'].values
        seleted_gene_num = adata.shape[1]
        genes = adata.var["gene_name"].tolist()
        gene_ids = np.array(self.vocab(genes), dtype=int)

        logger.info(
            f"match {seleted_gene_num}/{total_gene_num} genes "
            f"in vocabulary of size 19264."
        )

        ligand_database = pd.read_csv(tokenizer_dir+'ligand_database.csv', header=0, index_col=0)
        ligand_symbol = ligand_database[ligand_database.sum(1)>1].index.values
        ligand_ids = self.vocab(ligand_symbol.tolist())
        adata = adata[(adata[:,adata.var['gene_name'].isin(ligand_symbol)].X.sum(1)>0).A.T[0],:]
            
        celltype_proportion = adata.obsm['q05_cell_abundance_w_sf']
        celltype_proportion.rename(columns=lambda x: x[23:], inplace=True)
        celltype_proportion = celltype_proportion.div(celltype_proportion.sum(axis=1), axis=0)
        celltype_proportion[celltype_proportion < 0.05] = 0
        celltype_proportion = celltype_proportion.div(celltype_proportion.sum(axis=1), axis=0)

        for celltype in adata.layers.keys():
            adata.X = adata.layers[celltype]
            sc.pp.normalize_total(adata, target_sum=1e4)
            sc.pp.log1p(adata)
            adata.uns.pop('log1p')
            adata.layers[celltype] = adata.X

        self.adata = adata
        self.celltype_proportion = celltype_proportion
        self.gene_ids = gene_ids
        self.ligand_ids = ligand_ids
    
    def get_lr_pairs(self):
        lr_df = pd.read_csv(self.lr_path, sep='\t', header=0)
        lr_df = lr_df[lr_df['ligand_ensembl_gene_id'].isin(self.adata.var_names) & lr_df['receptor_ensembl_gene_id'].isin(self.adata.var_names)]
        lr_df['ligand_gene_symbol'] = self.adata.var.loc[lr_df['ligand_ensembl_gene_id'], 'gene_name'].values
        lr_df['receptor_gene_symbol'] = self.adata.var.loc[lr_df['receptor_ensembl_gene_id'], 'gene_name'].values
        lr_df['ligand_gene_id'] = self.vocab(lr_df['ligand_gene_symbol'].tolist())
        lr_df['receptor_gene_id'] = self.vocab(lr_df['receptor_gene_symbol'].tolist())
        lr_df = lr_df[lr_df['ligand_gene_id'].isin(self.ligand_ids)]
        lr_pairs = list(zip(lr_df['ligand_gene_id'], lr_df['receptor_gene_id']))
        
        random.seed(0)
        lr_df = lr_df.iloc[random.sample(range(lr_df.shape[0]), 300)]
        ligand_ids = list(set(lr_df['ligand_gene_id']))
        receptor_ids = list(set(lr_df['receptor_gene_id']))
        ligand_receptor_ids = ligand_ids + receptor_ids
        ligand_receptor_labels = [1]*len(ligand_ids) + [0]*len(receptor_ids)
        
        self.ligand_receptor_ids = np.array(ligand_receptor_ids)
        self.ligand_receptor_labels = np.array(ligand_receptor_labels)
        self.lr_pairs = set(lr_pairs)

    def get_niche_samples(self, sample_num):
        samples_expression = []
        samples_ctprop = []
        celltypes_labels = []
        for i in range(self.adata.shape[0]):
            ct_prop = self.celltype_proportion.iloc[i][self.celltype_proportion.iloc[i]>0]

            niche_counts = np.concatenate([self.adata.layers[ct][i].A for ct in ct_prop.index])
            niche_counts[:,~np.isin(self.gene_ids, self.ligand_ids)] = 0
            niche_ctprop = ct_prop.values

            for ct in ct_prop.index:
                counts = self.adata.layers[ct][i].A
                samples_expression.append(np.concatenate([counts, niche_counts],axis=0))
                samples_ctprop.append(niche_ctprop)
                celltypes_labels.append(ct)

        self.expression = samples_expression[:sample_num]
        self.ctprop = samples_ctprop[:sample_num]
        self.celltypes = celltypes_labels[:sample_num]

        # self.expression = np.array(self.expression)[np.where(np.array(self.celltypes)=='myofibroblast cell')[0]]
        # self.ctprop = np.array(self.ctprop)[np.where(np.array(self.celltypes)=='myofibroblast cell')[0]]

    def _pad_information_of_split_token(self, token_num):
        max_token_num = token_num.max().item()
        total_cell_num = token_num.size(0)
        key_padding_mask = torch.zeros((total_cell_num, max_token_num), dtype=torch.bool)
        for i,val in enumerate(token_num):
            key_padding_mask[i, val:] = True
        indices = (~key_padding_mask.view(-1)).nonzero(as_tuple=True)[0]
        return indices, total_cell_num, max_token_num, key_padding_mask
    
    def tokenize_data(self, train_index, valid_index):
        ligand_receptor_ids_train = self.ligand_receptor_ids[train_index]
        ligand_receptor_labels_train = self.ligand_receptor_labels[train_index]
        ligand_receptor_ids_valid = self.ligand_receptor_ids[valid_index]
        ligand_receptor_labels_valid = self.ligand_receptor_labels[valid_index]

        ligand_ids_train = set(ligand_receptor_ids_train[np.where(ligand_receptor_labels_train==1)[0]])
        receptor_ids_train = set(ligand_receptor_ids_train[np.where(ligand_receptor_labels_train==0)[0]])
        all_pairs_train = set(itertools.product(ligand_ids_train, receptor_ids_train))
        pos_lr_train = list(all_pairs_train.intersection(self.lr_pairs))
        random.seed(1)
        neg_lr_train = random.sample(sorted(all_pairs_train.difference(self.lr_pairs)), len(pos_lr_train))
        lr_train = pos_lr_train + neg_lr_train
        
        logger.info(f"number of pos/neg lr pairs in train set: {len(pos_lr_train)} / {len(neg_lr_train)}")
        
        ligand_ids_valid = set(ligand_receptor_ids_valid[np.where(ligand_receptor_labels_valid==1)[0]])
        receptor_ids_valid = set(ligand_receptor_ids_valid[np.where(ligand_receptor_labels_valid==0)[0]])
        all_pairs_valid = set(itertools.product(ligand_ids_valid, receptor_ids_valid))
        pos_lr_valid = list(all_pairs_valid.intersection(self.lr_pairs))
        random.seed(2)
        neg_lr_valid = random.sample(sorted(all_pairs_valid.difference(self.lr_pairs)), len(pos_lr_valid))
        lr_valid = pos_lr_valid + neg_lr_valid

        logger.info(f"number of pos/neg lr pairs in valid set: {len(pos_lr_valid)} / {len(neg_lr_valid)}")
        
        samples_expression = self.expression
        samples_l_ids_train = [set(self.gene_ids[np.nonzero(d[0])[0]]).intersection(ligand_ids_train) for d in samples_expression]
        samples_r_ids_train = [set(self.gene_ids[np.nonzero(d[0])[0]]).intersection(receptor_ids_train) for d in samples_expression]
        train_index = [k for k in range(len(samples_expression)) if len(set(itertools.product(samples_l_ids_train[k], samples_r_ids_train[k])).intersection(set(lr_train))) > 0]
        samples_l_ids_valid = [set(self.gene_ids[np.nonzero(d[0])[0]]).intersection(ligand_ids_valid) for d in samples_expression]
        samples_r_ids_valid = [set(self.gene_ids[np.nonzero(d[0])[0]]).intersection(receptor_ids_valid) for d in samples_expression]
        valid_index = [k for k in range(len(samples_expression)) if len(set(itertools.product(samples_l_ids_valid[k], samples_r_ids_valid[k])).intersection(set(lr_valid))) > 0]

        train_data = [samples_expression[k] for k in train_index]
        train_ctprop = [self.ctprop[k] for k in train_index]
        valid_data = [samples_expression[k] for k in valid_index]
        valid_ctprop = [self.ctprop[k] for k in valid_index]
        
        max_seq_len = np.max(np.count_nonzero(self.adata.X.A, axis=1))+2
        max_niche_cell_num = (self.celltype_proportion>0).sum(1).max()
        self.max_seq_len = max_seq_len
        self.max_niche_cell_num = max_niche_cell_num

        tokenized_train = tokenize_and_pad_batch(
            train_data,
            train_ctprop,
            self.gene_ids,
            max_len = max_seq_len,
            max_niche_cell_num = max_niche_cell_num,
            vocab = self.vocab,
            pad_token = self.pad_token,
            pad_value = self.pad_value,
            append_cls = False,  # append <cls> token at the beginning
            include_zero_gene = False,
        )

        tokenized_valid = tokenize_and_pad_batch(
            valid_data,
            valid_ctprop,
            self.gene_ids,
            max_len = max_seq_len,
            max_niche_cell_num = max_niche_cell_num,
            vocab = self.vocab,
            pad_token = self.pad_token,
            pad_value = self.pad_value,
            append_cls = False,
            include_zero_gene = False,
        )

        tokenized_train['niche_l'] = torch.isin(tokenized_train['niche_genes'],torch.tensor(list(ligand_ids_train)))
        tokenized_valid['niche_l'] = torch.isin(tokenized_valid['niche_genes'],torch.tensor(list(ligand_ids_valid)))
        tokenized_train['center_l'] = torch.isin(tokenized_train['center_genes'],torch.tensor(list(ligand_ids_train)))
        tokenized_valid['center_l'] = torch.isin(tokenized_valid['center_genes'],torch.tensor(list(ligand_ids_valid)))
        tokenized_train['center_r'] = torch.isin(tokenized_train['center_genes'],torch.tensor(list(receptor_ids_train)))
        tokenized_valid['center_r'] = torch.isin(tokenized_valid['center_genes'],torch.tensor(list(receptor_ids_valid)))
        
        ligand_ids_train = ligand_ids_train.union({self.vocab[self.pad_token]})
        receptor_ids_train = receptor_ids_train.union({self.vocab[self.pad_token]})
        all_pairs_train = set(itertools.product(ligand_ids_train, receptor_ids_train))
        lr2label_train = dict(zip(all_pairs_train, [-100]*len(all_pairs_train)))
        for lr in pos_lr_train:
            lr2label_train[lr] = 1
        for lr in neg_lr_train:
            lr2label_train[lr] = 0

        ligand_ids_valid = ligand_ids_valid.union({self.vocab[self.pad_token]})
        receptor_ids_valid = receptor_ids_valid.union({self.vocab[self.pad_token]})
        all_pairs_valid = set(itertools.product(ligand_ids_valid, receptor_ids_valid))
        lr2label_valid = dict(zip(all_pairs_valid, [-100]*len(all_pairs_valid)))
        for lr in pos_lr_valid:
            lr2label_valid[lr] = 1
        for lr in neg_lr_valid:
            lr2label_valid[lr] = 0

        split_indices, total_cell_num, max_l_num, split_key_padding_mask = self._pad_information_of_split_token(tokenized_train['center_l'].sum(1))
        l_ids_train = pad_input(tokenized_train['center_genes'][tokenized_train['center_l']].unsqueeze(-1), split_indices, total_cell_num, max_l_num).squeeze(-1)
        l_ids_train[split_key_padding_mask] = self.vocab[self.pad_token]
        split_indices, total_cell_num, max_r_num, split_key_padding_mask = self._pad_information_of_split_token(tokenized_train['center_r'].sum(1))
        r_ids_train = pad_input(tokenized_train['center_genes'][tokenized_train['center_r']].unsqueeze(-1), split_indices, total_cell_num, max_r_num).squeeze(-1)
        r_ids_train[split_key_padding_mask] = self.vocab[self.pad_token]
        lr_pairs_train = [list(itertools.product(l_ids_train[i].tolist(), r_ids_train[i].tolist())) for i in range(total_cell_num)]
        tokenized_train['lr_labels'] = torch.tensor([[lr2label_train[lr] for lr in cell] for cell in lr_pairs_train])
        
        split_indices, total_cell_num, max_l_num, split_key_padding_mask = self._pad_information_of_split_token(tokenized_valid['center_l'].sum(1))
        l_ids_valid = pad_input(tokenized_valid['center_genes'][tokenized_valid['center_l']].unsqueeze(-1), split_indices, total_cell_num, max_l_num).squeeze(-1)
        l_ids_valid[split_key_padding_mask] = self.vocab[self.pad_token]
        split_indices, total_cell_num, max_r_num, split_key_padding_mask = self._pad_information_of_split_token(tokenized_valid['center_r'].sum(1))
        r_ids_valid = pad_input(tokenized_valid['center_genes'][tokenized_valid['center_r']].unsqueeze(-1), split_indices, total_cell_num, max_r_num).squeeze(-1)
        r_ids_valid[split_key_padding_mask] = self.vocab[self.pad_token]
        lr_pairs_valid = [list(itertools.product(l_ids_valid[i].tolist(), r_ids_valid[i].tolist())) for i in range(total_cell_num)]
        tokenized_valid['lr_labels'] = torch.tensor([[lr2label_valid[lr] for lr in cell] for cell in lr_pairs_valid])
        
        logger.info(
            f"train set number of samples: {tokenized_train['center_genes'].shape[0]}, "
            f"\n\t feature length of center cell: {tokenized_train['center_genes'].shape[1]}"
            f"\n\t feature length of niche cells: {tokenized_train['niche_genes'].shape[1]}"
            f"\n\t feature length of lr pairs: {tokenized_train['lr_labels'].shape[1]}"
            f"\n\t number of pos/neg lr pairs: {(tokenized_train['lr_labels']==1).sum().item()} / {(tokenized_train['lr_labels']==0).sum().item()}"
        )
        logger.info(
            f"valid set number of samples: {tokenized_valid['center_genes'].shape[0]}, "
            f"\n\t feature length of center cell: {tokenized_valid['center_genes'].shape[1]}"
            f"\n\t feature length of niche cells: {tokenized_valid['niche_genes'].shape[1]}"
            f"\n\t feature length of lr pairs: {tokenized_valid['lr_labels'].shape[1]}"
            f"\n\t number of pos/neg lr pairs: {(tokenized_valid['lr_labels']==1).sum().item()} / {(tokenized_valid['lr_labels']==0).sum().item()}"
        )

        self.tokenized_train = tokenized_train
        self.tokenized_valid = tokenized_valid

    def prepare_data(self):
        self.train_data_pt = {
            "center_gene_ids": self.tokenized_train["center_genes"],
            "input_center_values": self.tokenized_train["center_values"],
            "target_center_values": self.tokenized_train["center_values"],
            "niche_gene_ids": self.tokenized_train["niche_genes"],
            "input_niche_values": self.tokenized_train["niche_values"],
            "niche_feature_lens": self.tokenized_train["niche_feature_lens"],
            "cross_attn_bias": self.tokenized_train["cross_attn_bias"],
            "niche_l": self.tokenized_train["niche_l"],
            "center_l": self.tokenized_train["center_l"],
            "center_r": self.tokenized_train["center_r"],
            "lr_labels": self.tokenized_train["lr_labels"],
        }

        self.valid_data_pt = {
            "center_gene_ids": self.tokenized_valid["center_genes"],
            "input_center_values": self.tokenized_valid["center_values"],
            "target_center_values": self.tokenized_valid["center_values"],
            "niche_gene_ids": self.tokenized_valid["niche_genes"],
            "input_niche_values": self.tokenized_valid["niche_values"],
            "niche_feature_lens": self.tokenized_valid["niche_feature_lens"],
            "cross_attn_bias": self.tokenized_valid["cross_attn_bias"],
            "niche_l": self.tokenized_valid["niche_l"],
            "center_l": self.tokenized_valid["center_l"],
            "center_r": self.tokenized_valid["center_r"],
            "lr_labels": self.tokenized_valid["lr_labels"],
        }
    
    def prepare_dataloader(self, batch_size):
        train_loader = DataLoader(
            dataset=SeqDataset(self.train_data_pt),
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=min(len(os.sched_getaffinity(0)), batch_size // 2),
            pin_memory=True,
        )

        valid_loader = DataLoader(
            dataset=SeqDataset(self.valid_data_pt),
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=min(len(os.sched_getaffinity(0)), batch_size // 2),
            pin_memory=True,
        )
        return train_loader, valid_loader

class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["center_gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}

In [3]:
criterion_cls = nn.CrossEntropyLoss()

def train(model: nn.Module, loader: DataLoader, valid_loader, max_batch) -> None:
    """
    Train the model for one epoch.
    """
    lr = 1e-4
    amp = True
    schedule_ratio = 0.9
    schedule_interval = 1
    log_interval = 10

    optimizer = torch.optim.Adam(
            model.parameters(), lr=lr, eps=1e-4 if amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, schedule_interval, gamma=schedule_ratio
    )
    scaler = torch.cuda.amp.GradScaler(enabled=amp)

    model.train()
    total_lrc = 0.0
    total_error = 0.0

    best_model = None
    best_auc_value = 0
    best_fpr = 0
    best_tpr = 0

    max_l_seqlen = slideData.tokenized_train['center_l'].sum(1).max().item()
    max_r_seqlen = slideData.tokenized_train['center_r'].sum(1).max().item()
 
    start_time = time.time()

    num_batches = len(loader)

    for batch, batch_data in enumerate(loader):
        if batch > max_batch:
            break
        niche_feature_lens = batch_data["niche_feature_lens"].to(device)
        # if niche_feature_lens.size(0)<7:
        #     continue
        center_gene_ids = batch_data["center_gene_ids"].to(device)
        input_center_values = batch_data["input_center_values"].to(device)
        niche_gene_ids = batch_data["niche_gene_ids"].to(device)
        input_niche_values = batch_data["input_niche_values"].to(device)
        cross_attn_bias = batch_data["cross_attn_bias"].to(device)
        
        niche_l = batch_data["niche_l"].to(device)
        center_l = batch_data["center_l"].to(device)
        center_r = batch_data["center_r"].to(device)
        lr_labels = batch_data["lr_labels"].to(device)
        
        encoder_src_key_padding_mask = niche_gene_ids.eq(vocab[pad_token])
        decoder_src_key_padding_mask = center_gene_ids.eq(vocab[pad_token])

        with torch.cuda.amp.autocast(enabled=amp):
            output_dict = model(
                    niche_gene_ids,
                    input_niche_values,
                    encoder_src_key_padding_mask,
                    center_gene_ids,
                    input_center_values,
                    decoder_src_key_padding_mask,
                    cross_attn_bias,
                    niche_l = niche_l,
                    center_l = center_l,
                    center_r = center_r,
                    max_l_seqlen = max_l_seqlen,
                    max_r_seqlen = max_r_seqlen,
                    LRC = True,
                )
    
            lrc_output = output_dict["lrc_output"]
            batch_logits = lrc_output[torch.logical_or(lr_labels==1, lr_labels==0)]
            batch_labels = lr_labels[torch.logical_or(lr_labels==1, lr_labels==0)]
            loss_lrc = criterion_cls(batch_logits, batch_labels)

            error_rate_lrc = 1 - (
                    (batch_logits.argmax(1) == batch_labels)
                    .sum()
                    .item()
                ) / batch_labels.size(0)

        model.zero_grad()
        scaler.scale(loss_lrc).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"Found infinite gradient. This may be caused by the gradient "
                    f"scaler. The current scale is {scaler.get_scale()}. This warning "
                    "can be ignored if no longer occurs after autoscaling of the scaler."
                )
        scaler.step(optimizer)
        scaler.update()

        total_lrc += loss_lrc.item()
        total_error += error_rate_lrc

        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            sec_per_batch = (time.time() - start_time) / log_interval
            cur_lrc = total_lrc / log_interval
            cur_error = total_error / log_interval
            logger.info(
                f"| Split {split} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.8f} | sec/batch {sec_per_batch:5.1f} | "
                f"gcl {cur_lrc:5.5f} | "
                f"err {cur_error:1.5f} | "
            )
            total_lrc = 0
            total_error = 0
            start_time = time.time()       
        if batch % (10*log_interval) == 0 and batch > 0:
            auc_value, fpr, tpr = evaluate(model, valid_loader)
            if auc_value > best_auc_value:
                best_auc_value = auc_value
                best_fpr = fpr
                best_tpr = tpr
                best_model = copy.deepcopy(model)
            model.train()
            start_time = time.time()
    
    return best_model, best_fpr, best_tpr

def py_softmax(vector):
	e = np.exp(vector)
	return e / e.sum()

def evaluate(model: nn.Module, loader: DataLoader) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    amp = True
    
    model.eval()
    total_lrc = 0.0
    total_error = 0.0
    total_num = 0

    max_l_seqlen = slideData.tokenized_valid['center_l'].sum(1).max().item()
    max_r_seqlen = slideData.tokenized_valid['center_r'].sum(1).max().item()

    logits = []
    labels = []
    
    batch_num = 0
    with torch.no_grad():
        for batch_data in tqdm(loader):
            batch_num += 1
            # if batch_num>100:
            #     break
            niche_feature_lens = batch_data["niche_feature_lens"].to(device)
            # if niche_feature_lens.size(0)<7:
            #     continue
            center_gene_ids = batch_data["center_gene_ids"].to(device)
            input_center_values = batch_data["input_center_values"].to(device)
            niche_gene_ids = batch_data["niche_gene_ids"].to(device)
            input_niche_values = batch_data["input_niche_values"].to(device)
            cross_attn_bias = batch_data["cross_attn_bias"].to(device)

            niche_l = batch_data["niche_l"].to(device)
            center_l = batch_data["center_l"].to(device)
            center_r = batch_data["center_r"].to(device)
            lr_labels = batch_data["lr_labels"].to(device)

            encoder_src_key_padding_mask = niche_gene_ids.eq(vocab[pad_token])
            decoder_src_key_padding_mask = center_gene_ids.eq(vocab[pad_token])

            with torch.cuda.amp.autocast(enabled=amp):
                output_dict = model(
                        niche_gene_ids,
                        input_niche_values,
                        encoder_src_key_padding_mask,
                        center_gene_ids,
                        input_center_values,
                        decoder_src_key_padding_mask,
                        cross_attn_bias,
                        niche_l = niche_l,
                        center_l = center_l,
                        center_r = center_r,
                        max_l_seqlen = max_l_seqlen,
                        max_r_seqlen = max_r_seqlen,
                        LRC = True,
                    )
                lrc_output = output_dict["lrc_output"]
                batch_logits = lrc_output[torch.logical_or(lr_labels==1, lr_labels==0)]
                batch_labels = lr_labels[torch.logical_or(lr_labels==1, lr_labels==0)]
                logits.append(batch_logits.to('cpu'))
                labels.append(batch_labels.to('cpu'))

            accuracy = (batch_logits.argmax(1) == batch_labels).sum().item()
            total_error += batch_labels.size(0) - accuracy
            total_num += batch_labels.size(0)
            total_lrc += criterion_cls(batch_logits, batch_labels).item()*batch_labels.size(0)

    logits = torch.cat(logits)
    labels = torch.cat(labels)

    y_score = [py_softmax(item)[1] for item in logits.numpy()]
    y_true = labels.numpy()
    fpr, tpr, _ = roc_curve(y_true, y_score)
    auc_value = auc(fpr, tpr)
    
    val_err = total_error / total_num
    val_loss = total_lrc / total_num
    logger.info("-" * 89)
    logger.info(
        f"valid accuracy: {1-val_err:1.4f} | "
        f"valid auc: {auc_value:1.4f} | "
        f"valid loss {val_loss:1.4f} | "
        f"valid err {val_err:1.4f}"
    )
    logger.info("-" * 89)

    return auc_value, fpr, tpr        

def train_and_evaluate(model, train_loader, valid_loader, max_batch):

    best_model, best_fpr, best_tpr = train(model, train_loader, valid_loader, max_batch)

    return best_model, best_fpr, best_tpr

In [4]:
from scfoundation import load

def initialize_model(model_file):
    pretrainmodel, pretrainconfig = load.load_model_frommmf('scfoundation/models/models.ckpt')
    
    model = TransformerModel(
        embsize,
        nhead,
        d_hid,
        nlayers,
        do_lrc = True,
        nlayers_lrc = 3,
        n_lrc = 2,
        dropout = dropout,
        cell_emb_style = cell_emb_style,
        scfoundation_token_emb1 = copy.deepcopy(pretrainmodel.token_emb),
        scfoundation_token_emb2 = copy.deepcopy(pretrainmodel.token_emb),
        scfoundation_pos_emb1 = copy.deepcopy(pretrainmodel.pos_emb),
        scfoundation_pos_emb2 = copy.deepcopy(pretrainmodel.pos_emb),
    )

    pt_model = torch.load(model_file, map_location='cpu')

    model_dict = model.state_dict()
    pretrained_dict = pt_model.state_dict()
    pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if 'lrc_decoder' not in k and 'gcl_decoder' not in k
                # if k in model_dict and v.shape == model_dict[k].shape
    }
    # for k, v in pretrained_dict.items():
    #     logger.info(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
    for name, para in model.named_parameters():
        para.requires_grad = False
    for name, para in model.lrc_decoder.named_parameters():
        para.requires_grad = True
    post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

    logger.info(f"Total Pre freeze Params {(pre_freeze_param_count )}")
    logger.info(f"Total Post freeze Params {(post_freeze_param_count )}")
    
    return model

In [None]:
embsize = 768 #256
d_hid = 3072 #1024
nhead = 12 #4
nlayers = 6 #12
dropout = 0.1
cell_emb_style = 'max-pool'
LRC = True

pad_token = "<pad>"
pad_value = 103
tokenizer_dir = '../stformer/tokenizer/'
vocab_file = tokenizer_dir + "scfoundation_gene_vocab.json"
vocab = GeneVocab.from_file(vocab_file)
vocab.append_token(pad_token)
vocab.set_default_index(vocab[pad_token])

In [None]:
model_file = '../pretraining/model.ckpt'

dataset = 'human_myocardial_infarction_dataset'
slide = '10X001'
data_path = f'../data/{dataset}/'
lr_path = 'gene_lists/human_LR_pairs.txt'
slideData = SlideData(data_path, slide, lr_path, vocab, pad_value, pad_token)
slideData.get_lr_pairs()

sample_num = 5000
slideData.get_niche_samples(sample_num)

batch_size = 6
max_batch = 500

n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
all_tpr = []
all_roc_auc = []
all_tpr_wt = []

split = 0
for train_index, valid_index in skf.split(slideData.ligand_receptor_ids, slideData.ligand_receptor_labels):
    split += 1
    logger.info(f"Cross-validate on dataset {dataset} slide {slide} - split {split}")
    slideData.tokenize_data(train_index, valid_index)
    slideData.prepare_data()
    train_loader, valid_loader = slideData.prepare_dataloader(batch_size)

    model = initialize_model(model_file)
    model = nn.DataParallel(model, device_ids = [0, 3, 1])
    device = torch.device("cuda:0")
    model.to(device)

    best_model, best_fpr, best_tpr = train_and_evaluate(model, train_loader, valid_loader, max_batch)
    
    mean_fpr = np.linspace(0, 1, 100)
    interp_tpr = np.interp(mean_fpr, best_fpr, best_tpr)
    interp_tpr[0] = 0.0
    all_tpr.append(interp_tpr)
    all_roc_auc.append(auc(best_fpr, best_tpr))
    all_tpr_wt.append(len(best_tpr))

In [None]:
import math 

def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
    wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]
    print(wts)

    all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]
    mean_tpr = np.sum(all_weighted_tpr, axis=0)
    mean_tpr[-1] = 1.0
    median_tpr = np.median(all_tpr, axis=0)
    median_tpr[-1] = 1.0

    all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]
    roc_auc_mean = np.sum(all_weighted_roc_auc)
    roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc_mean)**2, weights=wts))

    roc_auc_median = auc(mean_fpr, median_tpr)

    return mean_tpr, median_tpr, roc_auc_mean, roc_auc_sd, roc_auc_median, wts

mean_tpr, median_tpr, roc_auc_mean, roc_auc_sd, roc_auc_median, wts = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)

print(f"Mean ROC AUC: {roc_auc_mean} +/- {roc_auc_sd}")
cv_results = {'roc_auc_mean':roc_auc_mean, 'roc_auc_sd':roc_auc_sd, 'roc_auc_median':roc_auc_median, 'mean_fpr':mean_fpr, 'mean_tpr':mean_tpr, 'median_tpr':median_tpr, 'all_roc_auc':all_roc_auc, 'wts':wts}
pickle.dump(cv_results, open(f'roc_results/stformer_lrc_{slide}_random2.pkl', 'wb'))

In [None]:
import matplotlib.pyplot as plt

def plot_ROC(bundled_data, title):
    plt.figure()
    lw = 2
    for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
        plt.plot(mean_fpr, mean_tpr, color=color,
                 lw=lw, label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(sample, roc_auc, roc_auc_sd))
    plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc="lower right")
    plt.show()

bundled_data = [(roc_auc_mean, roc_auc_sd, mean_fpr, mean_tpr, "stFormer", "red")]

plot_ROC(bundled_data, 'Gene classification')