In [1]:
%env R_HOME=/work/gpt/lib/R

env: R_HOME=/work/gpt/lib/R


In [2]:
### Imports
import copy
import gc
import json
import os
from pathlib import Path
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings

import torch
from anndata import AnnData
import scanpy as sc
import scvi
import numpy as np
import pandas as pd
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)

from scgpt.tokenizer.gene_tokenizer import GeneVocab

sys.path.insert(0, "../")
import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics

sc.set_figure_params(figsize=(4, 4))
os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings('ignore')

Global seed set to 0
  IPython.display.set_matplotlib_formats(*ipython_format)


In [3]:
import wandb
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ucloud/.netrc


True

In [4]:
### Define functions
def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray:
    """
    Digitize the data into bins. This method spreads data uniformly when bins
    have same values.

    Args:

    x (:class:`np.ndarray`):
        The data to digitize.
    bins (:class:`np.ndarray`):
        The bins to use for digitization, in increasing order.

    Returns:

    :class:`np.ndarray`:
        The digitized data.
    """
    assert x.ndim == 1 and bins.ndim == 1

    left_digits = np.digitize(x, bins)
    right_difits = np.digitize(x, bins, right=True)

    rands = np.random.rand(len(x))  # uniform random numbers

    digits = rands * (right_difits - left_digits) + left_digits
    digits = np.ceil(digits).astype(np.int64)
    return digits
    
def prepare_data(sort_seq_batch=False) -> Tuple[Dict[str, torch.Tensor]]:
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_ratio=mask_ratio,
        mask_value=mask_value,
        pad_value=pad_value,
    )
    print(
        f"random masking at epoch {epoch:3d}, ratio of masked values in train: ",
        f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}",
    )

    input_gene_ids_train, input_gene_ids_valid = (
        tokenized_train["genes"],
        tokenized_valid["genes"],
    )
    input_values_train, input_values_valid = masked_values_train, masked_values_valid
    target_values_train, target_values_valid = (
        tokenized_train["values"],
        tokenized_valid["values"],
    )

    tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long()
    tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long()

    if sort_seq_batch:
        train_sort_ids = np.argsort(train_batch_labels)
        input_gene_ids_train = input_gene_ids_train[train_sort_ids]
        input_values_train = input_values_train[train_sort_ids]
        target_values_train = target_values_train[train_sort_ids]
        tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids]

        valid_sort_ids = np.argsort(valid_batch_labels)
        input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids]
        input_values_valid = input_values_valid[valid_sort_ids]
        target_values_valid = target_values_valid[valid_sort_ids]
        tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids]

    train_data_pt = {
        "gene_ids": input_gene_ids_train,
        "values": input_values_train,
        "target_values": target_values_train,
        "batch_labels": tensor_batch_labels_train,
    }
    valid_data_pt = {
        "gene_ids": input_gene_ids_valid,
        "values": input_values_valid,
        "target_values": target_values_valid,
        "batch_labels": tensor_batch_labels_valid,
    }

    return train_data_pt, valid_data_pt


# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    intra_domain_shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
) -> DataLoader:
    dataset = SeqDataset(data_pt)

    if per_seq_batch_sample:
        # find the indices of samples in each seq batch
        subsets = []
        batch_labels_array = data_pt["batch_labels"].numpy()
        for batch_label in np.unique(batch_labels_array):
            batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
            subsets.append(batch_indices)
        data_loader = DataLoader(
            dataset=dataset,
            batch_sampler=SubsetsBatchSampler(
                subsets,
                batch_size,
                intra_subset_shuffle=intra_domain_shuffle,
                inter_subset_shuffle=shuffle,
                drop_last=drop_last,
            ),
            num_workers=num_workers,
            pin_memory=True,
        )
        return data_loader

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    return data_loader

def train(model: nn.Module, loader: DataLoader) -> None:
    """
    Train the model for one epoch.
    """
    model.train()
    total_loss, total_mse, total_gepc = 0.0, 0.0, 0.0
    total_error = 0.0
    log_interval = config.log_interval
    start_time = time.time()

    num_batches = len(loader)
    for batch, batch_data in enumerate(loader):
        input_gene_ids = batch_data["gene_ids"].to(device)
        input_values = batch_data["values"].to(device)
        target_values = batch_data["target_values"].to(device)
        batch_labels = batch_data["batch_labels"].to(device)

        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
        with torch.cuda.amp.autocast(enabled=config.amp):
            output_dict = model(
                input_gene_ids,
                input_values,
                src_key_padding_mask=src_key_padding_mask,
                batch_labels=batch_labels if DSBN else None,
                MVC=config.GEPC,
                ECS=config.ecs_thres > 0,
            )

            masked_positions = input_values.eq(mask_value)  # the postions to predict
            loss = loss_mse = criterion(
                output_dict["mlm_output"], target_values, masked_positions
            )
            metrics_to_log = {"train/mse": loss_mse.item()}
            if explicit_zero_prob:
                loss_zero_log_prob = criterion_neg_log_bernoulli(
                    output_dict["mlm_zero_probs"], target_values, masked_positions
                )
                loss = loss + loss_zero_log_prob
                metrics_to_log.update({"train/nzlp": loss_zero_log_prob.item()})
            if config.GEPC:
                loss_gepc = criterion(
                    output_dict["mvc_output"], target_values, masked_positions
                )
                loss = loss + loss_gepc
                metrics_to_log.update({"train/mvc": loss_gepc.item()})
            if config.GEPC and explicit_zero_prob:
                loss_gepc_zero_log_prob = criterion_neg_log_bernoulli(
                    output_dict["mvc_zero_probs"], target_values, masked_positions
                )
                loss = loss + loss_gepc_zero_log_prob
                metrics_to_log.update(
                    {"train/mvc_nzlp": loss_gepc_zero_log_prob.item()}
                )
            if config.ecs_thres > 0:
                loss_ecs = 10 * output_dict["loss_ecs"]
                loss = loss + loss_ecs
                metrics_to_log.update({"train/ecs": loss_ecs.item()})
            loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)
            loss = loss + config.dab_weight * loss_dab
            metrics_to_log.update({"train/dab": loss_dab.item()})

        model.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        with warnings.catch_warnings(record=True) as w:
            warnings.filterwarnings("always")
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                1.0,
                error_if_nonfinite=False if scaler.is_enabled() else True,
            )
            if len(w) > 0:
                logger.warning(
                    f"Found infinite gradient. This may be caused by the gradient "
                    f"scaler. The current scale is {scaler.get_scale()}. This warning "
                    "can be ignored if no longer occurs after autoscaling of the scaler."
                )
        scaler.step(optimizer)
        scaler.update()

        wandb.log(metrics_to_log)

        with torch.no_grad():
            mre = masked_relative_error(
                output_dict["mlm_output"], target_values, masked_positions
            )

        total_loss += loss.item()
        total_mse += loss_mse.item()
        total_gepc += loss_gepc.item() if config.GEPC else 0.0
        total_error += mre.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            cur_mse = total_mse / log_interval
            cur_gepc = total_gepc / log_interval if config.GEPC else 0.0
            cur_error = total_error / log_interval
            # ppl = math.exp(cur_loss)
            logger.info(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} | mre {cur_error:5.2f} |"
                + (f"gepc {cur_gepc:5.2f} |" if config.GEPC else "")
            )
            total_loss = 0
            total_mse = 0
            total_gepc = 0
            total_error = 0
            start_time = time.time()


def define_wandb_metrcis():
    wandb.define_metric("valid/mse", summary="min", step_metric="epoch")
    wandb.define_metric("valid/mre", summary="min", step_metric="epoch")
    wandb.define_metric("valid/dab", summary="min", step_metric="epoch")
    wandb.define_metric("valid/sum_mse_dab", summary="min", step_metric="epoch")
    wandb.define_metric("test/avg_bio", summary="max")


def evaluate(model: nn.Module, loader: DataLoader) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_dab = 0.0
    total_num = 0
    with torch.no_grad():
        for batch_data in loader:
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)
            batch_labels = batch_data["batch_labels"].to(device)

            src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model(
                    input_gene_ids,
                    input_values,
                    src_key_padding_mask=src_key_padding_mask,
                    batch_labels=batch_labels if DSBN else None,
                )
                output_values = output_dict["mlm_output"]

                masked_positions = input_values.eq(mask_value)
                loss = criterion(output_values, target_values, masked_positions)
                loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

            total_loss += loss.item() * len(input_gene_ids)
            total_error += masked_relative_error(
                output_values, target_values, masked_positions
            ).item() * len(input_gene_ids)
            total_dab += loss_dab.item() * len(input_gene_ids)
            total_num += len(input_gene_ids)

    wandb.log(
        {
            "valid/mse": total_loss / total_num,
            "valid/mre": total_error / total_num,
            "valid/dab": total_dab / total_num,
            "valid/sum_mse_dab": (total_loss + config.dab_weight * total_dab)
            / total_num,
            "epoch": epoch,
        },
    )

    return total_loss / total_num, total_error / total_num


def eval_testdata(
    model: nn.Module,
    adata_t: AnnData,
    include_types: List[str] = ["cls"],
) -> Optional[Dict]:
    """evaluate the model on test dataset of adata_t"""
    model.eval()

    # copy adata_t to avoid reuse previously computed results stored in adata_t
    adata_t = adata_t.copy()

    all_counts = (
        adata_t.layers[input_layer_key].A
        if issparse(adata_t.layers[input_layer_key])
        else adata_t.layers[input_layer_key]
    )

    celltypes_labels = adata_t.obs["celltype"].tolist()
    celltypes_labels = np.array(celltypes_labels)

    batch_ids = adata_t.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)

    # Evaluate cls cell embeddings
    if "cls" in include_types:
        logger.info("Evaluating cls cell embeddings")
        tokenized_all = tokenize_and_pad_batch(
            all_counts,
            gene_ids,
            max_len=max_seq_len,
            vocab=vocab,
            pad_token=pad_token,
            pad_value=pad_value,
            append_cls=True,  # append <cls> token at the beginning
            include_zero_gene=True,
        )
        all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
        src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
            cell_embeddings = model.encode_batch(
                all_gene_ids,
                all_values.float(),
                src_key_padding_mask=src_key_padding_mask,
                batch_size=config.batch_size,
                batch_labels=torch.from_numpy(batch_ids).long() if DSBN else None,
                time_step=0,
                return_np=True,
            )
        cell_embeddings = cell_embeddings / np.linalg.norm(
            cell_embeddings, axis=1, keepdims=True
        )

        adata_t.obsm["X_scGPT"] = cell_embeddings

        results = {}
        try:
            results = eval_scib_metrics(adata_t)
        except Exception as e:
            traceback.print_exc()
            logger.error(e)

        sc.pp.neighbors(adata_t, use_rep="X_scGPT")
        sc.tl.umap(adata_t, min_dist=0.3)
        fig = sc.pl.umap(
            adata_t,
            color=["str_batch"],
            title=[f"batch, avg_bio = {results.get('avg_bio', 0.0):.4f}"],
            frameon=False,
            return_fig=True,
            show=False,
        )

        results["batch_umap"] = fig

        sc.pp.neighbors(adata_t, use_rep="X_scGPT")
        sc.tl.umap(adata_t, min_dist=0.3)
        fig = sc.pl.umap(
            adata_t,
            color=["celltype"],
            title=[
                f"celltype, avg_bio = {results.get('avg_bio', 0.0):.4f}",
            ],
            frameon=False,
            return_fig=True,
            show=False,
        )

        results["celltype_umap"] = fig

    if len(include_types) == 1:
        return results

