In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import warnings
warnings.filterwarnings('ignore')
import pickle
import time
import copy
from tqdm import tqdm
import numpy as np
import pandas as pd
import random
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import auc, roc_curve, precision_recall_curve
import anndata
import scanpy as sc
import torch
from torch import nn, Tensor
from scfoundation import load
from stformer.tokenizer import GeneVocab

In [None]:
class Tokenizer():
    def __init__(self, tokenizer_dir, adata, label_path, vocab, pad_value, pad_token, pathway_ligands=None):
        self.tokenizer_dir = tokenizer_dir
        self.adata = adata
        self.label_path = label_path
        self.vocab = vocab
        self.pad_value = pad_value
        self.pad_token = pad_token
        self.pathway_ligands = pathway_ligands

    def prepare_data(self):
        gene_list_df = pd.read_csv(f'{self.tokenizer_dir}/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
        gene_list = list(gene_list_df['gene_name'])
        self.gene_ids = np.array(self.vocab(gene_list), dtype=int)

        ligand_database = pd.read_csv(self.tokenizer_dir+'ligand_database.csv', header=0, index_col=0)
        ligand_symbol = ligand_database[ligand_database.sum(1)>1].index.values
        ligand_symbol = gene_list_df.loc[gene_list_df['gene_name'].isin(ligand_symbol), 'gene_name'].values
        self.ligand_ids = np.array(self.vocab(ligand_symbol.tolist())*25, dtype=int)

        df_label = pd.read_csv(f'{self.label_path}', header=0)
        selected_gene_index = sc.pp.filter_genes(self.adata, min_cells=100, inplace=False)
        selected_gene_names = self.adata.var_names.values[selected_gene_index[0]]
        df_label = df_label[df_label['gene_symbols'].isin(selected_gene_names)]

        selected_cell_index = []
        if self.pathway_ligands is not None:
            pathway_ligands = list(set(self.pathway_ligands).intersection(set(df_label.loc[df_label['class']=='outer_member', 'gene_symbols'])))
            pathway_ligands = set(self.vocab(pathway_ligands))
        else:
            pathway_ligands = set(self.vocab(list(df_label.loc[df_label['class']=='outer_member', 'gene_symbols'])))
        
        niche_ligands_expression = self.adata.obsm['niche_ligands_expression'].A
        for i in tqdm(range(niche_ligands_expression.shape[0])):
            if len(set(self.ligand_ids[np.nonzero(niche_ligands_expression[i])[0]]).intersection(pathway_ligands))>=0:
                selected_cell_index.append(i)

        gexpr_feature = self.adata.X[selected_cell_index].A
        S = gexpr_feature.sum(1)
        T = S
        TS = np.concatenate([[np.log10(T)],[np.log10(S)]],axis=0).T
        data = np.concatenate([gexpr_feature,TS],axis=1)
        self.data = data
        print(f"Selected cell number: {len(selected_cell_index): 6d}/{self.adata.shape[0]: 6d}")

        member_genes = df_label.loc[df_label['class']=='inner_member', 'gene_symbols'].values.tolist()
        random.seed(0)
        nonmember_genes = random.sample(sorted(set(selected_gene_names).difference(set(df_label['gene_symbols']))), len(member_genes))
        gene_targets = member_genes + nonmember_genes
        gene_targets = np.array(self.vocab(gene_targets))
        gene_labels = [1]*len(member_genes)+[0]*len(nonmember_genes)
        
        gene2label = dict(zip(gene_targets, gene_labels))
        for g in self.gene_ids:
            if g not in gene2label:
                gene2label[g] = -100
        gene2label[19264] = gene2label[19265] = gene2label[19266] = -100

        self.gene_targets = gene_targets
        self.gene_labels = gene_labels
        self.gene2label = gene2label

    def prepare_train_and_valid_data(self, train_index, valid_index):
        gene_targets_train = self.gene_targets[train_index]
        gene_targets_valid = self.gene_targets[valid_index]

        gene2label_train = self.gene2label.copy()
        for gene in gene2label_train:
            if gene not in gene_targets_train:
                gene2label_train[gene] = -100
        gene2label_valid = self.gene2label.copy()
        for gene in gene2label_valid:
            if gene not in gene_targets_valid:
                gene2label_valid[gene] = -100

        data = self.data
        train_data = [d for d in data if len(set(np.nonzero(d)[0]).intersection(set(gene_targets_train)))>0]
        valid_data = [d for d in data if len(set(np.nonzero(d)[0]).intersection(set(gene_targets_valid)))>0]
        train_data = torch.from_numpy(np.array(train_data)).float()
        valid_data = torch.from_numpy(np.array(valid_data)).float()
        train_data_gene_ids = torch.arange(train_data.shape[1]).repeat(train_data.shape[0], 1)
        valid_data_gene_ids = torch.arange(valid_data.shape[1]).repeat(valid_data.shape[0], 1)

        train_data_index = train_data != 0
        train_values, train_padding = load.gatherData(train_data, train_data_index, self.pad_value)
        train_gene_ids, _ = load.gatherData(train_data_gene_ids, train_data_index, self.pad_token)
        train_gene_labels = torch.from_numpy(np.vectorize(gene2label_train.get)(train_gene_ids.numpy()))
        train_data = {'values': train_values, 'padding': train_padding, 'gene_ids': train_gene_ids, 'gene_labels': train_gene_labels}

        valid_data_index = valid_data != 0
        valid_values, valid_padding = load.gatherData(valid_data, valid_data_index, self.pad_value)
        valid_gene_ids, _ = load.gatherData(valid_data_gene_ids, valid_data_index, self.pad_token)
        valid_gene_labels = torch.from_numpy(np.vectorize(gene2label_valid.get)(valid_gene_ids.numpy()))
        valid_data = {'values': valid_values, 'padding': valid_padding, 'gene_ids': valid_gene_ids, 'gene_labels': valid_gene_labels}

        return train_data, valid_data

In [None]:
class scF_Gcl(nn.Module):
    def __init__(
            self,
            scf_token_emb,
            scf_pos_emb,
            scf_encoder,
            d_model: int,
            n_gcl: int = 2,
            nlayers_gcl: int = 3,
    ):
        super(scF_Gcl, self).__init__()

        # encoder
        self.token_emb = scf_token_emb
        self.pos_emb = scf_pos_emb
        self.encoder = scf_encoder
        self.gcl_decoder = GclDecoder(d_model, n_gcl, nlayers=nlayers_gcl)

    def forward(self, gene_values, padding_label, gene_ids):

        # token and positional embedding
        x = self.token_emb(torch.unsqueeze(gene_values, 2), output_weight = 0)

        position_emb = self.pos_emb(gene_ids)
        x += position_emb
        x = self.encoder(x, padding_mask=padding_label)
        output = self.gcl_decoder(x)

        return output


class GclDecoder(nn.Module):
    """
    Decoder for gene classification task.
    """

    def __init__(
        self,
        d_model: int,
        n_gcl: int,
        nlayers: int = 3,
        activation: callable = nn.ReLU,
    ):
        super().__init__()
        # module list
        self._decoder = nn.ModuleList()
        for i in range(nlayers - 1):
            self._decoder.append(nn.Linear(d_model, d_model))
            self._decoder.append(activation())
            self._decoder.append(nn.LayerNorm(d_model))
        self.out_layer = nn.Linear(d_model, n_gcl)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embsize]
        """
        for layer in self._decoder:
            x = layer(x)
        return self.out_layer(x)

In [None]:
criterion_cls = nn.CrossEntropyLoss()

def train(model: nn.Module, train_data, batch_size, scaler, optimizer, scheduler, log_interval, epoch) -> None:
    """
    Train the model for one epoch.
    """
    amp = True

    model.train()
    total_gcl = 0.0
    total_error = 0.0

    start_time = time.time()

    train_values = train_data['values']
    train_padding = train_data['padding']
    train_gene_ids = train_data['gene_ids']
    train_gene_labels = train_data['gene_labels']

    num_batches = np.ceil(len(train_values)/batch_size).astype(int)
    for k in range(0, len(train_values), batch_size):
        batch = int(k/batch_size+1)
        if k+batch_size>len(train_values):
            break
        with torch.cuda.amp.autocast(enabled=amp):
            output = model(train_values[k:k+batch_size].to(device), 
                           train_padding[k:k+batch_size].to(device), 
                           train_gene_ids[k:k+batch_size].to(device))
            
            batch_train_gene_labels = train_gene_labels[k:k+batch_size].to(device)
            batch_logits = output[torch.logical_or(batch_train_gene_labels==1, batch_train_gene_labels==0)]
            batch_labels = batch_train_gene_labels[torch.logical_or(batch_train_gene_labels==1, batch_train_gene_labels==0)]
            targets = train_gene_ids[k:k+batch_size].to(device)[torch.logical_or(batch_train_gene_labels==1, batch_train_gene_labels==0)]
            loss_gcl = 0.0
            for t in set(targets):
                loss_gcl += criterion_cls(batch_logits[targets==t], batch_labels[targets==t])
            loss_gcl = loss_gcl/len(set(targets))

            error_rate_gcl = 1 - (
                    (batch_logits.argmax(1) == batch_labels)
                    .sum()
                    .item()
                ) / batch_labels.size(0)
            
        model.zero_grad()
        scaler.scale(loss_gcl).backward()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()

        total_gcl += loss_gcl.item()
        total_error += error_rate_gcl
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            sec_per_batch = (time.time() - start_time) / log_interval
            cur_gcl = total_gcl / log_interval
            cur_error = total_error / log_interval
            print(f"| Split {split} | epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.8f} | sec/batch {sec_per_batch:5.1f} | "
                f"gcl {cur_gcl:5.5f} | "
                f"acc {1-cur_error:1.5f} | "
            )
            total_gcl = 0
            total_error = 0
            start_time = time.time()

def py_softmax(vector):
	e = np.exp(vector)
	return e / e.sum()

def evaluate(model: nn.Module, valid_data, batch_size, curve) -> None:
    amp = True
    
    model.eval()
    total_gcl = 0.0
    total_error = 0.0
    total_num = 0
    
    logits = []
    labels = []

    valid_values = valid_data['values']
    valid_padding = valid_data['padding']
    valid_gene_ids = valid_data['gene_ids']
    valid_gene_labels = valid_data['gene_labels']

    with torch.no_grad():
        for k in tqdm(range(0, len(valid_values), batch_size)):
            if k+batch_size>len(valid_values):
                break
            with torch.cuda.amp.autocast(enabled=amp):
                output = model(valid_values[k:k+batch_size].to(device), 
                           valid_padding[k:k+batch_size].to(device), 
                           valid_gene_ids[k:k+batch_size].to(device))
            
                batch_valid_gene_labels = valid_gene_labels[k:k+batch_size].to(device)
                batch_logits = output[torch.logical_or(batch_valid_gene_labels==1, batch_valid_gene_labels==0)]
                batch_labels = batch_valid_gene_labels[torch.logical_or(batch_valid_gene_labels==1, batch_valid_gene_labels==0)]
                logits.append(batch_logits.to('cpu'))
                labels.append(batch_labels.to('cpu'))
            
            accuracy = (batch_logits.argmax(1) == batch_labels).sum().item()
            total_error += batch_labels.size(0) - accuracy
            total_num += batch_labels.size(0)
            total_gcl += criterion_cls(batch_logits, batch_labels).item()*batch_labels.size(0) 

    logits = torch.concat(logits).float()
    labels = torch.concat(labels)

    y_score = [py_softmax(item)[1] for item in logits.numpy()]
    y_true = labels.numpy()

    if curve == 'roc':
        fpr, tpr, _ = roc_curve(y_true, y_score)
        auc_value = auc(fpr, tpr)
    elif curve == 'prc':
        precision, recall, _ = precision_recall_curve(y_true, y_score)
        auc_value = auc(recall, precision)

    val_err = total_error / total_num
    val_loss = total_gcl / total_num
    print("-" * 89)
    print(
        f"valid accuracy: {1-val_err:1.4f} | "
        f"valid auc: {auc_value:1.4f} | "
        f"valid loss: {val_loss:1.4f} | "
    )
    print("-" * 89)

    if curve == 'roc':
        return auc_value, fpr, tpr
    elif curve == 'prc':
        return auc_value, recall, precision


def train_and_evaluate(model, train_loader, valid_loader, batch_size, epochs, mean_curve_x, curve):
    lr = 1e-4
    amp = True
    schedule_ratio = 0.9
    schedule_interval = 1
    log_interval = 10

    optimizer = torch.optim.Adam(
            model.parameters(), lr=lr, eps=1e-4 if amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, schedule_interval, gamma=schedule_ratio
    )
    scaler = torch.cuda.amp.GradScaler(enabled=amp)

    best_val_auc = 0
    best_curve_x = 0
    best_curve_y = 0

    for epoch in range(1, epochs+1):
        train(model, train_loader, batch_size, scaler, optimizer, scheduler, log_interval, epoch)

        val_auc, curve_x, curve_y = evaluate(model, valid_loader, batch_size, curve)
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_curve_x = curve_x
            best_curve_y = curve_y
            print(f"Best model with auc {best_val_auc:1.4f}")

        scheduler.step()
    
    if curve == 'roc':
        interp_curve_y = np.interp(mean_curve_x, best_curve_x, best_curve_y)
        interp_curve_y[0] = 0.0
    elif curve == 'prc':
        interp_curve_y = np.interp(mean_curve_x, best_curve_x[::-1], best_curve_y[::-1])
        interp_curve_y[0] = 1.0
    
    return interp_curve_y, best_curve_x, best_curve_y

In [None]:
class scFoundation(nn.Module):
    def __init__(
            self,
            scf_token_emb,
            scf_pos_emb,
            scf_encoder,
            scf_decoder,
            scf_decoder_embed,
            scf_norm,
            scf_to_final,
    ):
        super(scFoundation, self).__init__()

        # encoder
        self.token_emb = scf_token_emb
        self.pos_emb = scf_pos_emb

        # ## DEBUG
        self.encoder = scf_encoder

        ##### decoder
        self.decoder = scf_decoder
        self.decoder_embed = scf_decoder_embed
        self.norm = scf_norm
        self.to_final = scf_to_final

    def forward(self, x, padding_label, encoder_position_gene_ids, encoder_labels, decoder_data,
                decoder_position_gene_ids, decoder_data_padding_labels, **kwargs):

        # token and positional embedding
        x = self.token_emb(torch.unsqueeze(x, 2), output_weight = 0)

        position_emb = self.pos_emb(encoder_position_gene_ids)
        x += position_emb
        x = self.encoder(x, padding_mask=padding_label)

        decoder_data = self.token_emb(torch.unsqueeze(decoder_data, 2))
        position_emb = self.pos_emb(decoder_position_gene_ids)
        batch_idx, gen_idx = (encoder_labels == True).nonzero(as_tuple=True)
        decoder_data[batch_idx, gen_idx] = x[~padding_label].to(decoder_data.dtype)

        decoder_data += position_emb

        decoder_data = self.decoder_embed(decoder_data)
        x = self.decoder(decoder_data, padding_mask=decoder_data_padding_labels)

        x = self.norm(x)
        # return x
        x = self.to_final(x)
        return x.squeeze(2)

In [None]:
def initialize_model(model_file):
    if model_file is None:
        # load pretrained model
        pretrainmodel, pretrainconfig = load.load_model_frommmf('scfoundation/models/models.ckpt')
    else:
        # load fine-tuned model
        pretrainmodel = torch.load(f'fine-tuning_mse/models/{model_file}', map_location='cpu')
        pretrainmodel = pretrainmodel.module
    
    model = scF_Gcl(pretrainmodel.token_emb,
            pretrainmodel.pos_emb,
            pretrainmodel.encoder,
            d_model = 768,
            n_gcl = 2,
            nlayers_gcl = 3
            )
    
    pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
    for name, para in model.named_parameters():
        para.requires_grad = False
    for name, para in model.encoder.transformer_encoder[10:12].named_parameters():
        para.requires_grad = True
    for name, para in model.gcl_decoder.named_parameters():
        para.requires_grad = True
    post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
    print(f"Total Pre freeze Params {(pre_freeze_param_count )}")
    print(f"Total Post freeze Params {(post_freeze_param_count )}")

    return model

In [None]:
pad_token = 19266
pad_value = 103
tokenizer_dir = '../stformer/tokenizer/'
vocab_file = tokenizer_dir + "scfoundation_gene_vocab.json"
vocab = GeneVocab.from_file(vocab_file)
vocab.append_token("<pad>")
vocab.set_default_index(vocab["<pad>"])

curve = 'prc'

In [None]:
model_file = None
label_path = 'gene_lists/TGF-beta_signaling_pathway_KEGG.csv'
pathway_ligands = ['BMP2','BMP4','BMP5','BMP6','BMP7','BMP8A','BMP8B','GDF5','GDF6','GDF7','INHBA','INHBB','INHBC','INHBE','TGFB1','TGFB2','TGFB3','NODAL']

dataset = 'human_myocardial_infarction'
slide = 'ACH005'
adata = sc.read_h5ad(f'../datasets/{slide}_niche.h5ad')

# selected_celltype = 'Fibroblast'
# adata = adata[adata.obs['cell_type']==list(adata.uns['cell_types_list']).index(selected_celltype)+1]
np.random.seed(0)
shuffled_indices = np.random.permutation(adata.n_obs)
adata = adata[shuffled_indices]
# adata = adata[:2000]

tokenizer = Tokenizer(tokenizer_dir, adata, label_path, vocab, pad_value, pad_token, pathway_ligands)
tokenizer.prepare_data()

epochs = 2
batch_size = 30

n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
mean_curve_x = np.linspace(0, 1, 100)
all_curve_y = []
all_auc = []
all_wt = []

split = 0
for train_index, valid_index in skf.split(tokenizer.gene_targets, tokenizer.gene_labels):
    split += 1
    print(f"Cross-validate on dataset {dataset} slide {slide} - split {split}")
    train_data, valid_data = tokenizer.prepare_train_and_valid_data(train_index, valid_index)

    model = initialize_model(model_file)
    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
    device = torch.device("cuda:0")
    model.to(device)

    interp_curve_y, best_curve_x, best_curve_y = train_and_evaluate(model, train_data, valid_data, batch_size, epochs, mean_curve_x, curve)
    all_curve_y.append(interp_curve_y)
    all_auc.append(auc(best_curve_x, best_curve_y))
    all_wt.append(len(best_curve_y))

In [None]:
import math 

def get_cross_valid_metrics(all_curve_y, all_auc, all_wt, curve):
    wts = [count/sum(all_wt) for count in all_wt]
    print(wts)
    all_weighted_curve_y = [a*b for a,b in zip(all_curve_y, wts)]
    mean_curve_y = np.sum(all_weighted_curve_y, axis=0)
    if curve == 'roc':
        mean_curve_y[-1] = 1.0
    all_weighted_auc = [a*b for a,b in zip(all_auc, wts)]
    auc_mean = np.sum(all_weighted_auc)
    auc_sd = math.sqrt(np.average((all_auc-auc_mean)**2, weights=wts))
    return mean_curve_y, auc_mean, auc_sd, wts

mean_curve_y, auc_mean, auc_sd, wts = get_cross_valid_metrics(all_curve_y, all_auc, all_wt, curve)

print(f"Mean AUC: {auc_mean} +/- {auc_sd}")
cv_results = {'auc_mean':auc_mean, 'auc_sd':auc_sd, 'mean_curve_x':mean_curve_x, 'mean_curve_y':mean_curve_y, 'all_auc':all_auc, 'wts':wts, 'all_curve_y':all_curve_y}
pickle.dump(cv_results, open(f'figures/gene_classification/scf_gcl_{slide}_nfkb_{curve}.pkl', 'wb'))

In [None]:
import matplotlib.pyplot as plt

def plot_ROC(bundled_data, title, curve):
    plt.figure()
    lw = 2
    for auc_mean, auc_sd, mean_curve_x, mean_curve_y, sample, color in bundled_data:
        plt.plot(mean_curve_x, mean_curve_y, color=color,
                 lw=lw, label="{0} (AUC {1:0.3f} $\pm$ {2:0.3f})".format(sample, auc_mean, auc_sd))
    if curve == 'roc':
        plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
    elif curve == 'prc':
        plt.xlabel('Recall')
        plt.ylabel('Precision')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.title(title)
    plt.legend(loc="lower right")
    plt.show()

bundled_data = [(auc_mean, auc_sd, mean_curve_x, mean_curve_y, "scFoundation", "red")]

plot_ROC(bundled_data, 'Gene classification', curve)