In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
from typing import Dict
import warnings
warnings.filterwarnings('ignore')
from tqdm import tqdm
import matplotlib.pyplot as plt
import scanpy as sc
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

from stformer import logger
from stformer.tokenizer import GeneVocab
from stformer.tokenizer import tokenize_and_pad_batch_2

In [None]:
class Tokenizer():
    def __init__(self, tokenizer_dir, adata, vocab, pad_value, pad_token):
        self.tokenizer_dir = tokenizer_dir
        self.adata = adata
        self.vocab = vocab
        self.pad_value = pad_value
        self.pad_token = pad_token
        self.load_data()
    
    def load_data(self):
        self.expression_matrix = self.adata.X.A
        self.niche_ligands_expression = self.adata.obsm['niche_ligands_expression'].A
        self.niche_composition = self.adata.obsm['niche_composition'].A

        gene_list_df = pd.read_csv(f'{self.tokenizer_dir}/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
        gene_list = list(gene_list_df['gene_name'])
        self.gene_ids = np.array(self.vocab(gene_list), dtype=int)

        ligand_database = pd.read_csv(self.tokenizer_dir+'ligand_database.csv', header=0, index_col=0)
        ligand_symbol = ligand_database[ligand_database.sum(1)>1].index.values
        ligand_symbol = gene_list_df.loc[gene_list_df['gene_name'].isin(ligand_symbol), 'gene_name'].values
        self.ligand_ids = np.array(self.vocab(ligand_symbol.tolist())*25, dtype=int)

    def tokenize_data(self):
        biases = np.zeros([self.niche_composition.shape[0], self.niche_composition.shape[1]*986])
        for k in range(biases.shape[0]):
            biases[k] = np.concatenate([[np.log(p)]*986 for p in self.niche_composition[k]])

        tokenized_data = tokenize_and_pad_batch_2(
            self.expression_matrix,
            self.niche_ligands_expression,
            biases,
            self.gene_ids,
            self.ligand_ids,
            pad_id = self.vocab[self.pad_token],
            pad_value = self.pad_value,
        )

        logger.info(
            f"tokenize sample number: {tokenized_data['center_genes'].shape[0]}, "
            f"\n\t feature length of center cell: {tokenized_data['center_genes'].shape[1]}"
            f"\n\t feature length of niche cells: {tokenized_data['niche_genes'].shape[1]}"
        )

        self.tokenized_data = tokenized_data

    def prepare_data(self):
        self.data_pt = {
            "center_gene_ids": self.tokenized_data["center_genes"],
            "input_center_values": self.tokenized_data["center_values"],
            "niche_gene_ids": self.tokenized_data["niche_genes"],
            "input_niche_values": self.tokenized_data["niche_values"],
            "cross_attn_bias": self.tokenized_data["cross_attn_bias"],
        }
    
    def prepare_dataloader(self, batch_size):
        data_loader = DataLoader(
            dataset=SeqDataset(self.data_pt),
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=min(len(os.sched_getaffinity(0)), batch_size // 2),
            pin_memory=True,
        )
        return data_loader


class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["center_gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}

In [None]:
def evaluate(model: nn.Module, loader: DataLoader, mode) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    cell_embeddings = []

    with torch.no_grad():
        for batch_data in tqdm(loader):
            center_gene_ids = batch_data["center_gene_ids"].to(device)
            input_center_values = batch_data["input_center_values"].to(device)
            niche_gene_ids = batch_data["niche_gene_ids"].to(device)
            input_niche_values = batch_data["input_niche_values"].to(device)
            cross_attn_bias = batch_data["cross_attn_bias"].to(device)

            if mode == 'sp':
                encoder_src_key_padding_mask = niche_gene_ids.eq(vocab[pad_token])
            elif mode == 'sc':
                encoder_src_key_padding_mask = torch.ones_like(niche_gene_ids, dtype=torch.bool).to(device)
            decoder_src_key_padding_mask = center_gene_ids.eq(vocab[pad_token])

            with torch.cuda.amp.autocast(enabled=amp):
                output_dict = model(
                        niche_gene_ids,
                        input_niche_values,
                        encoder_src_key_padding_mask,
                        center_gene_ids,
                        input_center_values,
                        decoder_src_key_padding_mask,
                        cross_attn_bias,
                    )
                cell_embeddings.append(output_dict['cell_emb'].to('cpu'))
    
    return(torch.cat(cell_embeddings))

In [None]:
embsize = 768
d_hid = 3072
nhead = 12
nlayers = 6
dropout = 0.1
cell_emb_style = 'max-pool'

from stformer.model import TransformerModel
from tasks.scfoundation import load
import copy

pretrainmodel, pretrainconfig = load.load_model_frommmf('scfoundation/models/models.ckpt')

model = TransformerModel(
    embsize,
    nhead,
    d_hid,
    nlayers,
    dropout = dropout,
    cell_emb_style = cell_emb_style,
    scfoundation_token_emb1 = copy.deepcopy(pretrainmodel.token_emb),
    scfoundation_token_emb2 = copy.deepcopy(pretrainmodel.token_emb),
    scfoundation_pos_emb1 = copy.deepcopy(pretrainmodel.pos_emb),
    scfoundation_pos_emb2 = copy.deepcopy(pretrainmodel.pos_emb),
)

del pretrainmodel

In [None]:
model = torch.load(f'../pretraining/models/model_4.1M.ckpt', map_location='cpu')

In [None]:
device = torch.device("cuda:2")

model = nn.DataParallel(model, device_ids = [2, 0, 3, 1])
model.to(device)

In [None]:
pad_token = "<pad>"
pad_value = 103
mode = 'sp'

tokenizer_dir = '../stformer/tokenizer/'
vocab_file = tokenizer_dir + "scfoundation_gene_vocab.json"
vocab = GeneVocab.from_file(vocab_file)
vocab.append_token(pad_token)
vocab.set_default_index(vocab[pad_token])

In [None]:
slide = '10X001'

adata = sc.read_h5ad(f'../datasets/{slide}_niche.h5ad')

adata

In [None]:
batch_size = 50
amp = True

tokenizer = Tokenizer(tokenizer_dir, adata, vocab, pad_value, pad_token)
tokenizer.tokenize_data()
tokenizer.prepare_data()
data_loader = tokenizer.prepare_dataloader(batch_size)

In [None]:
cell_embeddings = evaluate(model, data_loader, mode)

In [None]:
adata.obsm['cell_embeddings'] = cell_embeddings.numpy()
adata.obs['cell_type'] = adata.obs['cell_type'].map(dict(zip(np.array(range(len(adata.uns['cell_types_list'])))+1, adata.uns['cell_types_list'])))
adata.obs['cell_type'] = adata.obs['cell_type'].astype('category')
adata

In [None]:
sc.pp.neighbors(adata, use_rep='cell_embeddings')
sc.tl.umap(adata)
sc.pl.umap(adata,
           title='',
        #    frameon=False,
        #    legend_loc='',
        #    legend_fontsize='xx-small',
           color='cell_type')

ax = plt.gca()
ax.set_xlabel('')
ax.set_ylabel('')
plt.savefig(f'figures/cell_clustering/umap_4M_pt_{slide}.pdf', bbox_inches='tight')

In [None]:
import scib

adata.obs['batch'] = [1]*adata.shape[0]

results = scib.metrics.metrics(
        adata,
        adata_int=adata,
        label_key='cell_type',
        batch_key='batch',
        embed="cell_embeddings",
        silhouette_=True,
        nmi_=True,  
        ari_=True, 
    )

result_dict = results[0].to_dict()