In [None]:
## Define variables
dataset_id = "Human_Liver"
adata_path = "/work/NMF_project/reproducibility/data/Liver/Human_Liver.h5ad"
hvg_path = "/work/NMF_project/reproducibility/data/Liver/Human_Liver.features"
latent_path = "/work/NMF_project/reproducibility/data/Liver/Human_Liver"

for rep in range(5):
    ## Set hyperparameters
    hyperparameter_defaults = dict(
        seed=42,
        dataset_name=dataset_id, # Dataset name
        do_train=True, # Flag to indicate whether to do update model parameters during training
        load_model="/work/NMF_project/reproducibility/data/scGPT/Model/", # Path to pre-trained model
        GEPC=True,  # Gene expression modelling for cell objective
        ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
        dab_weight=1.0, # DAR objective weight for batch correction
        mask_ratio=0.4, # Default mask ratio
        epochs=15, # Default number of epochs for fine-tuning
        n_bins=51, # Default number of bins for value binning in data pre-processing
        lr=1e-4, # Default learning rate for fine-tuning
        batch_size=64, # Default batch size for fine-tuning
        layer_size=128,
        nlayers=4,
        nhead=4, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
        dropout=0.2, # Default dropout rate during model fine-tuning
        schedule_ratio=0.9,  # Default rate for learning rate decay
        save_eval_interval=5, # Default model evaluation interval
        log_interval=100, # Default log interval
        fast_transformer=True, # Default setting
        pre_norm=False, # Default setting
        amp=True,  # # Default setting: Automatic Mixed Precision
    )

    ## Initialize the run on wandb
    run = wandb.init(
        config=hyperparameter_defaults,
        project="scGPT",
        reinit=True,
        settings=wandb.Settings(start_method="fork"),
    )
    config = wandb.config
    print(config)
    set_seed(config.seed)

    # Settings for input and preprocessing
    pad_token = "<pad>"
    special_tokens = [pad_token, "<cls>", "<eoc>"]
    mask_ratio = config.mask_ratio
    mask_value = -1
    pad_value = -2
    n_input_bins = config.n_bins
    n_hvg = 1200  # number of highly variable genes
    max_seq_len = n_hvg + 1
    per_seq_batch_sample = True
    DSBN = True  # Domain-spec batchnorm
    explicit_zero_prob = True  # whether explicit bernoulli for zeros

    # Settings for saving the model
    dataset_name = config.dataset_name
    save_dir = Path(f"/work/NMF_project/reproducibility/data/scGPT/FT/dev_{dataset_name}_{rep}/")
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"save to {save_dir}")
    logger = scg.logger
    scg.utils.add_file_handler(logger, save_dir / "run.log")

    # Load data
    adata = sc.read(adata_path)
    ori_batch_col = "batch_label"
    adata.var = adata.var.set_index("features")
    data_is_raw = True

    # make the batch category column
    adata.obs["str_batch"] = adata.obs[ori_batch_col].astype(str)
    batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
    adata.obs["batch_id"] = batch_id_labels
    adata.var["gene_name"] = adata.var.index.tolist()

    # Define HVGs
    f = open(hvg_path, "r")
    hvg = f.read().splitlines()
    f.close()
    adata.var['highly_variable'] = [True if g in hvg else False for g in adata.var_names]

    # Load the pretrained model
    if config.load_model is not None:
        model_dir = Path(config.load_model)
        model_config_file = model_dir / "args.json"
        model_file = model_dir / "best_model.pt"
        vocab_file = model_dir / "vocab.json"

        vocab = GeneVocab.from_file(vocab_file)
        for s in special_tokens:
            if s not in vocab:
                vocab.append_token(s)

        adata.var["id_in_vocab"] = [
            1 if gene in vocab else -1 for gene in adata.var["gene_name"]
        ]
        gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
        logger.info(
            f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
            f"in vocabulary of size {len(vocab)}."
        )
        adata = adata[:, adata.var["id_in_vocab"] >= 0]

        # model
        with open(model_config_file, "r") as f:
            model_configs = json.load(f)
        logger.info(
            f"Resume model from {model_file}, the model args will be overriden by the "
            f"config {model_config_file}."
        )
        embsize = model_configs["embsize"]
        nhead = model_configs["nheads"]
        d_hid = model_configs["d_hid"]
        nlayers = model_configs["nlayers"]
        n_layers_cls = model_configs["n_layers_cls"]
    else:
        embsize = config.layer_size
        nhead = config.nhead
        nlayers = config.nlayers
        d_hid = config.layer_size

    # Preprocess the dataset
    sc.pp.filter_genes(adata, min_counts=3)
    normed = sc.pp.normalize_total(adata, target_sum=1e4, layer=None, inplace=False)["X"]
    sc.get._set_obs_rep(adata, normed, layer="X_normed")
    sc.get._set_obs_rep(adata,sc.get._get_obs_rep(adata, layer="X_normed"), layer="X_log1p")
    sc.pp.log1p(adata, layer="X_log1p")
    adata = adata[:, adata.var.highly_variable]
    n_bins = config.n_bins  # NOTE: the first bin is always a spectial for zero
    binned_rows = []
    bin_edges = []
    layer_data = sc.get._get_obs_rep(adata, layer="X_log1p")
    layer_data = layer_data.A if issparse(layer_data) else layer_data
    for row in layer_data:
        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]
        bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
        non_zero_digits = _digitize(x = non_zero_row, bins = bins)
        assert non_zero_digits.min() >= 1
        assert non_zero_digits.max() <= n_bins - 1
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits
        binned_rows.append(binned_row)
        bin_edges.append(np.concatenate([[0], bins]))
    adata.layers["X_binned"] = np.stack(binned_rows)
    adata.obsm["bin_edges"] = np.stack(bin_edges)

    # Sort the adata by batch_id in advance
    if per_seq_batch_sample:
        adata_sorted = adata[adata.obs["batch_id"].argsort()].copy()

    # Define input layers and get counts
    input_layer_key = "X_binned"
    all_counts = (
        adata.layers[input_layer_key].A
        if issparse(adata.layers[input_layer_key])
        else adata.layers[input_layer_key]
    )
    genes = adata.var["gene_name"].tolist()

    # Get batch ids
    batch_ids = adata.obs["batch_id"].tolist()
    num_batch_types = len(set(batch_ids))
    batch_ids = np.array(batch_ids)

    # Create splits
    (
        train_data,
        valid_data,
        train_batch_labels,
        valid_batch_labels,
    ) = train_test_split(
        all_counts, batch_ids, test_size=0.1, shuffle=True
    )

    # Define vocabulary
    if config.load_model is None:
        vocab = Vocab(
            VocabPybind(genes + special_tokens, None)
        )  # bidirectional lookup [gene <-> int]
    vocab.set_default_index(vocab["<pad>"])
    gene_ids = np.array(vocab(genes), dtype=int)

    # Tokenize training and validation data
    tokenized_train = tokenize_and_pad_batch(
        train_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    tokenized_valid = tokenize_and_pad_batch(
        valid_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,
        include_zero_gene=True,
    )
    logger.info(
        f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
    )
    logger.info(
        f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
    )

    # Load the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ntokens = len(vocab)  # size of vocabulary
    model = TransformerModel(
        ntokens,
        embsize,
        nhead,
        d_hid,
        nlayers,
        vocab=vocab,
        dropout=config.dropout,
        pad_token=pad_token,
        pad_value=pad_value,
        do_mvc=config.GEPC,
        do_dab=True,
        use_batch_labels=True,
        num_batch_labels=num_batch_types,
        domain_spec_batchnorm=DSBN,
        n_input_bins=n_input_bins,
        ecs_threshold=config.ecs_thres,
        explicit_zero_prob=explicit_zero_prob,
        use_fast_transformer=config.fast_transformer,
        pre_norm=config.pre_norm,
    )
    if config.load_model is not None:
        try:
            model.load_state_dict(torch.load(model_file))
            logger.info(f"Loading all model params from {model_file}")
        except:
            # only load params that are in the model and match the size
            model_dict = model.state_dict()
            pretrained_dict = torch.load(model_file)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if k in model_dict and v.shape == model_dict[k].shape
            }
            for k, v in pretrained_dict.items():
                logger.info(f"Loading params {k} with shape {v.shape}")
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

    model.to(device)
    wandb.watch(model)

    # Set model criteria
    criterion = masked_mse_loss
    criterion_dab = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config.lr, eps=1e-4 if config.amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.schedule_ratio)
    scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

    # Train the model
    best_val_loss = float("inf")
    best_avg_bio = 0.0
    best_model = None
    define_wandb_metrcis()

    for epoch in range(1, config.epochs + 1):
        epoch_start_time = time.time()
        train_data_pt, valid_data_pt = prepare_data(sort_seq_batch=per_seq_batch_sample)
        train_loader = prepare_dataloader(
            train_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=True,
            drop_last=False,
        )
        valid_loader = prepare_dataloader(
            valid_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=False,
            drop_last=False,
        )

        if config.do_train:
            train(
                model,
                loader=train_loader,
            )
        val_loss, val_mre = evaluate(
            model,
            loader=valid_loader,
        )
        elapsed = time.time() - epoch_start_time
        logger.info("-" * 89)
        logger.info(
            f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
            f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}"
        )
        logger.info("-" * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model)
            best_model_epoch = epoch
            logger.info(f"Best model with score {best_val_loss:5.4f}")

        scheduler.step()
        
    # Extract embeddings
    best_model.eval()
    adata_t = adata_sorted
    adata_t = adata_t.copy()
    all_counts = (
        adata_t.layers[input_layer_key].A
        if issparse(adata_t.layers[input_layer_key])
        else adata_t.layers[input_layer_key]
    )
    batch_ids = adata_t.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)
    tokenized_all = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
    src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
        cell_embeddings = best_model.encode_batch(
            all_gene_ids,
            all_values.float(),
            src_key_padding_mask=src_key_padding_mask,
            batch_size=config.batch_size,
            batch_labels=torch.from_numpy(batch_ids).long() if DSBN else None,
            time_step=0,
            return_np=True,
        )
    cell_embeddings = cell_embeddings / np.linalg.norm(
        cell_embeddings, axis=1, keepdims=True
    )
    
    # Save the embeddings
    df = pd.DataFrame(cell_embeddings, index=adata_sorted.obs.index)
    file_out = latent_path + "_scGPT_" + str(rep) + ".txt"
    df.to_csv(file_out)
    
    # Clean up
    del adata
    del adata_sorted
    del adata_t
    del best_model
    del tokenized_all
    del all_counts
    del model
    del tokenized_train
    del tokenized_valid
    
    # End the logger and the run
    run.finish()
    wandb.finish()
    gc.collect()

    

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train/dab,▄▄▄▃▂▂▁▆▆▆▄▃▃▂▁▁▁█▇▆▆▅▄▂▂▁▆▆▆▅▄▂▂▁▁▁▁██▇
train/ecs,▂▂▂▁▁▅▅▄▅▄▄▅▄▃▃▃▄▄▆▇▇▅▄▅▅▆▆█▆█▆▆▆▆▆▆▆▆▇▆
train/mse,█▇█▇▆▆▆▇▇▆▆▇▆▆▆▆▆▅▄▄▃▃▂▂▂▂▃▂▂▂▂▂▂▁▂▁▁▁▁▁
train/mvc,█▇█▄▃▂▂▃▃▂▂▂▂▂▂▂▁▂▁▁▂▂▁▁▁▁▂▂▂▂▂▁▁▁▁▁▁▃▂▁
train/mvc_nzlp,██▇▃▄▃▂▆▅▃▃▂▂▂▂▂▁▂▂▂▂▂▂▂▂▁▅▅▄▃▂▁▁▁▁▁▁▂▂▁
train/nzlp,▃▃▂▂▃▂▃▃▂▂▃▃▂▃▂▃▂██▇█▇▆▅▅▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
train/dab,5.88318
train/ecs,9.68107
train/mse,118.59852
train/mvc,137.23468
train/mvc_nzlp,1.01268
train/nzlp,0.68715


