In [22]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import warnings
warnings.filterwarnings('ignore')
import pickle
import time
import copy
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import auc, roc_curve, confusion_matrix
import anndata
import scanpy as sc
import torch

from torch import nn, Tensor
from scfoundation import load

In [None]:
class Tokenizer():
    def __init__(self, tokenizer_dir, adata, pad_value, pad_token):
        self.tokenizer_dir = tokenizer_dir
        self.adata = adata
        self.pad_value = pad_value
        self.pad_token = pad_token

    def prepare_data(self):
        gexpr_feature = self.adata.X.A
        S = gexpr_feature.sum(1)
        T = S
        TS = np.concatenate([[np.log10(T)],[np.log10(S)]],axis=0).T
        data = np.concatenate([gexpr_feature,TS],axis=1)
        self.data = data

        self.celltype_labels = self.adata.obs['cell_type'].cat.codes.values

    def prepare_train_and_valid_data(self, train_index, valid_index):
        celltype_labels_train = torch.from_numpy(self.celltype_labels[train_index]).long()
        celltype_labels_valid = torch.from_numpy(self.celltype_labels[valid_index]).long()

        train_data = torch.from_numpy(self.data[train_index]).float()
        valid_data = torch.from_numpy(self.data[valid_index]).float()
        train_data_gene_ids = torch.arange(train_data.shape[1]).repeat(train_data.shape[0], 1)
        valid_data_gene_ids = torch.arange(valid_data.shape[1]).repeat(valid_data.shape[0], 1)

        train_data_index = train_data != 0
        train_values, train_padding = load.gatherData(train_data, train_data_index, self.pad_value)
        train_gene_ids, _ = load.gatherData(train_data_gene_ids, train_data_index, self.pad_token)
        train_data = {'values': train_values, 'padding': train_padding, 'gene_ids': train_gene_ids, 'celltype_labels': celltype_labels_train}

        valid_data_index = valid_data != 0
        valid_values, valid_padding = load.gatherData(valid_data, valid_data_index, self.pad_value)
        valid_gene_ids, _ = load.gatherData(valid_data_gene_ids, valid_data_index, self.pad_token)
        valid_data = {'values': valid_values, 'padding': valid_padding, 'gene_ids': valid_gene_ids, 'celltype_labels': celltype_labels_valid}

        return train_data, valid_data


def downsample(adata, n):
    celltype_counts = adata.obs['cell_type'].value_counts() 
    to_downsample = celltype_counts[celltype_counts > 400].index.tolist()

    keep_cells = np.zeros(len(adata), dtype=bool)
    for ct in adata.obs['cell_type'].unique():
        ct_mask = adata.obs['cell_type'] == ct
        if ct in to_downsample:
            np.random.seed(0)
            sampled_idx = np.random.choice(
                np.where(ct_mask)[0],  
                size=n,               
                replace=False           
            )
            keep_cells[sampled_idx] = True  

    adata = adata[keep_cells]
    return adata

In [3]:
class scF_Cls(nn.Module):
    def __init__(
            self,
            scf_token_emb,
            scf_pos_emb,
            scf_encoder,
            d_model: int,
            n_cls: int = 2,
            nlayers_cls: int = 3,
    ):
        super(scF_Cls, self).__init__()

        # encoder
        self.token_emb = scf_token_emb
        self.pos_emb = scf_pos_emb
        self.encoder = scf_encoder
        self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls)

    def forward(self, gene_values, padding_label, gene_ids):

        # token and positional embedding
        x = self.token_emb(torch.unsqueeze(gene_values, 2), output_weight = 0)

        position_emb = self.pos_emb(gene_ids)
        x += position_emb
        x = self.encoder(x, padding_mask=padding_label)

        geneembmerge = torch.cat([torch.max(x[k][~padding_label[k]], dim=0)[0].unsqueeze(0) for k in range(x.size(0))])
        output = self.cls_decoder(geneembmerge)

        return output


