In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import warnings
warnings.filterwarnings('ignore')
from collections import Counter
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import auc, roc_curve
import anndata
import scanpy as sc
import torch
from torch import nn
from scfoundation import load

In [None]:
class Tokenizer():
    def __init__(self, adata, pad_value, pad_token):
        self.adata = adata
        self.pad_value = pad_value
        self.pad_token = pad_token

    def prepare_data(self):
        gexpr_feature = self.adata.X.A
        S = gexpr_feature.sum(1)
        T = S
        TS = np.concatenate([[np.log10(T)],[np.log10(S)]],axis=0).T
        data = np.concatenate([gexpr_feature,TS],axis=1)
        self.data = data

    def tokenize_data(self):
        data = torch.from_numpy(self.data).float()
        data_gene_ids = torch.arange(data.shape[1]).repeat(data.shape[0], 1)

        data_index = data != 0
        gene_values, gene_padding = load.gatherData(data, data_index, self.pad_value)
        gene_ids, _ = load.gatherData(data_gene_ids, data_index, self.pad_token)
        data = {'values': gene_values, 'padding': gene_padding, 'gene_ids': gene_ids}

        return data

In [None]:
class scF_Ccl(nn.Module):
    def __init__(
            self,
            scf_token_emb,
            scf_pos_emb,
            scf_encoder,
    ):
        super(scF_Ccl, self).__init__()

        # encoder
        self.token_emb = scf_token_emb
        self.pos_emb = scf_pos_emb
        self.encoder = scf_encoder

    def forward(self, gene_values, padding_label, gene_ids):

        x = self.token_emb(torch.unsqueeze(gene_values, 2), output_weight = 0)

        position_emb = self.pos_emb(gene_ids)
        x += position_emb
        x = self.encoder(x, padding_mask=padding_label)
        geneembmerge = torch.cat([torch.max(x[k][~padding_label[k]], dim=0)[0].unsqueeze(0) for k in range(x.size(0))])
        # geneembmerge, _ = torch.max(x, dim=1)

        return geneembmerge

In [None]:
def evaluate(model: nn.Module, data, batch_size) -> None:
    model.eval()

    gene_values = data['values']
    gene_padding = data['padding']
    gene_ids = data['gene_ids']

    cell_embeddings = []

    with torch.no_grad():
        for k in tqdm(range(0, len(gene_values), batch_size)):
            with torch.cuda.amp.autocast(enabled=amp):
                geneembmerge = model(gene_values[k:k+batch_size].to(device), 
                               gene_padding[k:k+batch_size].to(device), 
                               gene_ids[k:k+batch_size].to(device))
                
                cell_embeddings.append(geneembmerge.to('cpu'))
    
    return(torch.cat(cell_embeddings))

In [None]:
class scFoundation(nn.Module):
    def __init__(
            self,
            scf_token_emb,
            scf_pos_emb,
            scf_encoder,
            scf_decoder,
            scf_decoder_embed,
            scf_norm,
            scf_to_final,
    ):
        super(scFoundation, self).__init__()

        # encoder
        self.token_emb = scf_token_emb
        self.pos_emb = scf_pos_emb

        # ## DEBUG
        self.encoder = scf_encoder

        ##### decoder
        self.decoder = scf_decoder
        self.decoder_embed = scf_decoder_embed
        self.norm = scf_norm
        self.to_final = scf_to_final

    def forward(self, x, padding_label, encoder_position_gene_ids, encoder_labels, decoder_data,
                decoder_position_gene_ids, decoder_data_padding_labels, **kwargs):

        # token and positional embedding
        x = self.token_emb(torch.unsqueeze(x, 2), output_weight = 0)

        position_emb = self.pos_emb(encoder_position_gene_ids)
        x += position_emb
        x = self.encoder(x, padding_mask=padding_label)

        decoder_data = self.token_emb(torch.unsqueeze(decoder_data, 2))
        position_emb = self.pos_emb(decoder_position_gene_ids)
        batch_idx, gen_idx = (encoder_labels == True).nonzero(as_tuple=True)
        decoder_data[batch_idx, gen_idx] = x[~padding_label].to(decoder_data.dtype)

        decoder_data += position_emb

        decoder_data = self.decoder_embed(decoder_data)
        x = self.decoder(decoder_data, padding_mask=decoder_data_padding_labels)

        x = self.norm(x)
        # return x
        x = self.to_final(x)
        return x.squeeze(2)

In [None]:
model_type = 'pretrained' #'fine-tuned'
#model_file = 'ft-scf-10X001.ckpt'

if model_type == 'pretrained':
    pretrainmodel, pretrainconfig = load.load_model_frommmf('scfoundation/models/models.ckpt')
elif model_type == 'fine-tuned':
    pretrainmodel = torch.load(f'fine-tuning_mse/models/{model_file}', map_location='cpu')

In [None]:
model = scF_Ccl(pretrainmodel.token_emb,
            pretrainmodel.pos_emb,
            pretrainmodel.encoder)

In [None]:
device = torch.device("cuda:2")

model = nn.DataParallel(model, device_ids = [2, 3, 0, 1])
model.to(device)

In [None]:
pad_token = 19266
pad_value = 103

tokenizer_dir = '../tokenizer'

In [None]:
slide = '10X001'
adata = sc.read_h5ad(f'../datasets/{slide}_niche.h5ad')

In [None]:
batch_size = 50
amp = True

tokenizer = Tokenizer(adata, pad_value, pad_token)
tokenizer.prepare_data()
data = tokenizer.tokenize_data()

In [None]:
cell_embeddings = evaluate(model,data, batch_size)
tokenizer.adata.obsm['cell_embeddings'] = cell_embeddings.numpy()

In [None]:
tokenizer.adata.obs['cell_type'] = tokenizer.adata.obs['cell_type'].map(dict(zip(np.array(range(len(tokenizer.adata.uns['cell_types_list'])))+1, tokenizer.adata.uns['cell_types_list'])))
tokenizer.adata.obs['cell_type'] = tokenizer.adata.obs['cell_type'].astype('category')

In [None]:
sc.pp.neighbors(tokenizer.adata, use_rep='cell_embeddings')
sc.tl.umap(tokenizer.adata)
sc.pl.umap(tokenizer.adata,
         #   title='',
         #   frameon=False,
         #   legend_loc='',
         #   legend_fontsize='xx-small',
           color='cell_type')

ax = plt.gca()
ax.set_xlabel('')
ax.set_ylabel('')
plt.savefig(f'figures/cell_clustering/umap_scf_pt_{slide}.pdf', bbox_inches='tight')

In [None]:
import scib

tokenizer.adata.obs['batch'] = [1]*tokenizer.adata.shape[0]

results = scib.metrics.metrics(
        tokenizer.adata,
        adata_int=tokenizer.adata,
        label_key='cell_type',
        batch_key='batch',
        embed='cell_embeddings',
        silhouette_=True,
        nmi_=True,
        ari_=True,
    )

result_dict = results[0].to_dict()