{'seed': 42, 'dataset_name': 'Human_Liver', 'do_train': True, 'load_model': '/work/NMF_project/reproducibility/data/scGPT/Model/', 'GEPC': True, 'ecs_thres': 0.8, 'dab_weight': 1.0, 'mask_ratio': 0.4, 'epochs': 15, 'n_bins': 51, 'lr': 0.0001, 'batch_size': 64, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True}
save to /work/NMF_project/reproducibility/data/scGPT/FT/dev_Human_Liver_0
scGPT - INFO - match 24017/24334 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from /work/NMF_project/reproducibility/data/scGPT/Model/best_model.pt, the model args will be overriden by the config /work/NMF_project/reproducibility/data/scGPT/Model/args.json.
scGPT - INFO - train set number of samples: 7599, 
	 feature length: 1201
scGPT - INFO - valid set number of samples: 845, 
	 feature length: 1201
Use domain specific batchnorm with affine=False
scGPT 

 54%|█████▍    | 71/132 [00:05<00:05, 12.06it/s]

In [None]:
## Define variables
dataset_id = "Human_Pancreas"
adata_path = "/work/NMF_project/reproducibility/data/Pancreas/Human_Pancreas.h5ad"
hvg_path = "/work/NMF_project/reproducibility/data/Pancreas/Human_Pancreas.features"
latent_path = "/work/NMF_project/reproducibility/data/Pancreas/Human_Pancreas"

for rep in range(5):
    ## Set hyperparameters
    hyperparameter_defaults = dict(
        seed=42,
        dataset_name=dataset_id, # Dataset name
        do_train=True, # Flag to indicate whether to do update model parameters during training
        load_model="/work/NMF_project/reproducibility/data/scGPT/Model/", # Path to pre-trained model
        GEPC=True,  # Gene expression modelling for cell objective
        ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
        dab_weight=1.0, # DAR objective weight for batch correction
        mask_ratio=0.4, # Default mask ratio
        epochs=15, # Default number of epochs for fine-tuning
        n_bins=51, # Default number of bins for value binning in data pre-processing
        lr=1e-4, # Default learning rate for fine-tuning
        batch_size=64, # Default batch size for fine-tuning
        layer_size=128,
        nlayers=4,
        nhead=4, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
        dropout=0.2, # Default dropout rate during model fine-tuning
        schedule_ratio=0.9,  # Default rate for learning rate decay
        save_eval_interval=5, # Default model evaluation interval
        log_interval=100, # Default log interval
        fast_transformer=True, # Default setting
        pre_norm=False, # Default setting
        amp=True,  # # Default setting: Automatic Mixed Precision
    )

    ## Initialize the run on wandb
    run = wandb.init(
        config=hyperparameter_defaults,
        project="scGPT",
        reinit=True,
        settings=wandb.Settings(start_method="fork"),
    )
    config = wandb.config
    print(config)
    set_seed(config.seed)

    # Settings for input and preprocessing
    pad_token = "<pad>"
    special_tokens = [pad_token, "<cls>", "<eoc>"]
    mask_ratio = config.mask_ratio
    mask_value = -1
    pad_value = -2
    n_input_bins = config.n_bins
    n_hvg = 1200  # number of highly variable genes
    max_seq_len = n_hvg + 1
    per_seq_batch_sample = True
    DSBN = True  # Domain-spec batchnorm
    explicit_zero_prob = True  # whether explicit bernoulli for zeros

    # Settings for saving the model
    dataset_name = config.dataset_name
    save_dir = Path(f"/work/NMF_project/reproducibility/data/scGPT/FT/dev_{dataset_name}_{rep}/")
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"save to {save_dir}")
    logger = scg.logger
    scg.utils.add_file_handler(logger, save_dir / "run.log")

    # Load data
    adata = sc.read(adata_path)
    ori_batch_col = "batch_label"
    adata.var = adata.var.set_index("features")
    data_is_raw = True

    # make the batch category column
    adata.obs["str_batch"] = adata.obs[ori_batch_col].astype(str)
    batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
    adata.obs["batch_id"] = batch_id_labels
    adata.var["gene_name"] = adata.var.index.tolist()

    # Define HVGs
    f = open(hvg_path, "r")
    hvg = f.read().splitlines()
    f.close()
    adata.var['highly_variable'] = [True if g in hvg else False for g in adata.var_names]

    # Load the pretrained model
    if config.load_model is not None:
        model_dir = Path(config.load_model)
        model_config_file = model_dir / "args.json"
        model_file = model_dir / "best_model.pt"
        vocab_file = model_dir / "vocab.json"

        vocab = GeneVocab.from_file(vocab_file)
        for s in special_tokens:
            if s not in vocab:
                vocab.append_token(s)

        adata.var["id_in_vocab"] = [
            1 if gene in vocab else -1 for gene in adata.var["gene_name"]
        ]
        gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
        logger.info(
            f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
            f"in vocabulary of size {len(vocab)}."
        )
        adata = adata[:, adata.var["id_in_vocab"] >= 0]

        # model
        with open(model_config_file, "r") as f:
            model_configs = json.load(f)
        logger.info(
            f"Resume model from {model_file}, the model args will be overriden by the "
            f"config {model_config_file}."
        )
        embsize = model_configs["embsize"]
        nhead = model_configs["nheads"]
        d_hid = model_configs["d_hid"]
        nlayers = model_configs["nlayers"]
        n_layers_cls = model_configs["n_layers_cls"]
    else:
        embsize = config.layer_size
        nhead = config.nhead
        nlayers = config.nlayers
        d_hid = config.layer_size

    # Preprocess the dataset
    sc.pp.filter_genes(adata, min_counts=3)
    normed = sc.pp.normalize_total(adata, target_sum=1e4, layer=None, inplace=False)["X"]
    sc.get._set_obs_rep(adata, normed, layer="X_normed")
    sc.get._set_obs_rep(adata,sc.get._get_obs_rep(adata, layer="X_normed"), layer="X_log1p")
    sc.pp.log1p(adata, layer="X_log1p")
    adata = adata[:, adata.var.highly_variable]
    n_bins = config.n_bins  # NOTE: the first bin is always a spectial for zero
    binned_rows = []
    bin_edges = []
    layer_data = sc.get._get_obs_rep(adata, layer="X_log1p")
    layer_data = layer_data.A if issparse(layer_data) else layer_data
    for row in layer_data:
        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]
        bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
        non_zero_digits = _digitize(x = non_zero_row, bins = bins)
        assert non_zero_digits.min() >= 1
        assert non_zero_digits.max() <= n_bins - 1
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits
        binned_rows.append(binned_row)
        bin_edges.append(np.concatenate([[0], bins]))
    adata.layers["X_binned"] = np.stack(binned_rows)
    adata.obsm["bin_edges"] = np.stack(bin_edges)

    # Sort the adata by batch_id in advance
    if per_seq_batch_sample:
        adata_sorted = adata[adata.obs["batch_id"].argsort()].copy()

    # Define input layers and get counts
    input_layer_key = "X_binned"
    all_counts = (
        adata.layers[input_layer_key].A
        if issparse(adata.layers[input_layer_key])
        else adata.layers[input_layer_key]
    )
    genes = adata.var["gene_name"].tolist()

    # Get batch ids
    batch_ids = adata.obs["batch_id"].tolist()
    num_batch_types = len(set(batch_ids))
    batch_ids = np.array(batch_ids)

    # Create splits
    (
        train_data,
        valid_data,
        train_batch_labels,
        valid_batch_labels,
    ) = train_test_split(
        all_counts, batch_ids, test_size=0.1, shuffle=True
    )

    # Define vocabulary
    if config.load_model is None:
        vocab = Vocab(
            VocabPybind(genes + special_tokens, None)
        )  # bidirectional lookup [gene <-> int]
    vocab.set_default_index(vocab["<pad>"])
    gene_ids = np.array(vocab(genes), dtype=int)

    # Tokenize training and validation data
    tokenized_train = tokenize_and_pad_batch(
        train_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    tokenized_valid = tokenize_and_pad_batch(
        valid_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,
        include_zero_gene=True,
    )
    logger.info(
        f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
    )
    logger.info(
        f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
    )

    # Load the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ntokens = len(vocab)  # size of vocabulary
    model = TransformerModel(
        ntokens,
        embsize,
        nhead,
        d_hid,
        nlayers,
        vocab=vocab,
        dropout=config.dropout,
        pad_token=pad_token,
        pad_value=pad_value,
        do_mvc=config.GEPC,
        do_dab=True,
        use_batch_labels=True,
        num_batch_labels=num_batch_types,
        domain_spec_batchnorm=DSBN,
        n_input_bins=n_input_bins,
        ecs_threshold=config.ecs_thres,
        explicit_zero_prob=explicit_zero_prob,
        use_fast_transformer=config.fast_transformer,
        pre_norm=config.pre_norm,
    )
    if config.load_model is not None:
        try:
            model.load_state_dict(torch.load(model_file))
            logger.info(f"Loading all model params from {model_file}")
        except:
            # only load params that are in the model and match the size
            model_dict = model.state_dict()
            pretrained_dict = torch.load(model_file)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if k in model_dict and v.shape == model_dict[k].shape
            }
            for k, v in pretrained_dict.items():
                logger.info(f"Loading params {k} with shape {v.shape}")
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

    model.to(device)
    wandb.watch(model)

    # Set model criteria
    criterion = masked_mse_loss
    criterion_dab = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config.lr, eps=1e-4 if config.amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.schedule_ratio)
    scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

    # Train the model
    best_val_loss = float("inf")
    best_avg_bio = 0.0
    best_model = None
    define_wandb_metrcis()

    for epoch in range(1, config.epochs + 1):
        epoch_start_time = time.time()
        train_data_pt, valid_data_pt = prepare_data(sort_seq_batch=per_seq_batch_sample)
        train_loader = prepare_dataloader(
            train_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=True,
            drop_last=False,
        )
        valid_loader = prepare_dataloader(
            valid_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=False,
            drop_last=False,
        )

        if config.do_train:
            train(
                model,
                loader=train_loader,
            )
        val_loss, val_mre = evaluate(
            model,
            loader=valid_loader,
        )
        elapsed = time.time() - epoch_start_time
        logger.info("-" * 89)
        logger.info(
            f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
            f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}"
        )
        logger.info("-" * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model)
            best_model_epoch = epoch
            logger.info(f"Best model with score {best_val_loss:5.4f}")

        scheduler.step()
        
    # Extract embeddings
    best_model.eval()
    adata_t = adata_sorted
    adata_t = adata_t.copy()
    all_counts = (
        adata_t.layers[input_layer_key].A
        if issparse(adata_t.layers[input_layer_key])
        else adata_t.layers[input_layer_key]
    )
    batch_ids = adata_t.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)
    tokenized_all = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
    src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
        cell_embeddings = best_model.encode_batch(
            all_gene_ids,
            all_values.float(),
            src_key_padding_mask=src_key_padding_mask,
            batch_size=config.batch_size,
            batch_labels=torch.from_numpy(batch_ids).long() if DSBN else None,
            time_step=0,
            return_np=True,
        )
    cell_embeddings = cell_embeddings / np.linalg.norm(
        cell_embeddings, axis=1, keepdims=True
    )
    
    # Save the embeddings
    df = pd.DataFrame(cell_embeddings, index=adata_sorted.obs.index)
    file_out = latent_path + "_scGPT_" + str(rep) + ".txt"
    df.to_csv(file_out)
    
    # Clean up
    del adata
    del adata_sorted
    del adata_t
    del best_model
    del tokenized_all
    del all_counts
    del model
    del tokenized_train
    del tokenized_valid
    
    # End the logger and the run
    run.finish()
    wandb.finish()
    gc.collect()

    

