In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
from typing import Dict
import time
import warnings
warnings.filterwarnings('ignore')
import copy
import pickle
from tqdm import tqdm
import scanpy as sc
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import auc, roc_curve, confusion_matrix
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import matplotlib.pyplot as plt
import seaborn as sns

from stformer import logger
from stformer.tokenizer import GeneVocab
from stformer.tokenizer import tokenize_and_pad_batch_2
from stformer.model import TransformerModel

In [2]:
class Tokenizer():
    def __init__(self, tokenizer_dir, adata, vocab, pad_value, pad_token):
        self.tokenizer_dir = tokenizer_dir
        self.adata = adata
        self.vocab = vocab
        self.pad_value = pad_value
        self.pad_token = pad_token
        self.load_data()
    
    def load_data(self):
        self.expression_matrix = self.adata.X.A
        self.niche_ligands_expression = self.adata.obsm['niche_ligands_expression'].A
        self.niche_composition = self.adata.obsm['niche_composition'].A

        gene_list_df = pd.read_csv(f'{self.tokenizer_dir}/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
        gene_list = list(gene_list_df['gene_name'])
        self.gene_ids = np.array(self.vocab(gene_list), dtype=int)

        ligand_database = pd.read_csv(self.tokenizer_dir+'ligand_database.csv', header=0, index_col=0)
        ligand_symbol = ligand_database[ligand_database.sum(1)>1].index.values
        ligand_symbol = gene_list_df.loc[gene_list_df['gene_name'].isin(ligand_symbol), 'gene_name'].values
        self.ligand_ids = np.array(self.vocab(ligand_symbol.tolist())*50, dtype=int)

        self.celltype_labels = self.adata.obs['cell_type'].cat.codes.values

    def tokenize_data(self, train_index, valid_index):
        biases = np.zeros([self.niche_composition.shape[0], self.niche_composition.shape[1]*986])
        for k in range(biases.shape[0]):
            biases[k] = np.concatenate([[np.log(p)]*986 for p in self.niche_composition[k]])

        self.celltype_labels_train = torch.from_numpy(self.celltype_labels[train_index]).long()
        self.celltype_labels_valid = torch.from_numpy(self.celltype_labels[valid_index]).long()

        train_expression = self.expression_matrix[train_index]
        train_niche = self.niche_ligands_expression[train_index]
        train_biases = biases[train_index]
        
        valid_expression = self.expression_matrix[valid_index]
        valid_niche = self.niche_ligands_expression[valid_index]
        valid_biases = biases[valid_index]

        tokenized_train = tokenize_and_pad_batch_2(
            train_expression,
            train_niche,
            train_biases,
            self.gene_ids,
            self.ligand_ids,
            pad_id = self.vocab[self.pad_token],
            pad_value = self.pad_value,
        )

        tokenized_valid = tokenize_and_pad_batch_2(
            valid_expression,
            valid_niche,
            valid_biases,
            self.gene_ids,
            self.ligand_ids,
            pad_id = self.vocab[self.pad_token],
            pad_value = self.pad_value,
        )

        logger.info(
            f"train set number of samples: {tokenized_train['center_genes'].shape[0]}, "
            f"\n\t feature length of center cell: {tokenized_train['center_genes'].shape[1]}"
            f"\n\t feature length of niche cells: {tokenized_train['niche_genes'].shape[1]}"
        )
        logger.info(
            f"valid set number of samples: {tokenized_valid['center_genes'].shape[0]}, "
            f"\n\t feature length of center cell: {tokenized_valid['center_genes'].shape[1]}"
            f"\n\t feature length of niche cells: {tokenized_valid['niche_genes'].shape[1]}"
        )

        self.tokenized_train = tokenized_train
        self.tokenized_valid = tokenized_valid

    def prepare_data(self):
        self.train_data_pt = {
            "center_gene_ids": self.tokenized_train["center_genes"],
            "input_center_values": self.tokenized_train["center_values"],
            "niche_gene_ids": self.tokenized_train["niche_genes"],
            "input_niche_values": self.tokenized_train["niche_values"],
            "cross_attn_bias": self.tokenized_train["cross_attn_bias"],
            "celltype_labels": self.celltype_labels_train,
        }

        self.valid_data_pt = {
            "center_gene_ids": self.tokenized_valid["center_genes"],
            "input_center_values": self.tokenized_valid["center_values"],
            "niche_gene_ids": self.tokenized_valid["niche_genes"],
            "input_niche_values": self.tokenized_valid["niche_values"],
            "cross_attn_bias": self.tokenized_valid["cross_attn_bias"],
            "celltype_labels": self.celltype_labels_valid,
        }

    def prepare_dataloader(self, batch_size):
        train_loader = DataLoader(
            dataset=SeqDataset(self.train_data_pt),
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=min(len(os.sched_getaffinity(0)), batch_size // 2),
            pin_memory=True,
        )

        valid_loader = DataLoader(
            dataset=SeqDataset(self.valid_data_pt),
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=min(len(os.sched_getaffinity(0)), batch_size // 2),
            pin_memory=True,
        )
        return train_loader, valid_loader


class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["center_gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}
    
    
def downsample(adata, n):
    celltype_counts = adata.obs['cell_type'].value_counts() 
    to_downsample = celltype_counts[celltype_counts > 400].index.tolist()

    keep_cells = np.zeros(len(adata), dtype=bool)
    for ct in adata.obs['cell_type'].unique():
        ct_mask = adata.obs['cell_type'] == ct
        if ct in to_downsample:
            np.random.seed(0)
            sampled_idx = np.random.choice(
                np.where(ct_mask)[0],  
                size=n,               
                replace=False           
            )
            keep_cells[sampled_idx] = True  

    adata = adata[keep_cells]
    return adata

In [None]:
criterion_cls = nn.CrossEntropyLoss()

def train(model: nn.Module, loader: DataLoader, scaler, optimizer, scheduler, log_interval, epoch, mode) -> None:
    """
    Train the model for one epoch.
    """
    amp = True

    model.train()
    total_cls = 0.0
    total_error = 0.0
 
    start_time = time.time()

    num_batches = len(loader)

    for batch, batch_data in enumerate(loader):
        center_gene_ids = batch_data["center_gene_ids"].to(device)
        if center_gene_ids.size(0)<8:
            continue
        input_center_values = batch_data["input_center_values"].to(device)
        niche_gene_ids = batch_data["niche_gene_ids"].to(device)
        input_niche_values = batch_data["input_niche_values"].to(device)
        cross_attn_bias = batch_data["cross_attn_bias"].to(device)
        celltype_labels = batch_data["celltype_labels"].to(device)

        if mode == 'sp':
            encoder_src_key_padding_mask = niche_gene_ids.eq(vocab[pad_token])
        elif mode == 'sc':
            encoder_src_key_padding_mask = torch.ones_like(niche_gene_ids, dtype=torch.bool).to(device)
        decoder_src_key_padding_mask = center_gene_ids.eq(vocab[pad_token])

        with torch.cuda.amp.autocast(enabled=amp):
            output_dict = model(
                    niche_gene_ids,
                    input_niche_values,
                    encoder_src_key_padding_mask,
                    center_gene_ids,
                    input_center_values,
                    decoder_src_key_padding_mask,
                    cross_attn_bias,
                    CLS = True,
                )

            loss_cls = criterion_cls(output_dict["cls_output"], celltype_labels)

            error_rate_cls = 1 - (
                    (output_dict["cls_output"].argmax(1) == celltype_labels)
                    .sum()
                    .item()
                ) / celltype_labels.size(0)

        model.zero_grad()
        scaler.scale(loss_cls).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"Found infinite gradient. This may be caused by the gradient "
                    f"scaler. The current scale is {scaler.get_scale()}. This warning "
                    "can be ignored if no longer occurs after autoscaling of the scaler."
                )
        scaler.step(optimizer)
        scaler.update()

        total_cls += loss_cls.item()
        total_error += error_rate_cls

        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            sec_per_batch = (time.time() - start_time) / log_interval
            cur_cls = total_cls / log_interval
            cur_error = total_error / log_interval
            logger.info(
                f"| Split {split} | {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.8f} | sec/batch {sec_per_batch:5.1f} | "
                f"cls {cur_cls:5.5f} | "
                f"acc {1-cur_error:1.5f} | "
            )
            total_cls = 0
            total_error = 0
            start_time = time.time()

def evaluate(model: nn.Module, loader: DataLoader, mode) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    amp = True
    
    model.eval()
    total_cls = 0.0
    total_accuracy = 0.0
    total_num = 0

    true_labels = []
    predicted_labels = []
    
    # batch_num = 0
    with torch.no_grad():
        for batch_data in tqdm(loader):
            # batch_num += 1
            # if batch_num>100:
            #     break
            center_gene_ids = batch_data["center_gene_ids"].to(device)
            if center_gene_ids.size(0)<8:
                continue
            input_center_values = batch_data["input_center_values"].to(device)
            niche_gene_ids = batch_data["niche_gene_ids"].to(device)
            input_niche_values = batch_data["input_niche_values"].to(device)
            cross_attn_bias = batch_data["cross_attn_bias"].to(device)
            celltype_labels = batch_data["celltype_labels"].to(device)

            if mode == 'sp':
                encoder_src_key_padding_mask = niche_gene_ids.eq(vocab[pad_token])
            elif mode == 'sc':
                encoder_src_key_padding_mask = torch.ones_like(niche_gene_ids, dtype=torch.bool).to(device)
            decoder_src_key_padding_mask = center_gene_ids.eq(vocab[pad_token])

            with torch.cuda.amp.autocast(enabled=amp):
                output_dict = model(
                        niche_gene_ids,
                        input_niche_values,
                        encoder_src_key_padding_mask,
                        center_gene_ids,
                        input_center_values,
                        decoder_src_key_padding_mask,
                        cross_attn_bias,
                        CLS = True,
                    )
                
                loss_cls = criterion_cls(output_dict["cls_output"], celltype_labels)
                
            accuracy = (output_dict["cls_output"].argmax(1) == celltype_labels).sum().item()
            total_accuracy += accuracy
            total_num += celltype_labels.size(0)
            total_cls += loss_cls*celltype_labels.size(0)

            true_labels.append(celltype_labels.to('cpu'))
            predicted_labels.append(output_dict["cls_output"].argmax(1).to('cpu'))

    val_accuracy = total_accuracy / total_num
    val_loss = total_cls / total_num

    true_labels = torch.cat(true_labels).numpy()
    predicted_labels = torch.cat(predicted_labels).numpy()
    
    return val_loss, val_accuracy, true_labels, predicted_labels

def train_and_evaluate(model, train_loader, valid_loader, epochs, mode):
    lr = 1e-4
    amp = True
    schedule_ratio = 0.9
    schedule_interval = 1
    log_interval = 10

    optimizer = torch.optim.Adam(
            model.parameters(), lr=lr, eps=1e-4 if amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, schedule_interval, gamma=schedule_ratio
    )
    scaler = torch.cuda.amp.GradScaler(enabled=amp)

    best_val_accuracy = 0
    best_model = None
    best_true_labels = None
    best_predicted_labels = None

    for epoch in range(1, epochs+1):
        train(model, train_loader, scaler, optimizer, scheduler, log_interval, epoch, mode)

        val_loss, val_accuracy, true_labels, predicted_labels = evaluate(model, valid_loader, mode)
        logger.info("-" * 89)
        logger.info(
            f"| end of epoch {epoch:3d} | "
            f"valid accuracy: {val_accuracy:1.4f} | "
            f"valid loss: {val_loss:5.4f} | "
        )
        logger.info("-" * 89)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            logger.info(f"Best model with accuracy {best_val_accuracy:1.4f}")
            
            best_model = copy.deepcopy(model)
            best_true_labels = true_labels
            best_predicted_labels = predicted_labels

        scheduler.step()
    
    return best_val_accuracy, best_model, best_true_labels, best_predicted_labels

In [None]:
from scfoundation import load

def initialize_model(model_file, n_cls):
    pretrainmodel, pretrainconfig = load.load_model_frommmf('scfoundation/models/models.ckpt')

    model = TransformerModel(
        embsize,
        nhead,
        d_hid,
        nlayers,
        do_cls = True,
        nlayers_cls = 3,
        n_cls = n_cls,
        dropout = dropout,
        cell_emb_style = cell_emb_style,
        scfoundation_token_emb1 = copy.deepcopy(pretrainmodel.token_emb),
        scfoundation_token_emb2 = copy.deepcopy(pretrainmodel.token_emb),
        scfoundation_pos_emb1 = copy.deepcopy(pretrainmodel.pos_emb),
        scfoundation_pos_emb2 = copy.deepcopy(pretrainmodel.pos_emb),
    )

    pt_model = torch.load(model_file, map_location='cpu')

    model_dict = model.state_dict()
    pretrained_dict = pt_model.state_dict()
    pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if 'cls_decoder' not in k and 'gcl_decoder' not in k
                # if k in model_dict and v.shape == model_dict[k].shape
    }
    # for k, v in pretrained_dict.items():
    #     logger.info(f"Loading params {k} with shape {v.shape}")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
    for name, para in model.named_parameters():
        para.requires_grad = False
    for name, para in model.transformer_decoder.layers[4:6].named_parameters():
        para.requires_grad = True
    for name, para in model.cls_decoder.named_parameters():
        para.requires_grad = True
    post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

    logger.info(f"Total Pre freeze Params {(pre_freeze_param_count )}")
    logger.info(f"Total Post freeze Params {(post_freeze_param_count )}")

    return model        

In [None]:
embsize = 768
d_hid = 3072
nhead = 12
nlayers = 6
dropout = 0.1
cell_emb_style = 'max-pool'
CLS = True
mode = 'sp'

pad_token = "<pad>"
pad_value = 103
tokenizer_dir = '../stformer/tokenizer/'
vocab_file = tokenizer_dir + "scfoundation_gene_vocab.json"
vocab = GeneVocab.from_file(vocab_file)
vocab.append_token(pad_token)
vocab.set_default_index(vocab[pad_token])

model_file = '../pretraining/models/model_4.1M.ckpt'

# Cross validation

In [None]:
dataset = 'pancreas_cosmx'
adata = sc.read_h5ad(f'../datasets/pancreas_cosmx_niche.h5ad')

adata.obs['cell_type'] = adata.obs['cell_type'].map(dict(zip(np.array(range(len(adata.uns['cell_types_list'])))+1, adata.uns['cell_types_list'])))
adata.obs['cell_type'] = adata.obs['cell_type'].astype('category')

adata = downsample(adata, 490)

In [None]:
best_val_accuracy_dict={}
cm_dict={}

for max_niche_cell_num in [5,10,15,20,25,30,35]:

    adata.obsm['niche_celltypes'] = adata.obsm[f'niche_celltypes_niche{max_niche_cell_num}']
    adata.obsm['niche_composition'] = adata.obsm[f'niche_composition_niche{max_niche_cell_num}']
    adata.obsm['niche_ligands_expression'] = adata.obsm[f'niche_ligands_expression_niche{max_niche_cell_num}']

    tokenizer = Tokenizer(tokenizer_dir, adata, vocab, pad_value, pad_token)

    epochs = 20
    batch_size = 100

    best_val_accuracy_list = []
    best_true_labels_list = []
    best_predicted_labels_list = []

    n_splits = 5
    skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)

    split = 0
    for train_index, valid_index in skf.split(tokenizer.adata.obs_names.values, tokenizer.celltype_labels):
        split += 1
        logger.info(f"Cross-validate on dataset {dataset} - split {split}")
        tokenizer.tokenize_data(train_index, valid_index)
        tokenizer.prepare_data()
        train_loader, valid_loader = tokenizer.prepare_dataloader(batch_size)
        
        model = initialize_model(model_file, len(set(tokenizer.celltype_labels)))
        model = nn.DataParallel(model, device_ids = [0, 3, 1, 2])
        device = torch.device("cuda:0")
        model.to(device)

        best_val_accuracy, best_model, best_true_labels, best_predicted_labels = train_and_evaluate(model, train_loader, valid_loader, epochs, mode)
        best_val_accuracy_list.append(best_val_accuracy)
        best_true_labels_list.append(best_true_labels)
        best_predicted_labels_list.append(best_predicted_labels)

    cm_list=[]
    for i in range(n_splits):
        cm = confusion_matrix(best_true_labels_list[i], best_predicted_labels_list[i], normalize='true')
        cm_list.append(cm)
    cm_dict[max_niche_cell_num] = cm_list

    best_val_accuracy_dict[max_niche_cell_num] = best_val_accuracy_list

    pickle.dump(best_val_accuracy_dict, open(f'figures/cell_classification/cv_accuracy_dict_4M.pkl', 'wb'))
    pickle.dump(cm_dict, open(f'figures/cell_classification/cm_dict_4M.pkl', 'wb'))
    pickle.dump(adata.obs['cell_type'].cat.categories.tolist(), open(f'figures/cell_classification/celltype_labels.pkl', 'wb'))

# Leave out one FOV

In [3]:
dataset = 'pancreas_cosmx'
adata = sc.read_h5ad(f'../datasets/pancreas_cosmx_niche.h5ad')

max_niche_cell_num = 20
adata.obsm['niche_celltypes'] = adata.obsm[f'niche_celltypes_niche{max_niche_cell_num}']
adata.obsm['niche_composition'] = adata.obsm[f'niche_composition_niche{max_niche_cell_num}']
adata.obsm['niche_ligands_expression'] = adata.obsm[f'niche_ligands_expression_niche{max_niche_cell_num}']

adata.obs['cell_type'] = adata.obs['cell_type'].map(dict(zip(np.array(range(len(adata.uns['cell_types_list'])))+1, adata.uns['cell_types_list'])))
adata.obs['cell_type'] = adata.obs['cell_type'].astype('category')

adata1 = adata[adata.obs['fov']!=52]
adata1 = downsample(adata1, 300)

In [None]:
from sklearn.model_selection import train_test_split

tokenizer = Tokenizer(tokenizer_dir, adata1, vocab, pad_value, pad_token)

epochs = 20
batch_size = 100

split = 0
train_index, valid_index = train_test_split(np.array(range(tokenizer.adata1.shape[0])), test_size=0.2, random_state=42, stratify=tokenizer.celltype_labels)

tokenizer.tokenize_data(train_index, valid_index)
tokenizer.prepare_data()
train_loader, valid_loader = tokenizer.prepare_dataloader(batch_size)
    
model = initialize_model(model_file, len(set(tokenizer.celltype_labels)))
model = nn.DataParallel(model, device_ids = [2, 3, 1, 0])
device = torch.device("cuda:2")
model.to(device)

best_val_accuracy, best_model, best_true_labels, best_predicted_labels = train_and_evaluate(model, train_loader, valid_loader, epochs, mode)

In [None]:
adata2 = adata[adata.obs['fov']==52]
adata3 = adata2[adata2.obs['cell_type'].isin(adata1.obs['cell_type'].cat.categories)]

tokenizer = Tokenizer(tokenizer_dir, adata3, vocab, pad_value, pad_token)

split = 0
test_index = np.array(range(adata3.shape[0]))
tokenizer.tokenize_data(test_index, test_index)
tokenizer.prepare_data()
test_loader, test_loader = tokenizer.prepare_dataloader(batch_size)

test_loss, test_accuracy, true_labels, predicted_labels = evaluate(best_model, test_loader, mode)
test_accuracy

In [None]:
cm = confusion_matrix(true_labels, predicted_labels, normalize='true')
pickle.dump(cm, open(f'figures/cell_classification/cm_fov52_4M-niche20.pkl', 'wb'))

In [None]:
adata2.obs['predicted_celltypes'] = ['']*adata2.shape[0]
predicted_celltypes = [adata3.obs['cell_type'].cat.categories[i] for i in predicted_labels]
adata2.obs['predicted_celltypes'].loc[adata3.obs_names] = predicted_celltypes

In [None]:
adata2.write('figures/cell_classification/adata_fov52_4M-niche20.h5ad')