class ClsDecoder(nn.Module):
    """
    Decoder for cell classification task.
    """

    def __init__(
        self,
        d_model: int,
        n_cls: int,
        nlayers: int = 3,
        activation: callable = nn.ReLU,
    ):
        super().__init__()
        # module list
        self._decoder = nn.ModuleList()
        for i in range(nlayers - 1):
            self._decoder.append(nn.Linear(d_model, d_model))
            self._decoder.append(activation())
            self._decoder.append(nn.LayerNorm(d_model))
        self.out_layer = nn.Linear(d_model, n_cls)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, embsize]
        """
        for layer in self._decoder:
            x = layer(x)
        return self.out_layer(x)

In [5]:
criterion_cls = nn.CrossEntropyLoss()

def train(model: nn.Module, train_data, batch_size, scaler, optimizer, scheduler, log_interval, epoch) -> None:
    """
    Train the model for one epoch.
    """
    amp = True

    model.train()
    total_cls = 0.0
    total_error = 0.0

    start_time = time.time()

    train_values = train_data['values']
    train_padding = train_data['padding']
    train_gene_ids = train_data['gene_ids']
    train_celltype_labels = train_data['celltype_labels']

    num_batches = np.ceil(len(train_values)/batch_size).astype(int)
    for k in range(0, len(train_values), batch_size):
        batch = int(k/batch_size+1)
        if k+12>=len(train_values):
            break
        with torch.cuda.amp.autocast(enabled=amp):
            output = model(train_values[k:k+batch_size].to(device), 
                           train_padding[k:k+batch_size].to(device), 
                           train_gene_ids[k:k+batch_size].to(device))
            
            batch_celltype_labels = train_celltype_labels[k:k+batch_size].to(device)
            loss_cls = criterion_cls(output, batch_celltype_labels)

            error_rate_cls = 1 - (
                    (output.argmax(1) == batch_celltype_labels)
                    .sum()
                    .item()
                ) / batch_celltype_labels.size(0)
            
        model.zero_grad()
        scaler.scale(loss_cls).backward()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()

        total_cls += loss_cls.item()
        total_error += error_rate_cls
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            sec_per_batch = (time.time() - start_time) / log_interval
            cur_cls = total_cls / log_interval
            cur_error = total_error / log_interval
            print(f"| Split {split} | {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.8f} | sec/batch {sec_per_batch:5.1f} | "
                f"cls {cur_cls:5.5f} | "
                f"acc {1-cur_error:1.5f} | "
            )
            total_cls = 0
            total_error = 0
            start_time = time.time()

def evaluate(model: nn.Module, valid_data, batch_size) -> None:
    amp = True
    
    model.eval()
    total_cls = 0.0
    total_accuracy = 0.0
    total_num = 0

    true_labels = []
    predicted_labels = []

    valid_values = valid_data['values']
    valid_padding = valid_data['padding']
    valid_gene_ids = valid_data['gene_ids']
    valid_celltype_labels = valid_data['celltype_labels']

    with torch.no_grad():
        for k in tqdm(range(0, len(valid_values), batch_size)):
            if k+12>=len(valid_values):
                break
            with torch.cuda.amp.autocast(enabled=amp):
                output = model(valid_values[k:k+batch_size].to(device), 
                           valid_padding[k:k+batch_size].to(device), 
                           valid_gene_ids[k:k+batch_size].to(device))
            
                batch_celltype_labels = valid_celltype_labels[k:k+batch_size].to(device)
                loss_cls = criterion_cls(output, batch_celltype_labels)
            
            accuracy = (output.argmax(1) == batch_celltype_labels).sum().item()
            total_accuracy += accuracy
            total_num += batch_celltype_labels.size(0)
            total_cls += loss_cls*batch_celltype_labels.size(0)

            true_labels.append(batch_celltype_labels.to('cpu'))
            predicted_labels.append(output.argmax(1).to('cpu'))

    val_acc = total_accuracy / total_num
    val_loss = total_cls / total_num

    true_labels = torch.cat(true_labels).numpy()
    predicted_labels = torch.cat(predicted_labels).numpy()

    return val_loss, val_acc, true_labels, predicted_labels


def train_and_evaluate(model, train_data, valid_data, batch_size, epochs):
    amp = True
    lr = 1e-4
    schedule_ratio = 0.9
    schedule_interval = 1
    log_interval = 10

    optimizer = torch.optim.Adam(
            model.parameters(), lr=lr, eps=1e-4 if amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, schedule_interval, gamma=schedule_ratio
    )
    scaler = torch.cuda.amp.GradScaler(enabled=amp)

    best_val_accuracy = 0
    best_model = None
    best_true_labels = None
    best_predicted_labels = None

    for epoch in range(1, epochs+1):
        train(model, train_data, batch_size, scaler, optimizer, scheduler, log_interval, epoch)

        val_loss, val_accuracy, true_labels, predicted_labels = evaluate(model, valid_data, batch_size)
        print("-" * 89)
        print(
            f"| end of epoch {epoch:3d} | "
            f"valid accuracy: {val_accuracy:1.4f} | "
            f"valid loss: {val_loss:5.4f} | "
        )
        print("-" * 89)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model = copy.deepcopy(model)
            print(f"Best model with accuracy {best_val_accuracy:1.4f}")

            best_true_labels = true_labels
            best_predicted_labels = predicted_labels

        scheduler.step()
    
    return best_val_accuracy, best_model, best_true_labels, best_predicted_labels

In [6]:
class scFoundation(nn.Module):
    def __init__(
            self,
            scf_token_emb,
            scf_pos_emb,
            scf_encoder,
            scf_decoder,
            scf_decoder_embed,
            scf_norm,
            scf_to_final,
    ):
        super(scFoundation, self).__init__()

        # encoder
        self.token_emb = scf_token_emb
        self.pos_emb = scf_pos_emb

        # ## DEBUG
        self.encoder = scf_encoder

        ##### decoder
        self.decoder = scf_decoder
        self.decoder_embed = scf_decoder_embed
        self.norm = scf_norm
        self.to_final = scf_to_final

    def forward(self, x, padding_label, encoder_position_gene_ids, encoder_labels, decoder_data,
                decoder_position_gene_ids, decoder_data_padding_labels, **kwargs):

        # token and positional embedding
        x = self.token_emb(torch.unsqueeze(x, 2), output_weight = 0)

        position_emb = self.pos_emb(encoder_position_gene_ids)
        x += position_emb
        x = self.encoder(x, padding_mask=padding_label)

        decoder_data = self.token_emb(torch.unsqueeze(decoder_data, 2))
        position_emb = self.pos_emb(decoder_position_gene_ids)
        batch_idx, gen_idx = (encoder_labels == True).nonzero(as_tuple=True)
        decoder_data[batch_idx, gen_idx] = x[~padding_label].to(decoder_data.dtype)

        decoder_data += position_emb

        decoder_data = self.decoder_embed(decoder_data)
        x = self.decoder(decoder_data, padding_mask=decoder_data_padding_labels)

        x = self.norm(x)
        # return x
        x = self.to_final(x)
        return x.squeeze(2)

In [None]:
def initialize_model(model_file, n_cls):
    if model_file is None:
        # load pretrained model
        pretrainmodel, pretrainconfig = load.load_model_frommmf('scfoundation/models/models.ckpt')
    else:
        # load fine-tuned model
        pretrainmodel = torch.load(f'fine-tuning_mse/models/{model_file}', map_location='cpu')
        pretrainmodel = pretrainmodel.module
    
    model = scF_Cls(pretrainmodel.token_emb,
            pretrainmodel.pos_emb,
            pretrainmodel.encoder,
            d_model = 768,
            n_cls = n_cls,
            nlayers_cls = 3
            )
    
    pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
    for name, para in model.named_parameters():
        para.requires_grad = False
    for name, para in model.encoder.transformer_encoder[10:12].named_parameters():
        para.requires_grad = True        
    for name, para in model.cls_decoder.named_parameters():
        para.requires_grad = True
    post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
    print(f"Total Pre freeze Params {(pre_freeze_param_count )}")
    print(f"Total Post freeze Params {(post_freeze_param_count )}")

    return model

In [None]:
pad_token = 19266
pad_value = 103
tokenizer_dir = '../stformer/tokenizer/'

model_file = None

# Cross validation

In [None]:
dataset = 'pancreas_cosmx'
adata = sc.read_h5ad(f'../datasets/pancreas_cosmx_niche.h5ad')

adata.obs['cell_type'] = adata.obs['cell_type'].map(dict(zip(np.array(range(len(adata.uns['cell_types_list'])))+1, adata.uns['cell_types_list'])))
adata.obs['cell_type'] = adata.obs['cell_type'].astype('category')

adata = downsample(adata, 490)

In [None]:
tokenizer = Tokenizer(tokenizer_dir, adata, pad_value, pad_token)
tokenizer.prepare_data()

epochs = 20
batch_size = 100

best_val_accuracy_list = []
best_true_labels_list = []
best_predicted_labels_list = []

n_splits = 5
skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)

split = 0
for train_index, valid_index in skf.split(tokenizer.adata.obs_names.values, tokenizer.celltype_labels):
    split += 1
    print(f"Cross-validate on dataset {dataset} - split {split}")
    train_data, valid_data = tokenizer.prepare_train_and_valid_data(train_index, valid_index)

    model = initialize_model(model_file, len(set(tokenizer.celltype_labels)))
    model = torch.nn.DataParallel(model, device_ids=[1, 3, 0, 2])
    device = torch.device("cuda:1")
    model.to(device)

    best_val_accuracy, best_model, best_true_labels, best_predicted_labels = train_and_evaluate(model, train_data, valid_data, batch_size, epochs)
    best_val_accuracy_list.append(best_val_accuracy)
    best_true_labels_list.append(best_true_labels)
    best_predicted_labels_list.append(best_predicted_labels)

In [15]:
print(f"Mean accuracy: {np.median(best_val_accuracy_list)} +/- {np.std(best_val_accuracy_list)}")

In [None]:
pickle.dump(best_val_accuracy_list, open(f'figures/cell_classification/cv_accuracy_list_scf.pkl', 'wb'))

In [29]:
cm_list=[]
for i in range(n_splits):
    cm = confusion_matrix(best_true_labels_list[i], best_predicted_labels_list[i], normalize='true')
    cm_list.append(cm)

In [None]:
pickle.dump(cm_list, open(f'figures/cell_classification/cm_list_scf.pkl', 'wb'))

# Leave out one FOV

In [None]:
dataset = 'pancreas_cosmx'
adata = sc.read_h5ad(f'../datasets/pancreas_cosmx_niche.h5ad')

max_niche_cell_num = 20
adata.obsm['niche_celltypes'] = adata.obsm[f'niche_celltypes_niche{max_niche_cell_num}']
adata.obsm['niche_composition'] = adata.obsm[f'niche_composition_niche{max_niche_cell_num}']
adata.obsm['niche_ligands_expression'] = adata.obsm[f'niche_ligands_expression_niche{max_niche_cell_num}']

adata.obs['cell_type'] = adata.obs['cell_type'].map(dict(zip(np.array(range(len(adata.uns['cell_types_list'])))+1, adata.uns['cell_types_list'])))
adata.obs['cell_type'] = adata.obs['cell_type'].astype('category')

adata1 = adata[adata.obs['fov']!=52]
adata1 = downsample(adata1, 300)

In [None]:
from sklearn.model_selection import train_test_split

tokenizer = Tokenizer(tokenizer_dir, adata, pad_value, pad_token)
tokenizer.prepare_data()

epochs = 20
batch_size = 100

split = 0
train_index, valid_index = train_test_split(np.array(range(tokenizer.adata.shape[0])), test_size=0.2, random_state=42, stratify=tokenizer.celltype_labels)

train_data, valid_data = tokenizer.prepare_train_and_valid_data(train_index, valid_index)

model = initialize_model(model_file, len(set(tokenizer.celltype_labels)))
model = torch.nn.DataParallel(model, device_ids=[1, 3, 0, 2])
device = torch.device("cuda:1")
model.to(device)

best_val_accuracy, best_model, best_true_labels, best_predicted_labels = train_and_evaluate(model, train_data, valid_data, batch_size, epochs)

In [None]:
adata2 = adata[adata.obs['fov']==52]
adata3 = adata2[adata2.obs['cell_type'].isin(adata1.obs['cell_type'].cat.categories)]

tokenizer = Tokenizer(tokenizer_dir, adata3, pad_value, pad_token)
tokenizer.prepare_data()

split = 0
test_index = np.array(range(adata3.shape[0]))
test_data, test_data = tokenizer.prepare_train_and_valid_data(test_index, test_index)

test_loss, test_accuracy, true_labels, predicted_labels = evaluate(best_model, test_data, batch_size)
test_accuracy

In [None]:
cm = confusion_matrix(true_labels, predicted_labels, normalize='true')
pickle.dump(cm, open(f'figures/cell_classification/cm_fov52_scf.pkl', 'wb'))

In [None]:
adata2.obs['predicted_celltypes'] = ['']*adata2.shape[0]
predicted_celltypes = [adata3.obs['cell_type'].cat.categories[i] for i in predicted_labels]
adata2.obs['predicted_celltypes'].loc[adata3.obs_names] = predicted_celltypes

In [None]:
adata2.write('figures/cell_classification/adata_fov52_scf.h5ad')