In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import warnings
warnings.filterwarnings('ignore')
import pickle
import time
import itertools
import copy
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import auc, roc_curve
import anndata
import scanpy as sc
import random 
import torch
from torch import nn, Tensor
from flash_attn.bert_padding import pad_input
from scfoundation import load

In [2]:
class SlideData():
    def __init__(self, data_path, slide, lr_path, pad_value, pad_token):
        self.data_path = data_path
        self.slide = slide
        self.lr_path = lr_path
        self.pad_value = pad_value
        self.pad_token = pad_token
        self.load_data()

    def load_data(self):
        adata = sc.read_h5ad(f'{self.data_path}/{self.slide}_Visium_deconv.h5ad')
        
        scfoundation_gene_df = pd.read_csv(f'{tokenizer_dir}/scfoundation_gene_df.csv')
        scfoundation_gene_df.set_index('gene_ids', inplace=True)
        total_gene_num = adata.shape[1]
        adata = adata[:, adata.var_names.isin(scfoundation_gene_df.index)]
        adata.var['gene_name'] = scfoundation_gene_df.loc[adata.var_names, 'gene_symbols'].values
        seleted_gene_num = adata.shape[1]

        print(
            f"match {seleted_gene_num}/{total_gene_num} genes "
            f"in vocabulary of size 19264."
        )

        for celltype in adata.layers.keys():
            adata.X = adata.layers[celltype]
            sc.pp.normalize_total(adata, target_sum=1e4)
            sc.pp.log1p(adata)
            adata.uns.pop('log1p')
            adata.layers[celltype] = adata.X

        celltype_proportion = adata.obsm['q05_cell_abundance_w_sf']
        celltype_proportion.rename(columns=lambda x: x[23:], inplace=True)
        celltype_proportion = celltype_proportion.div(celltype_proportion.sum(axis=1), axis=0)
        celltype_proportion[celltype_proportion < 0.05] = 0
        celltype_proportion = celltype_proportion.div(celltype_proportion.sum(axis=1), axis=0)

        self.adata = adata
        self.celltype_proportion = celltype_proportion

    def get_sc_data(self):
        barcode_list = []
        gexpr_feature = []
        celltypes_labels = []
        for i in range(self.adata.shape[0]):
            barcode = self.adata.obs.index[i]
            ct_prop = self.celltype_proportion.iloc[i][self.celltype_proportion.iloc[i]>0]
            cell_num = 0
            for ct in ct_prop.index:
                celltypes_labels.append(ct)
                cell_num += 1
                barcode_list.append(f'{barcode}_{cell_num}')
                gexpr_feature.append(self.adata.layers[ct][i].A)
        gexpr_feature = np.concatenate(gexpr_feature)

        adata_sc = anndata.AnnData(X=gexpr_feature, obs=pd.DataFrame({'celltype': celltypes_labels}, index=barcode_list), var=pd.DataFrame({'gene_name': self.adata.var['gene_name'].values}, index=self.adata.var_names.values), obsm=None)
        self.adata = adata_sc

    def get_lr_pairs(self):
        gene_list_df = pd.read_csv(f'{tokenizer_dir}/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
        gene_list_df.set_index('gene_name', inplace=True)

        ligand_database = pd.read_csv(tokenizer_dir+'ligand_database.csv', header=0, index_col=0)
        ligand_symbol = ligand_database[ligand_database.sum(1)>1].index.values
        ligand_ids = gene_list_df.loc[ligand_symbol.tolist(),'index'].values

        lr_df = pd.read_csv(self.lr_path, sep='\t', header=0)
        lr_df = lr_df[lr_df['ligand_ensembl_gene_id'].isin(self.adata.var_names) & lr_df['receptor_ensembl_gene_id'].isin(self.adata.var_names)]
        lr_df['ligand_gene_symbol'] = self.adata.var.loc[lr_df['ligand_ensembl_gene_id'], 'gene_name'].values
        lr_df['receptor_gene_symbol'] = self.adata.var.loc[lr_df['receptor_ensembl_gene_id'], 'gene_name'].values
        lr_df['ligand_gene_id'] = gene_list_df.loc[lr_df['ligand_gene_symbol'].tolist(),'index'].values
        lr_df['receptor_gene_id'] = gene_list_df.loc[lr_df['receptor_gene_symbol'].tolist(),'index'].values
        lr_df = lr_df[lr_df['ligand_gene_id'].isin(ligand_ids)]
        lr_pairs = list(zip(lr_df['ligand_gene_id'], lr_df['receptor_gene_id']))
        
        random.seed(0)
        lr_df = lr_df.iloc[random.sample(range(lr_df.shape[0]), 150)]
        ligand_ids = list(set(lr_df['ligand_gene_id']))
        receptor_ids = list(set(lr_df['receptor_gene_id']))
        ligand_receptor_ids = ligand_ids + receptor_ids
        ligand_receptor_labels = [1]*len(ligand_ids) + [0]*len(receptor_ids)
        
        self.ligand_receptor_ids = np.array(ligand_receptor_ids)
        self.ligand_receptor_labels = np.array(ligand_receptor_labels)
        self.lr_pairs = set(lr_pairs)

    def prepare_data(self, sample_num):
        gene_list_df = pd.read_csv(f'{tokenizer_dir}/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
        gene_list = list(gene_list_df['gene_name'])
        self.gene_ids = gene_list_df['index'].values

        gexpr_feature = self.adata.X
        idx = self.adata.obs_names.tolist()
        col = self.adata.var.gene_name.tolist()
        gexpr_feature = pd.DataFrame(gexpr_feature, index=idx, columns=col)
        gexpr_feature, _ = load.main_gene_selection(gexpr_feature, gene_list)
        S = gexpr_feature.sum(1)
        T = S
        TS = np.concatenate([[np.log10(T)],[np.log10(S)]],axis=0).T
        data = np.concatenate([gexpr_feature,TS],axis=1)
        self.data = data[:sample_num]
        # self.data = self.data[np.where(self.adata.obs['celltype'].values=='myofibroblast cell')[0]]

    def _pad_information_of_split_token(self, token_num):
        max_token_num = token_num.max().item()
        total_cell_num = token_num.size(0)
        key_padding_mask = torch.zeros((total_cell_num, max_token_num), dtype=torch.bool)
        for i,val in enumerate(token_num):
            key_padding_mask[i, val:] = True
        indices = (~key_padding_mask.view(-1)).nonzero(as_tuple=True)[0]
        return indices, total_cell_num, max_token_num, key_padding_mask

    def prepare_train_and_valid_data(self, train_index, valid_index):
        ligand_receptor_ids_train = self.ligand_receptor_ids[train_index]
        ligand_receptor_labels_train = self.ligand_receptor_labels[train_index]
        ligand_receptor_ids_valid = self.ligand_receptor_ids[valid_index]
        ligand_receptor_labels_valid = self.ligand_receptor_labels[valid_index]

        ligand_ids_train = set(ligand_receptor_ids_train[np.where(ligand_receptor_labels_train==1)[0]])
        receptor_ids_train = set(ligand_receptor_ids_train[np.where(ligand_receptor_labels_train==0)[0]])
        all_pairs_train = set(itertools.product(ligand_ids_train, receptor_ids_train))
        pos_lr_train = list(all_pairs_train.intersection(self.lr_pairs))
        random.seed(1)
        neg_lr_train = random.sample(sorted(all_pairs_train.difference(self.lr_pairs)), len(pos_lr_train))
        lr_train = pos_lr_train + neg_lr_train
        
        print(f"number of pos/neg lr pairs in train set: {len(pos_lr_train)} / {len(neg_lr_train)}")
        
        ligand_ids_valid = set(ligand_receptor_ids_valid[np.where(ligand_receptor_labels_valid==1)[0]])
        receptor_ids_valid = set(ligand_receptor_ids_valid[np.where(ligand_receptor_labels_valid==0)[0]])
        all_pairs_valid = set(itertools.product(ligand_ids_valid, receptor_ids_valid))
        pos_lr_valid = list(all_pairs_valid.intersection(self.lr_pairs))
        random.seed(2)
        neg_lr_valid = random.sample(sorted(all_pairs_valid.difference(self.lr_pairs)), len(pos_lr_valid))
        lr_valid = pos_lr_valid + neg_lr_valid

        print(f"number of pos/neg lr pairs in valid set: {len(pos_lr_valid)} / {len(neg_lr_valid)}")

        samples_l_ids_train = [set(self.gene_ids[np.nonzero(d[:-2])[0]]).intersection(ligand_ids_train) for d in self.data]
        samples_r_ids_train = [set(self.gene_ids[np.nonzero(d[:-2])[0]]).intersection(receptor_ids_train) for d in self.data]
        train_index = [k for k in range(len(self.data)) if len(set(itertools.product(samples_l_ids_train[k], samples_r_ids_train[k])).intersection(set(lr_train))) > 0]
        samples_l_ids_valid = [set(self.gene_ids[np.nonzero(d[:-2])[0]]).intersection(ligand_ids_valid) for d in self.data]
        samples_r_ids_valid = [set(self.gene_ids[np.nonzero(d[:-2])[0]]).intersection(receptor_ids_valid) for d in self.data]
        valid_index = [k for k in range(len(self.data)) if len(set(itertools.product(samples_l_ids_valid[k], samples_r_ids_valid[k])).intersection(set(lr_valid))) > 0]

        train_data = [self.data[k] for k in train_index]
        train_data = torch.from_numpy(np.array(train_data)).float()
        train_data_gene_ids = torch.arange(train_data.shape[1]).repeat(train_data.shape[0], 1)
        train_data_index = train_data != 0
        train_values, train_padding = load.gatherData(train_data, train_data_index, self.pad_value)
        train_gene_ids, _ = load.gatherData(train_data_gene_ids, train_data_index, self.pad_token)
        train_center_l = torch.isin(train_gene_ids, torch.tensor(list(ligand_ids_train)))
        train_center_r = torch.isin(train_gene_ids, torch.tensor(list(receptor_ids_train)))
        
        valid_data = [self.data[k] for k in valid_index]
        valid_data = torch.from_numpy(np.array(valid_data)).float()
        valid_data_gene_ids = torch.arange(valid_data.shape[1]).repeat(valid_data.shape[0], 1)
        valid_data_index = valid_data != 0
        valid_values, valid_padding = load.gatherData(valid_data, valid_data_index, self.pad_value)
        valid_gene_ids, _ = load.gatherData(valid_data_gene_ids, valid_data_index, self.pad_token)
        valid_center_l = torch.isin(valid_gene_ids, torch.tensor(list(ligand_ids_valid)))
        valid_center_r = torch.isin(valid_gene_ids, torch.tensor(list(receptor_ids_valid)))
        
        ligand_ids_train = ligand_ids_train.union({self.pad_token})
        receptor_ids_train = receptor_ids_train.union({self.pad_token})
        all_pairs_train = set(itertools.product(ligand_ids_train, receptor_ids_train))
        lr2label_train = dict(zip(all_pairs_train, [-100]*len(all_pairs_train)))
        for lr in pos_lr_train:
            lr2label_train[lr] = 1
        for lr in neg_lr_train:
            lr2label_train[lr] = 0

        split_indices, total_cell_num, max_l_num, split_key_padding_mask = self._pad_information_of_split_token(train_center_l.sum(1))
        l_ids_train = pad_input(train_gene_ids[train_center_l].unsqueeze(-1), split_indices, total_cell_num, max_l_num).squeeze(-1)
        l_ids_train[split_key_padding_mask] = self.pad_token
        split_indices, total_cell_num, max_r_num, split_key_padding_mask = self._pad_information_of_split_token(train_center_r.sum(1))
        r_ids_train = pad_input(train_gene_ids[train_center_r].unsqueeze(-1), split_indices, total_cell_num, max_r_num).squeeze(-1)
        r_ids_train[split_key_padding_mask] = self.pad_token
        lr_pairs_train = [list(itertools.product(l_ids_train[i].tolist(), r_ids_train[i].tolist())) for i in range(total_cell_num)]
        train_lr_labels = torch.tensor([[lr2label_train[lr] for lr in cell] for cell in lr_pairs_train])
        train_data = {'values': train_values, 'padding': train_padding, 'gene_ids': train_gene_ids, 'ligand':train_center_l, 'receptor':train_center_r, 'lr_labels': train_lr_labels}
        
        ligand_ids_valid = ligand_ids_valid.union({self.pad_token})
        receptor_ids_valid = receptor_ids_valid.union({self.pad_token})
        all_pairs_valid = set(itertools.product(ligand_ids_valid, receptor_ids_valid))
        lr2label_valid = dict(zip(all_pairs_valid, [-100]*len(all_pairs_valid)))
        for lr in pos_lr_valid:
            lr2label_valid[lr] = 1
        for lr in neg_lr_valid:
            lr2label_valid[lr] = 0

        split_indices, total_cell_num, max_l_num, split_key_padding_mask = self._pad_information_of_split_token(valid_center_l.sum(1))
        l_ids_valid = pad_input(valid_gene_ids[valid_center_l].unsqueeze(-1), split_indices, total_cell_num, max_l_num).squeeze(-1)
        l_ids_valid[split_key_padding_mask] = self.pad_token
        split_indices, total_cell_num, max_r_num, split_key_padding_mask = self._pad_information_of_split_token(valid_center_r.sum(1))
        r_ids_valid = pad_input(valid_gene_ids[valid_center_r].unsqueeze(-1), split_indices, total_cell_num, max_r_num).squeeze(-1)
        r_ids_valid[split_key_padding_mask] = self.pad_token
        lr_pairs_valid = [list(itertools.product(l_ids_valid[i].tolist(), r_ids_valid[i].tolist())) for i in range(total_cell_num)]
        valid_lr_labels = torch.tensor([[lr2label_valid[lr] for lr in cell] for cell in lr_pairs_valid])
        valid_data = {'values': valid_values, 'padding': valid_padding, 'gene_ids': valid_gene_ids, 'ligand':valid_center_l, 'receptor':valid_center_r, 'lr_labels': valid_lr_labels}

        print(
            f"train set number of samples: {train_gene_ids.shape[0]}, "
            f"\n\t feature length of center cell: {train_gene_ids.shape[1]}"
            f"\n\t feature length of lr pairs: {train_lr_labels.shape[1]}"
            f"\n\t number of pos/neg lr pairs: {(train_lr_labels==1).sum().item()} / {(train_lr_labels==0).sum().item()}"
        )
        print(
            f"valid set number of samples: {valid_gene_ids.shape[0]}, "
            f"\n\t feature length of center cell: {valid_gene_ids.shape[1]}"
            f"\n\t feature length of lr pairs: {valid_lr_labels.shape[1]}"
            f"\n\t number of pos/neg lr pairs: {(valid_lr_labels==1).sum().item()} / {(valid_lr_labels==0).sum().item()}"
        )

        return train_data, valid_data

In [3]:
class scF_lrc(nn.Module):
    def __init__(
            self,
            scf_token_emb,
            scf_pos_emb,
            scf_encoder,
            d_model: int,
            n_lrc: int = 2,
            nlayers_lrc: int = 3,
    ):
        super(scF_lrc, self).__init__()

        # encoder
        self.token_emb = scf_token_emb
        self.pos_emb = scf_pos_emb
        self.encoder = scf_encoder
        self.lrc_decoder = LRCDecoder(d_model, n_lrc, nlayers=nlayers_lrc)

    def _pad_information_of_split_input(self, encoder_feature_lens, max_seqlen):
        total_cell_num = encoder_feature_lens.size(0)
        key_padding_mask = torch.zeros((total_cell_num, max_seqlen), dtype=torch.bool, device=encoder_feature_lens.device)
        for i,val in enumerate(encoder_feature_lens):
            key_padding_mask[i, val:] = True
        indices = (~key_padding_mask.view(-1)).nonzero(as_tuple=True)[0]
        return indices, total_cell_num, max_seqlen, key_padding_mask

    def forward(self, gene_values, padding_label, gene_ids, ligand, receptor, max_l_seqlen, max_r_seqlen):

        # token and positional embedding
        x = self.token_emb(torch.unsqueeze(gene_values, 2), output_weight = 0)

        position_emb = self.pos_emb(gene_ids)
        x += position_emb
        x = self.encoder(x, padding_mask=padding_label)

        split_src_indices, total_cell_num, max_seqlen, _ = self._pad_information_of_split_input(ligand.sum(1), max_l_seqlen)
        x_ligand = pad_input(x[ligand], split_src_indices, total_cell_num, max_seqlen)
        split_src_indices, total_cell_num, max_seqlen, _ = self._pad_information_of_split_input(receptor.sum(1), max_r_seqlen)
        x_receptor = pad_input(x[receptor], split_src_indices, total_cell_num, max_seqlen)
            
        x_lr = [list(itertools.product(x_ligand[i].cpu(), x_receptor[i].cpu())) for i in range(total_cell_num)]
        x_lr = torch.tensor([[torch.concat(lr).tolist() for lr in cell] for cell in x_lr], device=x.device)

        output = self.lrc_decoder(x_lr)

        return output

class LRCDecoder(nn.Module):
    """
    Decoder for ligand-receptor pair classification task.
    """

    def __init__(
        self,
        d_model: int,
        n_lr: int = 2,
        nlayers: int = 3,
        activation: callable = nn.ReLU,
    ):
        super().__init__()
        # module list
        d_model = 2*d_model
        self._decoder = nn.ModuleList()
        for i in range(nlayers - 1):
            self._decoder.append(nn.Linear(d_model, d_model))
            self._decoder.append(activation())
            self._decoder.append(nn.LayerNorm(d_model))
        self.out_layer = nn.Linear(d_model, n_lr)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embsize]
        """
        for layer in self._decoder:
            x = layer(x)
        return self.out_layer(x)

In [4]:
criterion_cls = nn.CrossEntropyLoss()

def train(model: nn.Module, train_data, valid_data, batch_size, max_batch) -> None:
    """
    Train the model for one epoch.
    """
    lr = 1e-4
    amp = True
    schedule_ratio = 0.9
    schedule_interval = 1
    log_interval = 10

    optimizer = torch.optim.Adam(
            model.parameters(), lr=lr, eps=1e-4 if amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, schedule_interval, gamma=schedule_ratio
    )
    scaler = torch.cuda.amp.GradScaler(enabled=amp)

    model.train()
    total_lrc = 0.0
    total_error = 0.0
    
    best_model = None
    best_auc_value = 0
    best_fpr = 0
    best_tpr = 0

    max_l_seqlen = train_data['ligand'].sum(1).max().item()
    max_r_seqlen = train_data['receptor'].sum(1).max().item()
    
    start_time = time.time()

    train_values = train_data['values']
    train_padding = train_data['padding']
    train_gene_ids = train_data['gene_ids']
    train_ligand = train_data['ligand']
    train_receptor = train_data['receptor']
    train_lr_labels = train_data['lr_labels']

    num_batches = np.ceil(len(train_values)/batch_size).astype(int)
    for k in range(0, len(train_values), batch_size):
        batch = int(k/batch_size+1)
        if batch > max_batch:
            break
        with torch.cuda.amp.autocast(enabled=amp):
            output = model(train_values[k:k+batch_size].to(device), 
                           train_padding[k:k+batch_size].to(device), 
                           train_gene_ids[k:k+batch_size].to(device),
                           train_ligand[k:k+batch_size].to(device), 
                           train_receptor[k:k+batch_size].to(device), 
                           max_l_seqlen, 
                           max_r_seqlen)
            
            batch_train_lr_labels = train_lr_labels[k:k+batch_size].to(device)
            batch_logits = output[torch.logical_or(batch_train_lr_labels==1, batch_train_lr_labels==0)]
            batch_labels = batch_train_lr_labels[torch.logical_or(batch_train_lr_labels==1, batch_train_lr_labels==0)]
            loss_lrc = criterion_cls(batch_logits, batch_labels)

            error_rate_lrc = 1 - (
                    (batch_logits.argmax(1) == batch_labels)
                    .sum()
                    .item()
                ) / batch_labels.size(0)
            
        model.zero_grad()
        scaler.scale(loss_lrc).backward()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()

        total_lrc += loss_lrc.item()
        total_error += error_rate_lrc

        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            sec_per_batch = (time.time() - start_time) / log_interval
            cur_lrc = total_lrc / log_interval
            cur_error = total_error / log_interval
            print(f"| Split {split} | "
                f"{batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.8f} | sec/batch {sec_per_batch:5.1f} | "
                f"lrc {cur_lrc:5.5f} | "
                f"err {cur_error:1.5f} | "
            )
            total_lrc = 0
            total_error = 0
            start_time = time.time()
        if batch % (10*log_interval) == 0 and batch > 0:
            auc_value, fpr, tpr = evaluate(model, valid_data, batch_size)
            if auc_value > best_auc_value:
                best_auc_value = auc_value
                best_fpr = fpr
                best_tpr = tpr
                best_model = copy.deepcopy(model)
            model.train()
            start_time = time.time()

    return best_model, best_fpr, best_tpr

def py_softmax(vector):
	e = np.exp(vector)
	return e / e.sum()

def evaluate(model: nn.Module, valid_data, batch_size) -> None:
    amp = True
    
    model.eval()
    total_lrc = 0.0
    total_error = 0.0
    total_num = 0

    logits_list = []
    labels_list = []
    
    max_l_seqlen = valid_data['ligand'].sum(1).max().item()
    max_r_seqlen = valid_data['receptor'].sum(1).max().item()

    valid_values = valid_data['values']
    valid_padding = valid_data['padding']
    valid_gene_ids = valid_data['gene_ids']
    valid_ligand = valid_data['ligand']
    valid_receptor = valid_data['receptor']
    valid_lr_labels = valid_data['lr_labels']

    with torch.no_grad():
        for k in tqdm(range(0, len(valid_values), batch_size)):
            with torch.cuda.amp.autocast(enabled=amp):
                output = model(valid_values[k:k+batch_size].to(device), 
                           valid_padding[k:k+batch_size].to(device), 
                           valid_gene_ids[k:k+batch_size].to(device),
                           valid_ligand[k:k+batch_size].to(device), 
                           valid_receptor[k:k+batch_size].to(device), 
                           max_l_seqlen, 
                           max_r_seqlen)

                batch_valid_lr_labels = valid_lr_labels[k:k+batch_size].to(device)
                batch_logits = output[torch.logical_or(batch_valid_lr_labels==1, batch_valid_lr_labels==0)]
                batch_labels = batch_valid_lr_labels[torch.logical_or(batch_valid_lr_labels==1, batch_valid_lr_labels==0)]
                
                logits_list.append(batch_logits.cpu())
                labels_list.append(batch_labels.cpu())
            
            accuracy = (batch_logits.argmax(1) == batch_labels).sum().item()
            total_error += batch_labels.size(0) - accuracy
            total_num += batch_labels.size(0)
            total_lrc += criterion_cls(batch_logits, batch_labels).item()*batch_labels.size(0)

    logits = torch.cat(logits_list)
    labels = torch.cat(labels_list)

    y_score = [py_softmax(item)[1] for item in logits.numpy()]
    y_true = labels.numpy()
    fpr, tpr, _ = roc_curve(y_true, y_score)
    auc_value = auc(fpr, tpr)

    val_err = total_error / total_num
    val_loss = total_lrc / total_num
    print("-" * 89)
    print(
        f"valid accuracy: {1-val_err:1.4f} | "
        f"valid auc: {auc_value:1.4f} | "
        f"valid loss {val_loss:1.4f} | "
        f"valid err {val_err:1.4f}"
    )
    print("-" * 89)
    
    return auc_value, fpr, tpr


def train_and_evaluate(model, train_data, valid_data, batch_size, max_batch):

    best_model, best_fpr, best_tpr = train(model, train_data, valid_data, batch_size, max_batch)
    
    return best_model, best_fpr, best_tpr

In [5]:
class scFoundation(nn.Module):
    def __init__(
            self,
            scf_token_emb,
            scf_pos_emb,
            scf_encoder,
            scf_decoder,
            scf_decoder_embed,
            scf_norm,
            scf_to_final,
    ):
        super(scFoundation, self).__init__()

        # encoder
        self.token_emb = scf_token_emb
        self.pos_emb = scf_pos_emb

        # ## DEBUG
        self.encoder = scf_encoder

        ##### decoder
        self.decoder = scf_decoder
        self.decoder_embed = scf_decoder_embed
        self.norm = scf_norm
        self.to_final = scf_to_final

    def forward(self, x, padding_label, encoder_position_gene_ids, encoder_labels, decoder_data,
                decoder_position_gene_ids, decoder_data_padding_labels, **kwargs):

        # token and positional embedding
        x = self.token_emb(torch.unsqueeze(x, 2), output_weight = 0)

        position_emb = self.pos_emb(encoder_position_gene_ids)
        x += position_emb
        x = self.encoder(x, padding_mask=padding_label)

        decoder_data = self.token_emb(torch.unsqueeze(decoder_data, 2))
        position_emb = self.pos_emb(decoder_position_gene_ids)
        batch_idx, gen_idx = (encoder_labels == True).nonzero(as_tuple=True)
        decoder_data[batch_idx, gen_idx] = x[~padding_label].to(decoder_data.dtype)

        decoder_data += position_emb

        decoder_data = self.decoder_embed(decoder_data)
        x = self.decoder(decoder_data, padding_mask=decoder_data_padding_labels)

        x = self.norm(x)
        # return x
        x = self.to_final(x)
        return x.squeeze(2)

In [6]:
def initialize_model(model_file):
    if model_file is None:
        pretrainmodel, pretrainconfig = load.load_model_frommmf('scfoundation/models/models.ckpt')
    else:
        pretrainmodel = torch.load(f'scfoundation/fine-tuning/{model_file}', map_location='cpu')
        pretrainmodel = pretrainmodel.module

    model = scF_lrc(pretrainmodel.token_emb,
            pretrainmodel.pos_emb,
            pretrainmodel.encoder,
            d_model = 768,
            n_lrc = 2,
            nlayers_lrc = 3
            )
    
    pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
    for name, para in model.named_parameters():
            para.requires_grad = False
    for name, para in model.lrc_decoder.named_parameters():
            para.requires_grad = True
    post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
    print(f"Total Pre freeze Params {(pre_freeze_param_count )}")
    print(f"Total Post freeze Params {(post_freeze_param_count )}")

    return model

In [7]:
pad_token = 19266
pad_value = 103
tokenizer_dir = '/home/shcao/spFormer/spformer/tokenizer/'

In [None]:
model_file = None # 'model_human_myocardial_infarction.ckpt'

dataset = 'human_myocardial_infarction_dataset'
slide = '10X001'
data_path = f'../data/{dataset}/'
lr_path = 'gene_lists/human_LR_pairs.txt'
slideData = SlideData(data_path, slide, lr_path, pad_value, pad_token)
slideData.get_sc_data()
slideData.get_lr_pairs()

sample_num = 5000
slideData.prepare_data(sample_num)

batch_size = 6
max_batch = 500

n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
all_tpr = []
all_roc_auc = []
all_tpr_wt = []

split = 0
for train_index, valid_index in skf.split(slideData.ligand_receptor_ids, slideData.ligand_receptor_labels):
    split += 1
    print(f"Cross-validate on dataset {dataset} slide {slide} - split {split}")
    train_data, valid_data = slideData.prepare_train_and_valid_data(train_index, valid_index)
    
    model = initialize_model(model_file)
    model = nn.DataParallel(model, device_ids = [1, 3, 0])
    device = torch.device("cuda:1")
    model.to(device)

    best_model, best_fpr, best_tpr = train_and_evaluate(model, train_data, valid_data, batch_size, max_batch)
    
    mean_fpr = np.linspace(0, 1, 100)
    interp_tpr = np.interp(mean_fpr, best_fpr, best_tpr)
    interp_tpr[0] = 0.0
    all_tpr.append(interp_tpr)
    all_roc_auc.append(auc(best_fpr, best_tpr))
    all_tpr_wt.append(len(best_tpr))

In [None]:
import math 

def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
    wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]
    print(wts)

    all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]
    mean_tpr = np.sum(all_weighted_tpr, axis=0)
    mean_tpr[-1] = 1.0
    median_tpr = np.median(all_tpr, axis=0)
    median_tpr[-1] = 1.0

    all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]
    roc_auc_mean = np.sum(all_weighted_roc_auc)
    roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc_mean)**2, weights=wts))

    roc_auc_median = auc(mean_fpr, median_tpr)

    return mean_tpr, median_tpr, roc_auc_mean, roc_auc_sd, roc_auc_median, wts

mean_tpr, median_tpr, roc_auc_mean, roc_auc_sd, roc_auc_median, wts = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)
print(f"Mean ROC AUC: {roc_auc_mean} +/- {roc_auc_sd}")
cv_results = {'roc_auc_mean':roc_auc_mean, 'roc_auc_sd':roc_auc_sd, 'roc_auc_median':roc_auc_median, 'mean_fpr':mean_fpr, 'mean_tpr':mean_tpr, 'median_tpr':median_tpr, 'all_roc_auc':all_roc_auc, 'wts':wts}

pickle.dump(cv_results, open(f'roc_results/scf-pt_lrc_{slide}_random2.pkl', 'wb'))

In [None]:
import matplotlib.pyplot as plt

def plot_ROC(bundled_data, title):
    plt.figure()
    lw = 2
    for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
        plt.plot(mean_fpr, mean_tpr, color=color,
                 lw=lw, label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(sample, roc_auc, roc_auc_sd))
    plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc="lower right")
    plt.show()

bundled_data = [(roc_auc_mean, roc_auc_sd, mean_fpr, mean_tpr, "scf-ft", "blue")]

plot_ROC(bundled_data, 'Gene classification')