In [None]:
## Define variables
dataset_id = "Human_Kidney"
adata_path = "/work/NMF_project/reproducibility/data/Kidney/Human_Kidney_sub.h5ad"
hvg_path = "/work/NMF_project/reproducibility/data/Kidney/Human_Kidney_sub.features"
latent_path = "/work/NMF_project/reproducibility/data/Kidney/Human_Kidney"

for rep in range(5):
    ## Set hyperparameters
    hyperparameter_defaults = dict(
        seed=42,
        dataset_name=dataset_id, # Dataset name
        do_train=True, # Flag to indicate whether to do update model parameters during training
        load_model="/work/NMF_project/reproducibility/data/scGPT/Model/", # Path to pre-trained model
        GEPC=True,  # Gene expression modelling for cell objective
        ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
        dab_weight=1.0, # DAR objective weight for batch correction
        mask_ratio=0.4, # Default mask ratio
        epochs=15, # Default number of epochs for fine-tuning
        n_bins=51, # Default number of bins for value binning in data pre-processing
        lr=1e-4, # Default learning rate for fine-tuning
        batch_size=64, # Default batch size for fine-tuning
        layer_size=128,
        nlayers=4,
        nhead=4, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
        dropout=0.2, # Default dropout rate during model fine-tuning
        schedule_ratio=0.9,  # Default rate for learning rate decay
        save_eval_interval=5, # Default model evaluation interval
        log_interval=100, # Default log interval
        fast_transformer=True, # Default setting
        pre_norm=False, # Default setting
        amp=True,  # # Default setting: Automatic Mixed Precision
    )

    ## Initialize the run on wandb
    run = wandb.init(
        config=hyperparameter_defaults,
        project="scGPT",
        reinit=True,
        settings=wandb.Settings(start_method="fork"),
    )
    config = wandb.config
    print(config)
    set_seed(config.seed)

    # Settings for input and preprocessing
    pad_token = "<pad>"
    special_tokens = [pad_token, "<cls>", "<eoc>"]
    mask_ratio = config.mask_ratio
    mask_value = -1
    pad_value = -2
    n_input_bins = config.n_bins
    n_hvg = 1200  # number of highly variable genes
    max_seq_len = n_hvg + 1
    per_seq_batch_sample = True
    DSBN = True  # Domain-spec batchnorm
    explicit_zero_prob = True  # whether explicit bernoulli for zeros

    # Settings for saving the model
    dataset_name = config.dataset_name
    save_dir = Path(f"/work/NMF_project/reproducibility/data/scGPT/FT/dev_{dataset_name}_{rep}/")
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"save to {save_dir}")
    logger = scg.logger
    scg.utils.add_file_handler(logger, save_dir / "run.log")

    # Load data
    adata = sc.read(adata_path)
    ori_batch_col = "batch_label"
    adata.var = adata.var.set_index("features")
    data_is_raw = True

    # make the batch category column
    adata.obs["str_batch"] = adata.obs[ori_batch_col].astype(str)
    batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
    adata.obs["batch_id"] = batch_id_labels
    adata.var["gene_name"] = adata.var.index.tolist()

    # Define HVGs
    f = open(hvg_path, "r")
    hvg = f.read().splitlines()
    f.close()
    adata.var['highly_variable'] = [True if g in hvg else False for g in adata.var_names]

    # Load the pretrained model
    if config.load_model is not None:
        model_dir = Path(config.load_model)
        model_config_file = model_dir / "args.json"
        model_file = model_dir / "best_model.pt"
        vocab_file = model_dir / "vocab.json"

        vocab = GeneVocab.from_file(vocab_file)
        for s in special_tokens:
            if s not in vocab:
                vocab.append_token(s)

        adata.var["id_in_vocab"] = [
            1 if gene in vocab else -1 for gene in adata.var["gene_name"]
        ]
        gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
        logger.info(
            f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
            f"in vocabulary of size {len(vocab)}."
        )
        adata = adata[:, adata.var["id_in_vocab"] >= 0]

        # model
        with open(model_config_file, "r") as f:
            model_configs = json.load(f)
        logger.info(
            f"Resume model from {model_file}, the model args will be overriden by the "
            f"config {model_config_file}."
        )
        embsize = model_configs["embsize"]
        nhead = model_configs["nheads"]
        d_hid = model_configs["d_hid"]
        nlayers = model_configs["nlayers"]
        n_layers_cls = model_configs["n_layers_cls"]
    else:
        embsize = config.layer_size
        nhead = config.nhead
        nlayers = config.nlayers
        d_hid = config.layer_size

    # Preprocess the dataset
    sc.pp.filter_genes(adata, min_counts=3)
    normed = sc.pp.normalize_total(adata, target_sum=1e4, layer=None, inplace=False)["X"]
    sc.get._set_obs_rep(adata, normed, layer="X_normed")
    sc.get._set_obs_rep(adata,sc.get._get_obs_rep(adata, layer="X_normed"), layer="X_log1p")
    sc.pp.log1p(adata, layer="X_log1p")
    adata = adata[:, adata.var.highly_variable]
    n_bins = config.n_bins  # NOTE: the first bin is always a spectial for zero
    binned_rows = []
    bin_edges = []
    layer_data = sc.get._get_obs_rep(adata, layer="X_log1p")
    layer_data = layer_data.A if issparse(layer_data) else layer_data
    for row in layer_data:
        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]
        bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
        non_zero_digits = _digitize(x = non_zero_row, bins = bins)
        assert non_zero_digits.min() >= 1
        assert non_zero_digits.max() <= n_bins - 1
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits
        binned_rows.append(binned_row)
        bin_edges.append(np.concatenate([[0], bins]))
    adata.layers["X_binned"] = np.stack(binned_rows)
    adata.obsm["bin_edges"] = np.stack(bin_edges)

    # Sort the adata by batch_id in advance
    if per_seq_batch_sample:
        adata_sorted = adata[adata.obs["batch_id"].argsort()].copy()

    # Define input layers and get counts
    input_layer_key = "X_binned"
    all_counts = (
        adata.layers[input_layer_key].A
        if issparse(adata.layers[input_layer_key])
        else adata.layers[input_layer_key]
    )
    genes = adata.var["gene_name"].tolist()

    # Get batch ids
    batch_ids = adata.obs["batch_id"].tolist()
    num_batch_types = len(set(batch_ids))
    batch_ids = np.array(batch_ids)

    # Create splits
    (
        train_data,
        valid_data,
        train_batch_labels,
        valid_batch_labels,
    ) = train_test_split(
        all_counts, batch_ids, test_size=0.1, shuffle=True
    )

    # Define vocabulary
    if config.load_model is None:
        vocab = Vocab(
            VocabPybind(genes + special_tokens, None)
        )  # bidirectional lookup [gene <-> int]
    vocab.set_default_index(vocab["<pad>"])
    gene_ids = np.array(vocab(genes), dtype=int)

    # Tokenize training and validation data
    tokenized_train = tokenize_and_pad_batch(
        train_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    tokenized_valid = tokenize_and_pad_batch(
        valid_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,
        include_zero_gene=True,
    )
    logger.info(
        f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
    )
    logger.info(
        f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
    )

    # Load the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ntokens = len(vocab)  # size of vocabulary
    model = TransformerModel(
        ntokens,
        embsize,
        nhead,
        d_hid,
        nlayers,
        vocab=vocab,
        dropout=config.dropout,
        pad_token=pad_token,
        pad_value=pad_value,
        do_mvc=config.GEPC,
        do_dab=True,
        use_batch_labels=True,
        num_batch_labels=num_batch_types,
        domain_spec_batchnorm=DSBN,
        n_input_bins=n_input_bins,
        ecs_threshold=config.ecs_thres,
        explicit_zero_prob=explicit_zero_prob,
        use_fast_transformer=config.fast_transformer,
        pre_norm=config.pre_norm,
    )
    if config.load_model is not None:
        try:
            model.load_state_dict(torch.load(model_file))
            logger.info(f"Loading all model params from {model_file}")
        except:
            # only load params that are in the model and match the size
            model_dict = model.state_dict()
            pretrained_dict = torch.load(model_file)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if k in model_dict and v.shape == model_dict[k].shape
            }
            for k, v in pretrained_dict.items():
                logger.info(f"Loading params {k} with shape {v.shape}")
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

    model.to(device)
    wandb.watch(model)

    # Set model criteria
    criterion = masked_mse_loss
    criterion_dab = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config.lr, eps=1e-4 if config.amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.schedule_ratio)
    scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

    # Train the model
    best_val_loss = float("inf")
    best_avg_bio = 0.0
    best_model = None
    define_wandb_metrcis()

    for epoch in range(1, config.epochs + 1):
        epoch_start_time = time.time()
        train_data_pt, valid_data_pt = prepare_data(sort_seq_batch=per_seq_batch_sample)
        train_loader = prepare_dataloader(
            train_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=True,
            drop_last=False,
        )
        valid_loader = prepare_dataloader(
            valid_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=False,
            drop_last=False,
        )

        if config.do_train:
            train(
                model,
                loader=train_loader,
            )
        val_loss, val_mre = evaluate(
            model,
            loader=valid_loader,
        )
        elapsed = time.time() - epoch_start_time
        logger.info("-" * 89)
        logger.info(
            f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
            f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}"
        )
        logger.info("-" * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model)
            best_model_epoch = epoch
            logger.info(f"Best model with score {best_val_loss:5.4f}")

        scheduler.step()
        
    # Extract embeddings
    best_model.eval()
    adata_t = adata_sorted
    adata_t = adata_t.copy()
    all_counts = (
        adata_t.layers[input_layer_key].A
        if issparse(adata_t.layers[input_layer_key])
        else adata_t.layers[input_layer_key]
    )
    batch_ids = adata_t.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)
    tokenized_all = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
    src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
        cell_embeddings = best_model.encode_batch(
            all_gene_ids,
            all_values.float(),
            src_key_padding_mask=src_key_padding_mask,
            batch_size=config.batch_size,
            batch_labels=torch.from_numpy(batch_ids).long() if DSBN else None,
            time_step=0,
            return_np=True,
        )
    cell_embeddings = cell_embeddings / np.linalg.norm(
        cell_embeddings, axis=1, keepdims=True
    )
    
    # Save the embeddings
    df = pd.DataFrame(cell_embeddings, index=adata_sorted.obs.index)
    file_out = latent_path + "_scGPT_" + str(rep) + ".txt"
    df.to_csv(file_out)
    
    # Clean up
    del adata
    del adata_sorted
    del adata_t
    del best_model
    del tokenized_all
    del all_counts
    del model
    del tokenized_train
    del tokenized_valid
    
    # End the logger and the run
    run.finish()
    wandb.finish()
    gc.collect()

    

In [None]:
## Define variables
dataset_id = "PBMC"
adata_path = "/work/NMF_project/reproducibility/data/PBMC/PBMC.h5ad"
hvg_path = "/work/NMF_project/reproducibility/data/PBMC/PBMC.features"
latent_path = "/work/NMF_project/reproducibility/data/PBMC/PBMC"

for rep in range(5):
    ## Set hyperparameters
    hyperparameter_defaults = dict(
        seed=42,
        dataset_name=dataset_id, # Dataset name
        do_train=True, # Flag to indicate whether to do update model parameters during training
        load_model="/work/NMF_project/reproducibility/data/scGPT/Model/", # Path to pre-trained model
        GEPC=True,  # Gene expression modelling for cell objective
        ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
        dab_weight=1.0, # DAR objective weight for batch correction
        mask_ratio=0.4, # Default mask ratio
        epochs=15, # Default number of epochs for fine-tuning
        n_bins=51, # Default number of bins for value binning in data pre-processing
        lr=1e-4, # Default learning rate for fine-tuning
        batch_size=64, # Default batch size for fine-tuning
        layer_size=128,
        nlayers=4,
        nhead=4, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
        dropout=0.2, # Default dropout rate during model fine-tuning
        schedule_ratio=0.9,  # Default rate for learning rate decay
        save_eval_interval=5, # Default model evaluation interval
        log_interval=100, # Default log interval
        fast_transformer=True, # Default setting
        pre_norm=False, # Default setting
        amp=True,  # # Default setting: Automatic Mixed Precision
    )

    ## Initialize the run on wandb
    run = wandb.init(
        config=hyperparameter_defaults,
        project="scGPT",
        reinit=True,
        settings=wandb.Settings(start_method="fork"),
    )
    config = wandb.config
    print(config)
    set_seed(config.seed)

    # Settings for input and preprocessing
    pad_token = "<pad>"
    special_tokens = [pad_token, "<cls>", "<eoc>"]
    mask_ratio = config.mask_ratio
    mask_value = -1
    pad_value = -2
    n_input_bins = config.n_bins
    n_hvg = 1200  # number of highly variable genes
    max_seq_len = n_hvg + 1
    per_seq_batch_sample = True
    DSBN = True  # Domain-spec batchnorm
    explicit_zero_prob = True  # whether explicit bernoulli for zeros

    # Settings for saving the model
    dataset_name = config.dataset_name
    save_dir = Path(f"/work/NMF_project/reproducibility/data/scGPT/FT/dev_{dataset_name}_{rep}/")
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"save to {save_dir}")
    logger = scg.logger
    scg.utils.add_file_handler(logger, save_dir / "run.log")

    # Load data
    adata = sc.read(adata_path)
    ori_batch_col = "batch_label"
    adata.var = adata.var.set_index("features")
    data_is_raw = True

    # make the batch category column
    adata.obs["str_batch"] = adata.obs[ori_batch_col].astype(str)
    batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
    adata.obs["batch_id"] = batch_id_labels
    adata.var["gene_name"] = adata.var.index.tolist()

    # Define HVGs
    f = open(hvg_path, "r")
    hvg = f.read().splitlines()
    f.close()
    adata.var['highly_variable'] = [True if g in hvg else False for g in adata.var_names]

    # Load the pretrained model
    if config.load_model is not None:
        model_dir = Path(config.load_model)
        model_config_file = model_dir / "args.json"
        model_file = model_dir / "best_model.pt"
        vocab_file = model_dir / "vocab.json"

        vocab = GeneVocab.from_file(vocab_file)
        for s in special_tokens:
            if s not in vocab:
                vocab.append_token(s)

        adata.var["id_in_vocab"] = [
            1 if gene in vocab else -1 for gene in adata.var["gene_name"]
        ]
        gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
        logger.info(
            f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
            f"in vocabulary of size {len(vocab)}."
        )
        adata = adata[:, adata.var["id_in_vocab"] >= 0]

        # model
        with open(model_config_file, "r") as f:
            model_configs = json.load(f)
        logger.info(
            f"Resume model from {model_file}, the model args will be overriden by the "
            f"config {model_config_file}."
        )
        embsize = model_configs["embsize"]
        nhead = model_configs["nheads"]
        d_hid = model_configs["d_hid"]
        nlayers = model_configs["nlayers"]
        n_layers_cls = model_configs["n_layers_cls"]
    else:
        embsize = config.layer_size
        nhead = config.nhead
        nlayers = config.nlayers
        d_hid = config.layer_size

    # Preprocess the dataset
    sc.pp.filter_genes(adata, min_counts=3)
    normed = sc.pp.normalize_total(adata, target_sum=1e4, layer=None, inplace=False)["X"]
    sc.get._set_obs_rep(adata, normed, layer="X_normed")
    sc.get._set_obs_rep(adata,sc.get._get_obs_rep(adata, layer="X_normed"), layer="X_log1p")
    sc.pp.log1p(adata, layer="X_log1p")
    adata = adata[:, adata.var.highly_variable]
    n_bins = config.n_bins  # NOTE: the first bin is always a spectial for zero
    binned_rows = []
    bin_edges = []
    layer_data = sc.get._get_obs_rep(adata, layer="X_log1p")
    layer_data = layer_data.A if issparse(layer_data) else layer_data
    for row in layer_data:
        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]
        bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
        non_zero_digits = _digitize(x = non_zero_row, bins = bins)
        assert non_zero_digits.min() >= 1
        assert non_zero_digits.max() <= n_bins - 1
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits
        binned_rows.append(binned_row)
        bin_edges.append(np.concatenate([[0], bins]))
    adata.layers["X_binned"] = np.stack(binned_rows)
    adata.obsm["bin_edges"] = np.stack(bin_edges)

    # Sort the adata by batch_id in advance
    if per_seq_batch_sample:
        adata_sorted = adata[adata.obs["batch_id"].argsort()].copy()

    # Define input layers and get counts
    input_layer_key = "X_binned"
    all_counts = (
        adata.layers[input_layer_key].A
        if issparse(adata.layers[input_layer_key])
        else adata.layers[input_layer_key]
    )
    genes = adata.var["gene_name"].tolist()

    # Get batch ids
    batch_ids = adata.obs["batch_id"].tolist()
    num_batch_types = len(set(batch_ids))
    batch_ids = np.array(batch_ids)

    # Create splits
    (
        train_data,
        valid_data,
        train_batch_labels,
        valid_batch_labels,
    ) = train_test_split(
        all_counts, batch_ids, test_size=0.1, shuffle=True
    )

    # Define vocabulary
    if config.load_model is None:
        vocab = Vocab(
            VocabPybind(genes + special_tokens, None)
        )  # bidirectional lookup [gene <-> int]
    vocab.set_default_index(vocab["<pad>"])
    gene_ids = np.array(vocab(genes), dtype=int)

    # Tokenize training and validation data
    tokenized_train = tokenize_and_pad_batch(
        train_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    tokenized_valid = tokenize_and_pad_batch(
        valid_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,
        include_zero_gene=True,
    )
    logger.info(
        f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
    )
    logger.info(
        f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
    )

    # Load the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ntokens = len(vocab)  # size of vocabulary
    model = TransformerModel(
        ntokens,
        embsize,
        nhead,
        d_hid,
        nlayers,
        vocab=vocab,
        dropout=config.dropout,
        pad_token=pad_token,
        pad_value=pad_value,
        do_mvc=config.GEPC,
        do_dab=True,
        use_batch_labels=True,
        num_batch_labels=num_batch_types,
        domain_spec_batchnorm=DSBN,
        n_input_bins=n_input_bins,
        ecs_threshold=config.ecs_thres,
        explicit_zero_prob=explicit_zero_prob,
        use_fast_transformer=config.fast_transformer,
        pre_norm=config.pre_norm,
    )
    if config.load_model is not None:
        try:
            model.load_state_dict(torch.load(model_file))
            logger.info(f"Loading all model params from {model_file}")
        except:
            # only load params that are in the model and match the size
            model_dict = model.state_dict()
            pretrained_dict = torch.load(model_file)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if k in model_dict and v.shape == model_dict[k].shape
            }
            for k, v in pretrained_dict.items():
                logger.info(f"Loading params {k} with shape {v.shape}")
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

    model.to(device)
    wandb.watch(model)

    # Set model criteria
    criterion = masked_mse_loss
    criterion_dab = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config.lr, eps=1e-4 if config.amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.schedule_ratio)
    scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

    # Train the model
    best_val_loss = float("inf")
    best_avg_bio = 0.0
    best_model = None
    define_wandb_metrcis()

    for epoch in range(1, config.epochs + 1):
        epoch_start_time = time.time()
        train_data_pt, valid_data_pt = prepare_data(sort_seq_batch=per_seq_batch_sample)
        train_loader = prepare_dataloader(
            train_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=True,
            drop_last=False,
        )
        valid_loader = prepare_dataloader(
            valid_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=False,
            drop_last=False,
        )

        if config.do_train:
            train(
                model,
                loader=train_loader,
            )
        val_loss, val_mre = evaluate(
            model,
            loader=valid_loader,
        )
        elapsed = time.time() - epoch_start_time
        logger.info("-" * 89)
        logger.info(
            f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
            f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}"
        )
        logger.info("-" * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model)
            best_model_epoch = epoch
            logger.info(f"Best model with score {best_val_loss:5.4f}")

        scheduler.step()
        
    # Extract embeddings
    best_model.eval()
    adata_t = adata_sorted
    adata_t = adata_t.copy()
    all_counts = (
        adata_t.layers[input_layer_key].A
        if issparse(adata_t.layers[input_layer_key])
        else adata_t.layers[input_layer_key]
    )
    batch_ids = adata_t.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)
    tokenized_all = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
    src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
        cell_embeddings = best_model.encode_batch(
            all_gene_ids,
            all_values.float(),
            src_key_padding_mask=src_key_padding_mask,
            batch_size=config.batch_size,
            batch_labels=torch.from_numpy(batch_ids).long() if DSBN else None,
            time_step=0,
            return_np=True,
        )
    cell_embeddings = cell_embeddings / np.linalg.norm(
        cell_embeddings, axis=1, keepdims=True
    )
    
    # Save the embeddings
    df = pd.DataFrame(cell_embeddings, index=adata_sorted.obs.index)
    file_out = latent_path + "_scGPT_" + str(rep) + ".txt"
    df.to_csv(file_out)
    
    # Clean up
    del adata
    del adata_sorted
    del adata_t
    del best_model
    del tokenized_all
    del all_counts
    del model
    del tokenized_train
    del tokenized_valid
    
    # End the logger and the run
    run.finish()
    wandb.finish()
    gc.collect()

    

In [None]:
## Define variables
dataset_id = "Human_Lung"
adata_path = "/work/NMF_project/reproducibility/data/Lung/Human_Lung_sub.h5ad"
hvg_path = "/work/NMF_project/reproducibility/data/Lung/Human_Lung_sub.features"
latent_path = "/work/NMF_project/reproducibility/data/Lung/Human_Lung"

for rep in range(5):
    ## Set hyperparameters
    hyperparameter_defaults = dict(
        seed=42,
        dataset_name=dataset_id, # Dataset name
        do_train=True, # Flag to indicate whether to do update model parameters during training
        load_model="/work/NMF_project/reproducibility/data/scGPT/Model/", # Path to pre-trained model
        GEPC=True,  # Gene expression modelling for cell objective
        ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
        dab_weight=1.0, # DAR objective weight for batch correction
        mask_ratio=0.4, # Default mask ratio
        epochs=15, # Default number of epochs for fine-tuning
        n_bins=51, # Default number of bins for value binning in data pre-processing
        lr=1e-4, # Default learning rate for fine-tuning
        batch_size=64, # Default batch size for fine-tuning
        layer_size=128,
        nlayers=4,
        nhead=4, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
        dropout=0.2, # Default dropout rate during model fine-tuning
        schedule_ratio=0.9,  # Default rate for learning rate decay
        save_eval_interval=5, # Default model evaluation interval
        log_interval=100, # Default log interval
        fast_transformer=True, # Default setting
        pre_norm=False, # Default setting
        amp=True,  # # Default setting: Automatic Mixed Precision
    )

    ## Initialize the run on wandb
    run = wandb.init(
        config=hyperparameter_defaults,
        project="scGPT",
        reinit=True,
        settings=wandb.Settings(start_method="fork"),
    )
    config = wandb.config
    print(config)
    set_seed(config.seed)

    # Settings for input and preprocessing
    pad_token = "<pad>"
    special_tokens = [pad_token, "<cls>", "<eoc>"]
    mask_ratio = config.mask_ratio
    mask_value = -1
    pad_value = -2
    n_input_bins = config.n_bins
    n_hvg = 1200  # number of highly variable genes
    max_seq_len = n_hvg + 1
    per_seq_batch_sample = True
    DSBN = True  # Domain-spec batchnorm
    explicit_zero_prob = True  # whether explicit bernoulli for zeros

    # Settings for saving the model
    dataset_name = config.dataset_name
    save_dir = Path(f"/work/NMF_project/reproducibility/data/scGPT/FT/dev_{dataset_name}_{rep}/")
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"save to {save_dir}")
    logger = scg.logger
    scg.utils.add_file_handler(logger, save_dir / "run.log")

    # Load data
    adata = sc.read(adata_path)
    ori_batch_col = "batch_label"
    adata.var = adata.var.set_index("features")
    data_is_raw = True

    # make the batch category column
    adata.obs["str_batch"] = adata.obs[ori_batch_col].astype(str)
    batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
    adata.obs["batch_id"] = batch_id_labels
    adata.var["gene_name"] = adata.var.index.tolist()

    # Define HVGs
    f = open(hvg_path, "r")
    hvg = f.read().splitlines()
    f.close()
    adata.var['highly_variable'] = [True if g in hvg else False for g in adata.var_names]

    # Load the pretrained model
    if config.load_model is not None:
        model_dir = Path(config.load_model)
        model_config_file = model_dir / "args.json"
        model_file = model_dir / "best_model.pt"
        vocab_file = model_dir / "vocab.json"

        vocab = GeneVocab.from_file(vocab_file)
        for s in special_tokens:
            if s not in vocab:
                vocab.append_token(s)

        adata.var["id_in_vocab"] = [
            1 if gene in vocab else -1 for gene in adata.var["gene_name"]
        ]
        gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
        logger.info(
            f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
            f"in vocabulary of size {len(vocab)}."
        )
        adata = adata[:, adata.var["id_in_vocab"] >= 0]

        # model
        with open(model_config_file, "r") as f:
            model_configs = json.load(f)
        logger.info(
            f"Resume model from {model_file}, the model args will be overriden by the "
            f"config {model_config_file}."
        )
        embsize = model_configs["embsize"]
        nhead = model_configs["nheads"]
        d_hid = model_configs["d_hid"]
        nlayers = model_configs["nlayers"]
        n_layers_cls = model_configs["n_layers_cls"]
    else:
        embsize = config.layer_size
        nhead = config.nhead
        nlayers = config.nlayers
        d_hid = config.layer_size

    # Preprocess the dataset
    sc.pp.filter_genes(adata, min_counts=3)
    normed = sc.pp.normalize_total(adata, target_sum=1e4, layer=None, inplace=False)["X"]
    sc.get._set_obs_rep(adata, normed, layer="X_normed")
    sc.get._set_obs_rep(adata,sc.get._get_obs_rep(adata, layer="X_normed"), layer="X_log1p")
    sc.pp.log1p(adata, layer="X_log1p")
    adata = adata[:, adata.var.highly_variable]
    n_bins = config.n_bins  # NOTE: the first bin is always a spectial for zero
    binned_rows = []
    bin_edges = []
    layer_data = sc.get._get_obs_rep(adata, layer="X_log1p")
    layer_data = layer_data.A if issparse(layer_data) else layer_data
    for row in layer_data:
        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]
        bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
        non_zero_digits = _digitize(x = non_zero_row, bins = bins)
        assert non_zero_digits.min() >= 1
        assert non_zero_digits.max() <= n_bins - 1
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits
        binned_rows.append(binned_row)
        bin_edges.append(np.concatenate([[0], bins]))
    adata.layers["X_binned"] = np.stack(binned_rows)
    adata.obsm["bin_edges"] = np.stack(bin_edges)

    # Sort the adata by batch_id in advance
    if per_seq_batch_sample:
        adata_sorted = adata[adata.obs["batch_id"].argsort()].copy()

    # Define input layers and get counts
    input_layer_key = "X_binned"
    all_counts = (
        adata.layers[input_layer_key].A
        if issparse(adata.layers[input_layer_key])
        else adata.layers[input_layer_key]
    )
    genes = adata.var["gene_name"].tolist()

    # Get batch ids
    batch_ids = adata.obs["batch_id"].tolist()
    num_batch_types = len(set(batch_ids))
    batch_ids = np.array(batch_ids)

    # Create splits
    (
        train_data,
        valid_data,
        train_batch_labels,
        valid_batch_labels,
    ) = train_test_split(
        all_counts, batch_ids, test_size=0.1, shuffle=True
    )

    # Define vocabulary
    if config.load_model is None:
        vocab = Vocab(
            VocabPybind(genes + special_tokens, None)
        )  # bidirectional lookup [gene <-> int]
    vocab.set_default_index(vocab["<pad>"])
    gene_ids = np.array(vocab(genes), dtype=int)

    # Tokenize training and validation data
    tokenized_train = tokenize_and_pad_batch(
        train_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    tokenized_valid = tokenize_and_pad_batch(
        valid_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,
        include_zero_gene=True,
    )
    logger.info(
        f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
    )
    logger.info(
        f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
    )

    # Load the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ntokens = len(vocab)  # size of vocabulary
    model = TransformerModel(
        ntokens,
        embsize,
        nhead,
        d_hid,
        nlayers,
        vocab=vocab,
        dropout=config.dropout,
        pad_token=pad_token,
        pad_value=pad_value,
        do_mvc=config.GEPC,
        do_dab=True,
        use_batch_labels=True,
        num_batch_labels=num_batch_types,
        domain_spec_batchnorm=DSBN,
        n_input_bins=n_input_bins,
        ecs_threshold=config.ecs_thres,
        explicit_zero_prob=explicit_zero_prob,
        use_fast_transformer=config.fast_transformer,
        pre_norm=config.pre_norm,
    )
    if config.load_model is not None:
        try:
            model.load_state_dict(torch.load(model_file))
            logger.info(f"Loading all model params from {model_file}")
        except:
            # only load params that are in the model and match the size
            model_dict = model.state_dict()
            pretrained_dict = torch.load(model_file)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if k in model_dict and v.shape == model_dict[k].shape
            }
            for k, v in pretrained_dict.items():
                logger.info(f"Loading params {k} with shape {v.shape}")
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

    model.to(device)
    wandb.watch(model)

    # Set model criteria
    criterion = masked_mse_loss
    criterion_dab = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config.lr, eps=1e-4 if config.amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.schedule_ratio)
    scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

    # Train the model
    best_val_loss = float("inf")
    best_avg_bio = 0.0
    best_model = None
    define_wandb_metrcis()

    for epoch in range(1, config.epochs + 1):
        epoch_start_time = time.time()
        train_data_pt, valid_data_pt = prepare_data(sort_seq_batch=per_seq_batch_sample)
        train_loader = prepare_dataloader(
            train_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=True,
            drop_last=False,
        )
        valid_loader = prepare_dataloader(
            valid_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=False,
            drop_last=False,
        )

        if config.do_train:
            train(
                model,
                loader=train_loader,
            )
        val_loss, val_mre = evaluate(
            model,
            loader=valid_loader,
        )
        elapsed = time.time() - epoch_start_time
        logger.info("-" * 89)
        logger.info(
            f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
            f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}"
        )
        logger.info("-" * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model)
            best_model_epoch = epoch
            logger.info(f"Best model with score {best_val_loss:5.4f}")

        scheduler.step()
        
    # Extract embeddings
    best_model.eval()
    adata_t = adata_sorted
    adata_t = adata_t.copy()
    all_counts = (
        adata_t.layers[input_layer_key].A
        if issparse(adata_t.layers[input_layer_key])
        else adata_t.layers[input_layer_key]
    )
    batch_ids = adata_t.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)
    tokenized_all = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
    src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
        cell_embeddings = best_model.encode_batch(
            all_gene_ids,
            all_values.float(),
            src_key_padding_mask=src_key_padding_mask,
            batch_size=config.batch_size,
            batch_labels=torch.from_numpy(batch_ids).long() if DSBN else None,
            time_step=0,
            return_np=True,
        )
    cell_embeddings = cell_embeddings / np.linalg.norm(
        cell_embeddings, axis=1, keepdims=True
    )
    
    # Save the embeddings
    df = pd.DataFrame(cell_embeddings, index=adata_sorted.obs.index)
    file_out = latent_path + "_scGPT_" + str(rep) + ".txt"
    df.to_csv(file_out)
    
    # Clean up
    del adata
    del adata_sorted
    del adata_t
    del best_model
    del tokenized_all
    del all_counts
    del model
    del tokenized_train
    del tokenized_valid
    
    # End the logger and the run
    run.finish()
    wandb.finish()
    gc.collect()

    

In [5]:
## Define variables
dataset_id = "Mixture"
adata_path = "/work/NMF_project/reproducibility/data/Mixture/mix.h5ad"
hvg_path = "/work/NMF_project/reproducibility/data/Mixture/mix.features"
latent_path = "/work/NMF_project/reproducibility/data/Mixture/mix"

for rep in range(5):
    ## Set hyperparameters
    hyperparameter_defaults = dict(
        seed=42,
        dataset_name=dataset_id, # Dataset name
        do_train=True, # Flag to indicate whether to do update model parameters during training
        load_model="/work/NMF_project/reproducibility/data/scGPT/Model/", # Path to pre-trained model
        GEPC=True,  # Gene expression modelling for cell objective
        ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
        dab_weight=1.0, # DAR objective weight for batch correction
        mask_ratio=0.4, # Default mask ratio
        epochs=15, # Default number of epochs for fine-tuning
        n_bins=51, # Default number of bins for value binning in data pre-processing
        lr=1e-4, # Default learning rate for fine-tuning
        batch_size=64, # Default batch size for fine-tuning
        layer_size=128,
        nlayers=4,
        nhead=4, # if load model, batch_size, layer_size, nlayers, nhead will be ignored
        dropout=0.2, # Default dropout rate during model fine-tuning
        schedule_ratio=0.9,  # Default rate for learning rate decay
        save_eval_interval=5, # Default model evaluation interval
        log_interval=100, # Default log interval
        fast_transformer=True, # Default setting
        pre_norm=False, # Default setting
        amp=True,  # # Default setting: Automatic Mixed Precision
    )

    ## Initialize the run on wandb
    run = wandb.init(
        config=hyperparameter_defaults,
        project="scGPT",
        reinit=True,
        settings=wandb.Settings(start_method="fork"),
    )
    config = wandb.config
    print(config)
    set_seed(config.seed)

    # Settings for input and preprocessing
    pad_token = "<pad>"
    special_tokens = [pad_token, "<cls>", "<eoc>"]
    mask_ratio = config.mask_ratio
    mask_value = -1
    pad_value = -2
    n_input_bins = config.n_bins
    n_hvg = 1200  # number of highly variable genes
    max_seq_len = n_hvg + 1
    per_seq_batch_sample = True
    DSBN = True  # Domain-spec batchnorm
    explicit_zero_prob = True  # whether explicit bernoulli for zeros

    # Settings for saving the model
    dataset_name = config.dataset_name
    save_dir = Path(f"/work/NMF_project/reproducibility/data/scGPT/FT/dev_{dataset_name}_{rep}/")
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"save to {save_dir}")
    logger = scg.logger
    scg.utils.add_file_handler(logger, save_dir / "run.log")

    # Load data
    adata = sc.read(adata_path)
    ori_batch_col = "batch_label"
    adata.var = adata.var.set_index("features")
    data_is_raw = True

    # make the batch category column
    adata.obs["str_batch"] = adata.obs[ori_batch_col].astype(str)
    batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values
    adata.obs["batch_id"] = batch_id_labels
    adata.var["gene_name"] = adata.var.index.tolist()

    # Define HVGs
    f = open(hvg_path, "r")
    hvg = f.read().splitlines()
    f.close()
    adata.var['highly_variable'] = [True if g in hvg else False for g in adata.var_names]

    # Load the pretrained model
    if config.load_model is not None:
        model_dir = Path(config.load_model)
        model_config_file = model_dir / "args.json"
        model_file = model_dir / "best_model.pt"
        vocab_file = model_dir / "vocab.json"

        vocab = GeneVocab.from_file(vocab_file)
        for s in special_tokens:
            if s not in vocab:
                vocab.append_token(s)

        adata.var["id_in_vocab"] = [
            1 if gene in vocab else -1 for gene in adata.var["gene_name"]
        ]
        gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
        logger.info(
            f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes "
            f"in vocabulary of size {len(vocab)}."
        )
        adata = adata[:, adata.var["id_in_vocab"] >= 0]

        # model
        with open(model_config_file, "r") as f:
            model_configs = json.load(f)
        logger.info(
            f"Resume model from {model_file}, the model args will be overriden by the "
            f"config {model_config_file}."
        )
        embsize = model_configs["embsize"]
        nhead = model_configs["nheads"]
        d_hid = model_configs["d_hid"]
        nlayers = model_configs["nlayers"]
        n_layers_cls = model_configs["n_layers_cls"]
    else:
        embsize = config.layer_size
        nhead = config.nhead
        nlayers = config.nlayers
        d_hid = config.layer_size

    # Preprocess the dataset
    sc.pp.filter_genes(adata, min_counts=3)
    normed = sc.pp.normalize_total(adata, target_sum=1e4, layer=None, inplace=False)["X"]
    sc.get._set_obs_rep(adata, normed, layer="X_normed")
    sc.get._set_obs_rep(adata,sc.get._get_obs_rep(adata, layer="X_normed"), layer="X_log1p")
    sc.pp.log1p(adata, layer="X_log1p")
    adata = adata[:, adata.var.highly_variable]
    n_bins = config.n_bins  # NOTE: the first bin is always a spectial for zero
    binned_rows = []
    bin_edges = []
    layer_data = sc.get._get_obs_rep(adata, layer="X_log1p")
    layer_data = layer_data.A if issparse(layer_data) else layer_data
    for row in layer_data:
        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]
        bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
        non_zero_digits = _digitize(x = non_zero_row, bins = bins)
        assert non_zero_digits.min() >= 1
        assert non_zero_digits.max() <= n_bins - 1
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits
        binned_rows.append(binned_row)
        bin_edges.append(np.concatenate([[0], bins]))
    adata.layers["X_binned"] = np.stack(binned_rows)
    adata.obsm["bin_edges"] = np.stack(bin_edges)

    # Sort the adata by batch_id in advance
    if per_seq_batch_sample:
        adata_sorted = adata[adata.obs["batch_id"].argsort()].copy()

    # Define input layers and get counts
    input_layer_key = "X_binned"
    all_counts = (
        adata.layers[input_layer_key].A
        if issparse(adata.layers[input_layer_key])
        else adata.layers[input_layer_key]
    )
    genes = adata.var["gene_name"].tolist()

    # Get batch ids
    batch_ids = adata.obs["batch_id"].tolist()
    num_batch_types = len(set(batch_ids))
    batch_ids = np.array(batch_ids)

    # Create splits
    (
        train_data,
        valid_data,
        train_batch_labels,
        valid_batch_labels,
    ) = train_test_split(
        all_counts, batch_ids, test_size=0.1, shuffle=True
    )

    # Define vocabulary
    if config.load_model is None:
        vocab = Vocab(
            VocabPybind(genes + special_tokens, None)
        )  # bidirectional lookup [gene <-> int]
    vocab.set_default_index(vocab["<pad>"])
    gene_ids = np.array(vocab(genes), dtype=int)

    # Tokenize training and validation data
    tokenized_train = tokenize_and_pad_batch(
        train_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    tokenized_valid = tokenize_and_pad_batch(
        valid_data,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,
        include_zero_gene=True,
    )
    logger.info(
        f"train set number of samples: {tokenized_train['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_train['genes'].shape[1]}"
    )
    logger.info(
        f"valid set number of samples: {tokenized_valid['genes'].shape[0]}, "
        f"\n\t feature length: {tokenized_valid['genes'].shape[1]}"
    )

    # Load the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ntokens = len(vocab)  # size of vocabulary
    model = TransformerModel(
        ntokens,
        embsize,
        nhead,
        d_hid,
        nlayers,
        vocab=vocab,
        dropout=config.dropout,
        pad_token=pad_token,
        pad_value=pad_value,
        do_mvc=config.GEPC,
        do_dab=True,
        use_batch_labels=True,
        num_batch_labels=num_batch_types,
        domain_spec_batchnorm=DSBN,
        n_input_bins=n_input_bins,
        ecs_threshold=config.ecs_thres,
        explicit_zero_prob=explicit_zero_prob,
        use_fast_transformer=config.fast_transformer,
        pre_norm=config.pre_norm,
    )
    if config.load_model is not None:
        try:
            model.load_state_dict(torch.load(model_file))
            logger.info(f"Loading all model params from {model_file}")
        except:
            # only load params that are in the model and match the size
            model_dict = model.state_dict()
            pretrained_dict = torch.load(model_file)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items()
                if k in model_dict and v.shape == model_dict[k].shape
            }
            for k, v in pretrained_dict.items():
                logger.info(f"Loading params {k} with shape {v.shape}")
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

    model.to(device)
    wandb.watch(model)

    # Set model criteria
    criterion = masked_mse_loss
    criterion_dab = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config.lr, eps=1e-4 if config.amp else 1e-8
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.schedule_ratio)
    scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

    # Train the model
    best_val_loss = float("inf")
    best_avg_bio = 0.0
    best_model = None
    define_wandb_metrcis()

    for epoch in range(1, config.epochs + 1):
        epoch_start_time = time.time()
        train_data_pt, valid_data_pt = prepare_data(sort_seq_batch=per_seq_batch_sample)
        train_loader = prepare_dataloader(
            train_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=True,
            drop_last=False,
        )
        valid_loader = prepare_dataloader(
            valid_data_pt,
            batch_size=config.batch_size,
            shuffle=False,
            intra_domain_shuffle=False,
            drop_last=False,
        )

        if config.do_train:
            train(
                model,
                loader=train_loader,
            )
        val_loss, val_mre = evaluate(
            model,
            loader=valid_loader,
        )
        elapsed = time.time() - epoch_start_time
        logger.info("-" * 89)
        logger.info(
            f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
            f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}"
        )
        logger.info("-" * 89)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = copy.deepcopy(model)
            best_model_epoch = epoch
            logger.info(f"Best model with score {best_val_loss:5.4f}")

        scheduler.step()
        
    # Extract embeddings
    best_model.eval()
    adata_t = adata_sorted
    adata_t = adata_t.copy()
    all_counts = (
        adata_t.layers[input_layer_key].A
        if issparse(adata_t.layers[input_layer_key])
        else adata_t.layers[input_layer_key]
    )
    batch_ids = adata_t.obs["batch_id"].tolist()
    batch_ids = np.array(batch_ids)
    tokenized_all = tokenize_and_pad_batch(
        all_counts,
        gene_ids,
        max_len=max_seq_len,
        vocab=vocab,
        pad_token=pad_token,
        pad_value=pad_value,
        append_cls=True,  # append <cls> token at the beginning
        include_zero_gene=True,
    )
    all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
    src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
        cell_embeddings = best_model.encode_batch(
            all_gene_ids,
            all_values.float(),
            src_key_padding_mask=src_key_padding_mask,
            batch_size=config.batch_size,
            batch_labels=torch.from_numpy(batch_ids).long() if DSBN else None,
            time_step=0,
            return_np=True,
        )
    cell_embeddings = cell_embeddings / np.linalg.norm(
        cell_embeddings, axis=1, keepdims=True
    )
    
    # Save the embeddings
    df = pd.DataFrame(cell_embeddings, index=adata_sorted.obs.index)
    file_out = latent_path + "_scGPT_" + str(rep) + ".txt"
    df.to_csv(file_out)
    
    # Clean up
    del adata
    del adata_sorted
    del adata_t
    del best_model
    del tokenized_all
    del all_counts
    del model
    del tokenized_train
    del tokenized_valid
    
    # End the logger and the run
    run.finish()
    wandb.finish()
    gc.collect()

    

[34m[1mwandb[0m: Currently logged in as: [33mjespergrud[0m ([33mmadlab_sdu[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'seed': 42, 'dataset_name': 'Mixture', 'do_train': True, 'load_model': '/work/NMF_project/reproducibility/data/scGPT/Model/', 'GEPC': True, 'ecs_thres': 0.8, 'dab_weight': 1.0, 'mask_ratio': 0.4, 'epochs': 15, 'n_bins': 51, 'lr': 0.0001, 'batch_size': 64, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True}
save to /work/NMF_project/reproducibility/data/scGPT/FT/dev_Mixture_0
scGPT - INFO - match 13123/13760 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from /work/NMF_project/reproducibility/data/scGPT/Model/best_model.pt, the model args will be overriden by the config /work/NMF_project/reproducibility/data/scGPT/Model/args.json.
scGPT - INFO - train set number of samples: 27747, 
	 feature length: 1201
scGPT - INFO - valid set number of samples: 3083, 
	 feature length: 1201
Use domain specific batchnorm with affine=False
scGPT - INFO

100%|██████████| 482/482 [00:41<00:00, 11.54it/s]


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/dab,▄▁▂▇▅▇▃█▅▆▆▄▇▄▆▇▄▃▅▂▅▅▄▅▄▄▃▃▄▄▄▃▆▃▆▂▅▄▄▅
train/ecs,███▇█▅▄█▄▇▆▃▅▃▆▆▂▆▃▃▄▃▄▃▂▆▃▁▄▂▂▃▁▂▂▁▂▁▂▂
train/mse,█▂▂▄▂▃▃▂▃▁▂▃▁▃▂▂▃▁▂▃▂▂▂▂▃▁▂▃▂▃▁▂▃▂▂▃▁▃▁▁
train/mvc,█▃▂▆▂▆▄▄▅▂▄▅▂▅▂▂▅▁▂▄▂▂▂▂▅▂▂▄▂▄▂▂▄▂▂▄▁▄▁▂
train/mvc_nzlp,▇▃▃▅▂▄▄▂▄▂▄▄▂▆▂▂▃▁▂▃▄▂▁▂▃▁▂▃▁▄▁▂█▁▃▃▂▂▁▁
train/nzlp,█▇▆▇▄▅▅▂▅▁▃▄▂▄▂▂▄▁▃▄▂▃▂▂▄▂▂▄▂▄▂▂▄▂▂▄▂▄▁▂
valid/dab,█▆▄▆▃▃▂▂▁▁▁▁▁▁▁
valid/mre,█▅▁▃▃▄▂▂▅▄▂▃▃▄▂
valid/mse,█▅▄▃▃▂▂▂▂▁▁▁▁▁▁

0,1
epoch,15.0
train/dab,2.52846
train/ecs,6.04161
train/mse,76.31644
train/mvc,78.44629
train/mvc_nzlp,0.43864
train/nzlp,0.39117


{'seed': 42, 'dataset_name': 'Mixture', 'do_train': True, 'load_model': '/work/NMF_project/reproducibility/data/scGPT/Model/', 'GEPC': True, 'ecs_thres': 0.8, 'dab_weight': 1.0, 'mask_ratio': 0.4, 'epochs': 15, 'n_bins': 51, 'lr': 0.0001, 'batch_size': 64, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True}
save to /work/NMF_project/reproducibility/data/scGPT/FT/dev_Mixture_1
scGPT - INFO - match 13123/13760 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from /work/NMF_project/reproducibility/data/scGPT/Model/best_model.pt, the model args will be overriden by the config /work/NMF_project/reproducibility/data/scGPT/Model/args.json.
scGPT - INFO - train set number of samples: 27747, 
	 feature length: 1201
scGPT - INFO - valid set number of samples: 3083, 
	 feature length: 1201
Use domain specific batchnorm with affine=False
scGPT - INFO

100%|██████████| 482/482 [00:41<00:00, 11.58it/s]


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/dab,▄▁▃▇▃▇▁▅▆▇▆▄█▆▅▆▆▃▅▂▄▄▄▅▄▄▃▃▅▄▄▅▆▃▆▂▄▃▅▅
train/ecs,█████▄▄█▄▆▅▄▅▃▆▆▂▆▃▃▄▃▄▃▂▆▃▁▄▂▃▃▁▄▄▁▃▁▂▃
train/mse,█▂▂▄▂▃▃▂▃▁▂▃▁▄▂▂▃▁▂▃▂▂▂▂▃▁▂▃▂▃▁▂▃▂▂▃▁▃▁▁
train/mvc,▄▂▂▃▁▃▃▁▃▁▁▄█▄▁▁▃▁▂▃▁▂▁▁▃▁▁▂▁▃▁▁▃▁▁▂▁▃▁▁
train/mvc_nzlp,█▃▃█▂▇▄▂▄▂▃▆▂▇▃▅▄▁▂▃▂▂▂▂▄▂▂▃▁▄▁▂▃▁▂▂▁▃▁▁
train/nzlp,█▇▆▇▄▆▅▂▅▂▂▄▂▅▂▂▅▁▃▄▂▃▂▂▄▂▂▄▂▄▂▂▄▂▂▄▂▄▁▂
valid/dab,█▅▅▃▃▃▂▂▂▁▁▁▁▁▁
valid/mre,█▃▁▃▄▃▁▂▄▄▂▃▂▃▂
valid/mse,█▄▃▂▂▂▂▂▂▁▁▁▁▁▁

0,1
epoch,15.0
train/dab,2.15723
train/ecs,6.56093
train/mse,75.89831
train/mvc,78.14529
train/mvc_nzlp,0.43278
train/nzlp,0.38992


{'seed': 42, 'dataset_name': 'Mixture', 'do_train': True, 'load_model': '/work/NMF_project/reproducibility/data/scGPT/Model/', 'GEPC': True, 'ecs_thres': 0.8, 'dab_weight': 1.0, 'mask_ratio': 0.4, 'epochs': 15, 'n_bins': 51, 'lr': 0.0001, 'batch_size': 64, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True}
save to /work/NMF_project/reproducibility/data/scGPT/FT/dev_Mixture_2
scGPT - INFO - match 13123/13760 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from /work/NMF_project/reproducibility/data/scGPT/Model/best_model.pt, the model args will be overriden by the config /work/NMF_project/reproducibility/data/scGPT/Model/args.json.
scGPT - INFO - train set number of samples: 27747, 
	 feature length: 1201
scGPT - INFO - valid set number of samples: 3083, 
	 feature length: 1201
Use domain specific batchnorm with affine=False
scGPT - INFO

100%|██████████| 482/482 [00:45<00:00, 10.58it/s]


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/dab,▄▁▃█▄▇▂▄▆█▄▅█▅▆█▅▃▆▂▆▅▆▅▅▅▃▃▄▄▄▄▆▄▆▂▅▃▅▆
train/ecs,███▇█▄▄█▄▇▆▄▅▂▆▆▃▆▄▃▄▃▄▃▂▆▃▁▄▂▂▂▁▂▂▁▃▁▂▂
train/mse,█▂▂▄▂▃▃▁▃▁▂▃▁▃▁▂▃▁▂▃▂▂▂▂▃▁▂▃▂▃▁▂▃▂▂▃▁▃▁▁
train/mvc,█▃▃▆▂▅▄▂▅▂▂▅▂▅▂▂▅▁▄▄▂▂▂▂▅▂▂▄▂▄▂▂▅▂▂▄▁▄▁▂
train/mvc_nzlp,▇▄▃▅▂▄▄▂▄▁▂▃▂▄▂▃█▁▂▃▂▂▂▂▃▁▂▃▂▃▂▂▃▁▂▂▁▂▁▁
train/nzlp,█▇▆▇▄▆▅▂▅▁▂▄▂▄▂▂▄▁▃▄▂▃▂▂▄▂▃▄▂▄▂▂▄▂▂▄▂▄▁▂
valid/dab,█▅▅▅▃▂▂▂▂▂▁▁▁▁▁
valid/mre,█▁▁▄▃▄▄▄▄▃▂▂▂▂▂
valid/mse,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
epoch,15.0
train/dab,2.48493
train/ecs,5.90701
train/mse,76.14938
train/mvc,78.35318
train/mvc_nzlp,0.43936
train/nzlp,0.39121


{'seed': 42, 'dataset_name': 'Mixture', 'do_train': True, 'load_model': '/work/NMF_project/reproducibility/data/scGPT/Model/', 'GEPC': True, 'ecs_thres': 0.8, 'dab_weight': 1.0, 'mask_ratio': 0.4, 'epochs': 15, 'n_bins': 51, 'lr': 0.0001, 'batch_size': 64, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True}
save to /work/NMF_project/reproducibility/data/scGPT/FT/dev_Mixture_3
scGPT - INFO - match 13123/13760 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from /work/NMF_project/reproducibility/data/scGPT/Model/best_model.pt, the model args will be overriden by the config /work/NMF_project/reproducibility/data/scGPT/Model/args.json.
scGPT - INFO - train set number of samples: 27747, 
	 feature length: 1201
scGPT - INFO - valid set number of samples: 3083, 
	 feature length: 1201
Use domain specific batchnorm with affine=False
scGPT - INFO

100%|██████████| 482/482 [00:45<00:00, 10.53it/s]


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/dab,▄▁▃▇▂▇▁▅▅▅▇▄▆▅▆█▄▄▅▂▄▄▄▅▅▄▄▂▄▄▄▄▅▃▄▂▄▃▅▅
train/ecs,███▆█▄▄█▄▇▆▃▅▃▆▆▃▆▃▃▄▃▄▃▂▆▃▁▄▂▃▃▁▄▅▁▄▁▄▃
train/mse,█▂▂▄▂▃▃▁▃▁▂▃▂▃▁▂▃▁▂▃▂▂▂▂▃▁▂▃▂▃▁▂▃▁▂▃▁▃▁▁
train/mvc,█▃▂▆▂▆▄▂▅▂▄▅▂▅▂▃▅▁▂▅▂▂▂▂▅▂▂▄▂▄▂▂▄▂▂▄▁▄▁▂
train/mvc_nzlp,▄▂▂▃▁▄▂▂▃▁▂▂▁█▃▂▄▁▁▂▁▁▁▁▃▁▁▂▁▂▁▁▃▁▁▂▁▂▁▁
train/nzlp,█▇▆▇▄▅▅▂▅▁▂▄▂▄▂▂▄▁▃▄▂▃▂▂▄▂▂▄▂▄▂▂▄▂▂▄▂▄▁▂
valid/dab,█▆▅▄▃▃▂▂▂▁▁▁▁▁▁
valid/mre,█▂▁▇▅▄▁▂▄▄▃▂▂▃▃
valid/mse,█▄▃▃▃▂▂▂▁▁▁▁▁▁▁

0,1
epoch,15.0
train/dab,2.46666
train/ecs,6.56948
train/mse,76.07622
train/mvc,78.41503
train/mvc_nzlp,0.44025
train/nzlp,0.39032


{'seed': 42, 'dataset_name': 'Mixture', 'do_train': True, 'load_model': '/work/NMF_project/reproducibility/data/scGPT/Model/', 'GEPC': True, 'ecs_thres': 0.8, 'dab_weight': 1.0, 'mask_ratio': 0.4, 'epochs': 15, 'n_bins': 51, 'lr': 0.0001, 'batch_size': 64, 'layer_size': 128, 'nlayers': 4, 'nhead': 4, 'dropout': 0.2, 'schedule_ratio': 0.9, 'save_eval_interval': 5, 'log_interval': 100, 'fast_transformer': True, 'pre_norm': False, 'amp': True}
save to /work/NMF_project/reproducibility/data/scGPT/FT/dev_Mixture_4
scGPT - INFO - match 13123/13760 genes in vocabulary of size 60697.
scGPT - INFO - Resume model from /work/NMF_project/reproducibility/data/scGPT/Model/best_model.pt, the model args will be overriden by the config /work/NMF_project/reproducibility/data/scGPT/Model/args.json.
scGPT - INFO - train set number of samples: 27747, 
	 feature length: 1201
scGPT - INFO - valid set number of samples: 3083, 
	 feature length: 1201
Use domain specific batchnorm with affine=False
scGPT - INFO

100%|██████████| 482/482 [00:45<00:00, 10.57it/s]


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/dab,▄▁▂▇▃▆▃▅▄▅▄▄█▄▆▇▄▃▅▂▆▄▄▅▄▄▃▃▄▄▄▄▆▃▅▂▄▃▅▅
train/ecs,███▇█▄▄█▄▇▅▃▅▃▆▆▂▆▃▃▄▃▄▃▂▆▃▁▄▂▃▃▁▄▅▁▃▁▃▃
train/mse,█▂▂▄▂▃▃▁▃▁▂▃▂▃▂▂▃▁▂▃▂▂▂▂▃▁▂▃▂▃▁▂▃▁▂▃▁▃▁▁
train/mvc,█▃▂▆▂▅▄▂▅▂▂▅▅▅▂▃▅▁▂▄▂▂▂▂▅▂▂▄▂▅▂▂▅▂▂▄▁▄▁▂
train/mvc_nzlp,▃▂▂▂▁▂▂▁▂▁▁▂▂▄▁▁▂▁▁▂▁▁▁▁▂▁▁▂▁▂▁▁█▁▁▁▁▁▁▁
train/nzlp,█▇▆▇▄▅▅▂▄▁▂▄▃▄▂▂▄▁▃▄▂▃▂▂▄▂▂▄▂▄▂▂▄▂▂▄▂▄▁▂
valid/dab,█▆▄▆▃▂▂▂▂▁▁▁▁▁▁
valid/mre,█▃▁█▃▃▅▃▄▅▄▄▄▄▃
valid/mse,█▄▃▄▂▂▂▂▂▂▁▁▁▁▁

0,1
epoch,15.0
train/dab,2.32589
train/ecs,6.58555
train/mse,76.10165
train/mvc,78.42239
train/mvc_nzlp,0.43788
train/nzlp,0.39053
