In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import json
import numpy as np
import scanpy as sc
import time
import copy
from scipy.sparse import issparse

from utils import set_seed, AttrDict
from vocab import Vocab
from preprocess import Preprocessor, get_interactions, get_z
from tokenizer import Tokenizer, random_mask_value
from model import TransformerModel, BioFormerModel
from loss import masked_mse_loss, masked_relative_error, criterion_neg_log_bernoulli

In [4]:
torch.optim.lr_scheduler.ReduceLROnPlateau?

[1;31mInit signature:[0m
[0mtorch[0m[1;33m.[0m[0moptim[0m[1;33m.[0m[0mlr_scheduler[0m[1;33m.[0m[0mReduceLROnPlateau[0m[1;33m([0m[1;33m
[0m    [0moptimizer[0m[1;33m,[0m[1;33m
[0m    [0mmode[0m[1;33m=[0m[1;34m'min'[0m[1;33m,[0m[1;33m
[0m    [0mfactor[0m[1;33m=[0m[1;36m0.1[0m[1;33m,[0m[1;33m
[0m    [0mpatience[0m[1;33m=[0m[1;36m10[0m[1;33m,[0m[1;33m
[0m    [0mthreshold[0m[1;33m=[0m[1;36m0.0001[0m[1;33m,[0m[1;33m
[0m    [0mthreshold_mode[0m[1;33m=[0m[1;34m'rel'[0m[1;33m,[0m[1;33m
[0m    [0mcooldown[0m[1;33m=[0m[1;36m0[0m[1;33m,[0m[1;33m
[0m    [0mmin_lr[0m[1;33m=[0m[1;36m0[0m[1;33m,[0m[1;33m
[0m    [0meps[0m[1;33m=[0m[1;36m1e-08[0m[1;33m,[0m[1;33m
[0m    [0mverbose[0m[1;33m=[0m[1;34m'deprecated'[0m[1;33m,[0m[1;33m
[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[1;31mDocstring:[0m     
Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing th

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import json
import numpy as np
import scanpy as sc
import time
import copy
from scipy.sparse import issparse

from utils import set_seed, AttrDict
from vocab import Vocab
from preprocess import Preprocessor, get_interactions, get_z
from tokenizer import Tokenizer, random_mask_value
from model import TransformerModel, BioFormerModel
from loss import masked_mse_loss, masked_relative_error, criterion_neg_log_bernoulli

config = AttrDict(json.load(open('config.json')))
print(config)

if config.seed:
    set_seed(config.seed)

if config.wandb:
    import wandb
    wandb.login()
    run = wandb.init(
        project='BioFormer',
        config = config,
        name = config.run_name if config.run_name else None
    )

# Pre-processing
# pad_token = "<pad>"
# special_tokens = [pad_token, "<cls>", "<eoc>"]
# mask_value = -1 # in the value vector corresponding to msk token (!= msk token index in vocab)
# pad_value = -2  # in the value vector corresponding to pad token (!= pad token index in vocab)

# Import data
path_to_transcriptional_interactions = '../data/transcriptional_interactions.csv'
dataset_name = config.dataset_name

if dataset_name == 'BREAST_25K':
    adata = sc.read_h5ad('../data/breast_25k.h5ad')
    data_is_raw = True

elif dataset_name == 'BREAST_12K':
    adata = sc.read_h5ad('../data/breast_12k.h5ad')
    data_is_raw = True

elif dataset_name == 'DERMAL_100K':
    adata = sc.read_h5ad('../data/dermal_100k.h5ad')
    adata.var["gene_name"] = adata.var.feature_name.tolist()
    data_is_raw = True

elif dataset_name == 'HYPOXIA_9K':
    adata = sc.read_h5ad('../data/scsHypoxiaTimeSub.h5ad')
    adata.X = adata.layers['raw_count']
    adata.var['gene_name'] = adata.var.index.tolist()
    data_is_raw = True

print(dataset_name)
print(adata)

# Pre-process RNA-seq data
preprocessor = Preprocessor(use_key="X",  # the key in adata.layers to use as raw data
                            filter_gene_by_counts=3,  # step 1
                            filter_cell_by_counts=False,  # step 2
                            normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
                            result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
                            log1p=data_is_raw,  # 4. whether to log1p the normalized data
                            result_log1p_key="X_log1p",
                            subset_hvg=config.n_hvg,  # 5. whether to subset the raw data to highly variable genes
                            hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
                            binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
                            result_binned_key="X_binned",  # the key in adata.layers to store the binned data
                            )
preprocessor(adata, batch_key=None)

# Vocab
genes = adata.var["gene_name"].tolist()
vocab = Vocab(genes)
vocab.set_default_index(vocab["<pad>"]) # index to return if token not found in vocab
print(f'Init vocab of size {len(vocab)} with {config.n_hvg} unique genes...')
print(f'CLS in vocab: {vocab.stoi['<cls>']}')

# Tokenize & Pad
tokenizer = Tokenizer(vocab = vocab,
                      append_cls = True,
                      cls_token = "<cls>",
                      pad_token = "<pad>",
                      pad_value = -2,
                      include_zero_gene= config.include_zero_gene, 
                      )
tokenized = tokenizer.tokenize_and_pad_batch(adata.layers["X_binned"].toarray() if issparse(adata.layers["X_binned"]) else adata.layers["X_binned"],
                                             np.array(vocab(genes), dtype=int),
                                             max_len=config.n_hvg + 1,
                                             )
print(f"Tot samples: {tokenized['genes'].shape[0]}")
print(f"Input length: {tokenized['genes'].shape[1]}")

{'run_name': 'test-autorun', 'dataset_name': 'HYPOXIA_9K', 'model': 'scGPT', 'd_model': 32, 'nhead': 4, 'nlayers': 8, 'n_hvg': 50, 'do_pair_bias': True, 'do_opm': True, 'd_z': 32, 'd_opm': 8, 'init_z': False, 'do_train': True, 'epochs': 3, 'batch_size': 16, 'wandb': False, 'seed': 5289, 'n_bins': 51, 'include_zero_gene': False, 'explicit_zero_prob': True, 'log_interval': 100, 'lr': 0.0001, 'amp': True, 'schedule_ratio': 0.9, 'save_model': False}
HYPOXIA_9K
AnnData object with n_obs × n_vars = 9234 × 19046
    obs: 'nCount_RNA', 'nFeature_RNA', 'SampleTags', 'percent.mt', 'HypoxicState', 'TimePoint', 'nCount_SCT', 'nFeature_SCT', 'S.Score', 'G2M.Score', 'Phase', 'seurat_clusters', 'SampleTagsShort', 'active_ident'
    var: 'variable_gene', 'gene_name'
    uns: 'active_ident_colors', 'seurat_clusters_colors'
    obsm: 'X_pca', 'X_umap'
    layers: 'raw_count'
Filtering genes by counts ...
Normalizing total counts ...
Log1p transforming ...
Subsetting highly variable genes ...
No batch_ke

In [None]:
batch = next(iter(loader))
batch['values']

tensor([[ 0., 35., -1., 34., 43., 40., 24., 11., 50., 13., 30.,  6., 18., -1.,
         26., 11.,  5., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,
         -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.],
        [ 0., -1., 11.,  9., 12.,  4., 35., 20., 30., 17., 41., 14., 37., -1.,
         47., 23., 43., 50., 12., 20., 33., 13., 30., 12., -1., 39., -2., -2.,
         -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.],
        [ 0., 36., 19., 22., 13., 30., -1., 30., 40., 50., -1., 14., 36., 21.,
         45.,  2., -1., 11., 26., 43.,  7., 39.,  9., -2., -2., -2., -2., -2.,
         -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.],
        [ 0., 46.,  6., 33., -1., 22., 21., 25., 50., 27., 40., 36., -1.,  3.,
          9.,  3., 30., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.,
         -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2., -2.],
        [ 0., 20., 11., 32., 11.,  1., 15.,  4., 43., 24., 39., 50.,

In [48]:
g, x = batch['gene_ids'], batch['values']
model(g, x)

{'mlm_output': tensor([[ 6.9914e-02,  1.0955e-01,  1.2178e-01,  4.6768e-02,  1.5070e-01,
           1.5231e-01,  7.3118e-02,  3.9406e-02,  4.4318e-03,  5.5104e-02,
           5.4776e-02,  6.8119e-02,  6.7590e-02,  6.9702e-02,  2.9926e-02,
           3.3949e-02,  1.9985e-01,  1.5031e-01,  1.5031e-01,  1.5031e-01,
           1.5031e-01,  1.5031e-01,  1.5031e-01,  1.5031e-01,  1.5031e-01,
           1.5031e-01,  1.5031e-01,  1.5031e-01,  1.5031e-01,  1.5031e-01,
           1.5031e-01,  1.5031e-01,  1.5031e-01,  1.5031e-01,  1.5031e-01,
           1.5031e-01,  1.5031e-01,  1.5031e-01,  1.5031e-01,  1.5031e-01,
           1.5031e-01],
         [ 8.4718e-02,  9.5660e-02,  9.8267e-02,  4.8800e-02,  1.2695e-01,
           9.5691e-02,  1.2213e-01,  1.0694e-01,  1.0744e-01,  3.5587e-02,
           2.0089e-02,  3.8034e-02,  3.8196e-02,  1.5177e-01,  1.1638e-01,
          -3.1438e-03,  6.9284e-02,  9.0792e-03,  3.6398e-02,  1.3420e-01,
           1.7233e-02,  2.0623e-02,  6.5182e-02,  3.6120e-02, 

In [49]:
config.do_train = True

# Instantiate model
if config.model == "scGPT":
    model = TransformerModel(ntoken=len(vocab),
                             d_model=config.d_model,
                             nhead=config.nhead,
                             nlayers=config.nlayers,
                             pad_id = vocab.stoi['<pad>'],
                             explicit_zero_prob=config.explicit_zero_prob
                             ) 
elif config.model == "BioFormer":
    model = BioFormerModel(ntoken=len(vocab),
                           d_model=config.d_model,
                           d_z = config.d_z,
                           d_opm = config.d_opm,
                           nhead=config.nhead,
                           nlayers=config.nlayers,
                           do_pair_bias=config.do_pair_bias,
                           do_opm=config.do_opm,
                           pad_id = vocab.stoi['<pad>'],
                           explicit_zero_prob=config.explicit_zero_prob
                           ) 
print(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# model = torch.nn.DataParallel(model)

# Parameters count
n_params = sum(p.numel() for p in model.parameters())
model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
print(f'''device: {device} | model: {config.model} | d_model: {config.d_model} | nhead: {config.nhead} | nlayers: {config.nlayers} | tot. params: {n_params/1e6:.2f}M | model size: {model_size_bytes/1e6:.2f}MB''')
if config.wandb:
    wandb.config.update({"Model Parameters": n_params})

# Max memory required per intermediate step:
if config.model == 'BioFormer':
        # no. elements in the [B, r, r, c, c] opm intermediate matrix
        tmp = config.batch_size * (config.n_hvg + 1) ** 2 * config.d_opm ** 2
        
        # number of bytes required (using np.float32)
        tmp = tmp * 4

        print(f'memory required for opm: {tmp/1e6 :,.2f}MB')
        

# RNA-seq Dataset
class SeqDataset(Dataset):
    """
    Create RNA-seq dataset from vocabulary with keys ['gene_ids', 'valeus', 'target_vaules', 'interactions'].
    """
    def __init__(self, data: dict):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}

# Mask and get interactions
def prepare_data():
    """
    1. Random mask the data
    2. Get the interaction matrix z
    3. Convert to torch.Dataset.
    
    """
    masked_values = random_mask_value(tokenized["values"])
    print(f"Random masking at epoch {epoch}...")

    B, r = masked_values.shape
    if config.init_z:
        tf = get_interactions(genes, path_to_transcriptional_interactions)
    interactions = get_z(tokenized["genes"], tf, vocab.itos) if config.init_z else torch.zeros((B, r, r))    # [B, r, r]

    data_pt = {
        "gene_ids": tokenized["genes"],           # [B, r]
        "values": masked_values,                  # [B, r]
        "target_values": tokenized["values"],     # [B, r]
        "interactions": interactions              # [B, r, r]
    }

    return SeqDataset(data_pt)

# --------------------------------------------------------------------------- #
# --------------------------- TRAINING LOOP --------------------------------- #
# --------------------------------------------------------------------------- #

criterion = masked_mse_loss
criterion_dab = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=config.lr,
                              eps=1e-4 if config.amp else 1e-8
                              )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            1,
                                            gamma=config.schedule_ratio
                                            )
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

best_val_loss = float("inf")
best_model = None

if config.wandb:
    wandb.define_metric("valid/mse", summary="min", step_metric="epoch")
    wandb.define_metric("valid/mre", summary="min", step_metric="epoch")

for epoch in range(1, config.epochs + 1):
    epoch_start_time = time.time()
    
    dataset = prepare_data()

    train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
    )
    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=config.batch_size,
        shuffle=True,
    )

    # -------------------------------- TRAINING ----------------------------------- #
    if config.do_train:
        model.train()

        loader = train_loader

        total_loss = 0.0
        total_mse = 0.0
        total_gepc = 0.0
        total_mre = 0.0
        log_interval = config.log_interval
        start_time = time.time()

        num_batches = len(loader)
        for batch, batch_data in enumerate(loader):
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)
            
            if config.model == "BioFormer":
                z = batch_data['interactions'].to(device)

            # ---------- forward -------------------
            with torch.cuda.amp.autocast(enabled=config.amp):
                
                if config.model == "scGPT":
                    output_dict = model(input_gene_ids, input_values)
                elif config.model == "BioFormer":
                    output_dict = model(input_gene_ids, input_values, z)
                
                masked_positions = input_values.eq(-1)          # default value for the mask position
                loss = loss_mse = criterion(output_dict["mlm_output"], target_values, masked_positions)
                
                metrics_to_log = {"train/mse": loss_mse.item()}
                
                if config.explicit_zero_prob:
                    loss_zero_log_prob = criterion_neg_log_bernoulli(output_dict["mlm_zero_probs"], target_values, masked_positions)
                    loss += loss_zero_log_prob
                    metrics_to_log.update({"train/nzlp": loss_zero_log_prob.item()})
                
            # -------------- backward ------------------
            model.zero_grad()
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            scaler.step(optimizer)
            scaler.update()

            scheduler.step()
            
            # --------------- logs & stats ---------------------
            if config.wandb:
                wandb.log(metrics_to_log)

            with torch.no_grad():
                mre = masked_relative_error(output_dict["mlm_output"], target_values, masked_positions)

            total_loss += loss.item()                               # sum of all losses
            total_mse += loss_mse.item()                            # MSE alone
            total_mre += mre.item()                                 # MRE alone
            
            # For logging purposes, aggregate loss across log_interval batches 
            if batch % log_interval == 0 and batch > 0:
                lr = scheduler.get_last_lr()[0]
                ms_per_batch = (time.time() - start_time) * 1000 / log_interval
                cur_loss = total_loss / log_interval
                cur_mse = total_mse / log_interval
                cur_mre = total_mre / log_interval
                
                print(f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | train/loss {cur_loss:5.2f} | train/mse {cur_mse:5.2f} |" + f"train/mre {cur_mre:5.2f} |" )
                
                total_loss = 0
                total_mse = 0
                total_mre = 0
                start_time = time.time()

    # -------------------------------- VALIDATION ----------------------------------- #
    model.eval()
    
    loader = valid_loader
    
    total_loss = 0.0
    total_mre = 0.0
    total_num = 0
    with torch.no_grad():
        for batch_data in loader:
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)

            if config.model == "BioFormer":
                interactions = batch_data['interactions'].to(device)

            with torch.cuda.amp.autocast(enabled=config.amp):
                
                if config.model == "scGPT":
                    output_dict = model(input_gene_ids, input_values)
                elif config.model == "BioFormer":
                    output_dict = model(input_gene_ids, input_values, interactions)
                
                output_values = output_dict["mlm_output"]

                masked_positions = input_values.eq(-1)
                loss = criterion(output_values, target_values, masked_positions)

            total_loss += loss.item() * len(input_gene_ids)
            total_mre += masked_relative_error(output_values, target_values, masked_positions).item() * len(input_gene_ids)
            total_num += len(input_gene_ids)

    if config.wandb:
        wandb.log({ 
            "valid/mse": total_loss / total_num,
            "valid/mre": total_mre / total_num,
            "epoch": epoch
            })

    val_loss = total_loss / total_num
    val_mre = total_mre / total_num
    
    # -------------------------------- EPOCH-RELATED STATS ----------------------------------- #
    elapsed = time.time() - epoch_start_time
    print("-" * 89)
    print(f"| end of epoch {epoch:3d} | runtime: {elapsed:5.2f}s | valid/mse {val_loss:5.4f} | valid/mre {val_mre:5.4f}")
    print("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        print(f"New best model found at epoch {epoch} with valid/mse {best_val_loss:5.4f}")
# --------------------------------- END OF TRAINING LOOP -------------------------------------- #

# --------------------------------- final house-keeping --------------------------------------- #
if config.save_model:
    if config.save_model[-1] != "/":
        config.save_model += "/"
    dir = f"{config.save_model}/{config.run_name}_{time.time():.0f}.pt"
    torch.save(best_model.state_dict(), dir)

if config.wandb:
    run.finish()

TransformerModel(
  (emb_g): GeneEncoder(
    (embedding): Embedding(52, 32, padding_idx=50)
    (enc_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  )
  (emb_x): ContinuousValueEncoder(
    (dropout): Dropout(p=0.2, inplace=False)
    (linear1): Linear(in_features=1, out_features=32, bias=True)
    (activation): ReLU()
    (linear2): Linear(in_features=32, out_features=32, bias=True)
    (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
        )
        (linear1): Linear(in_features=32, out_features=128, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=128, out_features=32, bias=True)
        (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      



Random masking at epoch 1...




| epoch   1 | 100/520 batches | lr 0.0000 | ms/batch 74.12 | train/loss 802.87 | train/mse 802.87 |train/mre 16825.11 |
| epoch   1 | 200/520 batches | lr 0.0000 | ms/batch 86.45 | train/loss 799.72 | train/mse 799.72 |train/mre 17633.63 |
| epoch   1 | 300/520 batches | lr 0.0000 | ms/batch 86.03 | train/loss 810.66 | train/mse 810.66 |train/mre 18230.28 |


KeyboardInterrupt: 

## Load Pre-trained

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import json
import numpy as np
import scanpy as sc
import time
import copy
from scipy.sparse import issparse

from utils import set_seed, AttrDict
from vocab import Vocab
from preprocess import Preprocessor, get_interactions, get_z
from tokenizer import Tokenizer, random_mask_value
from model import TransformerModel, BioFormerModel
from loss import masked_mse_loss, masked_relative_error, criterion_neg_log_bernoulli

PATH = '../checkpoints/bioformer-af2_1730519158'
config = AttrDict(json.load(open(PATH + '.json')))
print(config)

if config.seed:
    set_seed(config.seed)

# Import data
path_to_transcriptional_interactions = '../data/transcriptional_interactions.csv'
dataset_name = config.dataset_name

if dataset_name == 'BREAST_25K':
    adata = sc.read_h5ad('../data/breast_25k.h5ad')
    data_is_raw = True

elif dataset_name == 'BREAST_12K':
    adata = sc.read_h5ad('../data/breast_12k.h5ad')
    data_is_raw = True

elif dataset_name == 'DERMAL_100K':
    adata = sc.read_h5ad('../data/dermal_100k.h5ad')
    adata.var["gene_name"] = adata.var.feature_name.tolist()
    data_is_raw = True

elif dataset_name == 'HYPOXIA_9K':
    adata = sc.read_h5ad('../data/scsHypoxiaTimeSub.h5ad')
    adata.X = adata.layers['raw_count']
    adata.var['gene_name'] = adata.var.index.tolist()
    data_is_raw = True

print(dataset_name)
print(adata)

# Pre-process RNA-seq data
preprocessor = Preprocessor(use_key="X",  # the key in adata.layers to use as raw data
                            filter_gene_by_counts=3,  # step 1
                            filter_cell_by_counts=False,  # step 2
                            normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
                            result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
                            log1p=data_is_raw,  # 4. whether to log1p the normalized data
                            result_log1p_key="X_log1p",
                            subset_hvg=config.n_hvg,  # 5. whether to subset the raw data to highly variable genes
                            hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
                            binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
                            result_binned_key="X_binned",  # the key in adata.layers to store the binned data
                            )
preprocessor(adata, batch_key=None)

# Vocab
genes = adata.var["gene_name"].tolist()
vocab = Vocab(genes)
vocab.set_default_index(vocab["<pad>"]) # index to return if token not found in vocab
print(f'Init vocab of size {len(vocab)} with {config.n_hvg} unique genes...')
print(f"CLS in vocab: {vocab.stoi['<cls>']}")

# Tokenize & Pad
tokenizer = Tokenizer(vocab = vocab,
                      append_cls = True,
                      cls_token = "<cls>",
                      pad_token = "<pad>",
                      pad_value = -2,
                      include_zero_gene= config.include_zero_gene, 
                      )
tokenized = tokenizer.tokenize_and_pad_batch(adata.layers["X_binned"].toarray() if issparse(adata.layers["X_binned"]) else adata.layers["X_binned"],
                                             np.array(vocab(genes), dtype=int),
                                             max_len=config.n_hvg + 1,
                                             )
print(f"Tot samples: {tokenized['genes'].shape[0]}")
print(f"Input length: {tokenized['genes'].shape[1]}")

# Instantiate model
if config.model == "scGPT":
    model = TransformerModel(ntoken=len(vocab),
                             d_model=config.d_model,
                             nhead=config.nhead,
                             nlayers=config.nlayers,
                             pad_id = vocab.stoi['<pad>'],
                             explicit_zero_prob=config.explicit_zero_prob
                             ) 
    model.load_state_dict(torch.load(PATH + '.pt', weights_only=True))

elif config.model == "BioFormer":
    model = BioFormerModel(ntoken=len(vocab),
                           d_model=config.d_model,
                           d_z = config.d_z,
                           d_opm = config.d_opm,
                           nhead=config.nhead,
                           nlayers=config.nlayers,
                           do_pair_bias=config.do_pair_bias,
                           do_opm=config.do_opm,
                           pad_id = vocab.stoi['<pad>'],
                           explicit_zero_prob=config.explicit_zero_prob
                           ) 
    state_dict = torch.load(PATH + '.pt', weights_only=True, map_location=torch.device('cpu'))
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)

print(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model = torch.nn.DataParallel(model)

# Parameters count
n_params = sum(p.numel() for p in model.parameters())
model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
print(f'''device: {device} | model: {config.model} | d_model: {config.d_model} | nhead: {config.nhead} | nlayers: {config.nlayers} | tot. params: {n_params/1e6:.2f}M | model size: {model_size_bytes/1e6:.2f}MB''')

# Max memory required per intermediate step:
if config.model == 'BioFormer':
        # no. elements in the [B, r, r, c, c] opm intermediate matrix
        tmp = config.batch_size * (config.n_hvg + 1) ** 2 * config.d_opm ** 2
        
        # number of bytes required (using np.float32)
        tmp = tmp * 4

        print(f'memory required for opm: {tmp/1e6 :,.2f}MB')
        

# RNA-seq Dataset
class SeqDataset(Dataset):
    """
    Create RNA-seq dataset from vocabulary with keys ['gene_ids', 'valeus', 'target_vaules', 'interactions'].
    """
    def __init__(self, data: dict):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}

# Mask and get interactions
def prepare_data():
    """
    1. Random mask the data
    2. Get the interaction matrix z
    3. Convert to torch.Dataset.
    
    """
    masked_values = random_mask_value(tokenized["values"])
    # print(f"Random masking at epoch {epoch}...")

    B, r = masked_values.shape
    if config.init_z:
        tf = get_interactions(genes, path_to_transcriptional_interactions)
    interactions = get_z(tokenized["genes"], tf, vocab.itos) if config.init_z else torch.zeros((B, r, r))    # [B, r, r]

    data_pt = {
        "gene_ids": tokenized["genes"],           # [B, r]
        "values": masked_values,                  # [B, r]
        "target_values": tokenized["values"],     # [B, r]
        "interactions": interactions              # [B, r, r]
    }

    return SeqDataset(data_pt)

# --------------------------------------------------------------------------- #
# ------------------------------ TEST MODEL --------------------------------- #
# --------------------------------------------------------------------------- #                                         )

criterion = masked_mse_loss

epoch_start_time = time.time()

dataset = prepare_data()

loader = DataLoader(
    dataset=dataset,
    batch_size=config.batch_size,
    shuffle=True,
)


model.eval()

total_loss = 0.0
total_mre = 0.0
total_num = 0


{'run_name': 'bioformer-af2', 'dataset_name': 'HYPOXIA_9K', 'model': 'BioFormer', 'd_model': 256, 'nhead': 8, 'nlayers': 48, 'n_hvg': 255, 'do_pair_bias': True, 'do_opm': True, 'd_z': 128, 'd_opm': 32, 'init_z': False, 'do_train': True, 'epochs': 5, 'batch_size': 16, 'wandb': True, 'seed': 5289, 'n_bins': 51, 'include_zero_gene': False, 'explicit_zero_prob': True, 'log_interval': 100, 'lr': 0.01, 'amp': True, 'schedule_ratio': 0.1, 'save_model': '../checkpoints/'}
HYPOXIA_9K
AnnData object with n_obs × n_vars = 9234 × 19046
    obs: 'nCount_RNA', 'nFeature_RNA', 'SampleTags', 'percent.mt', 'HypoxicState', 'TimePoint', 'nCount_SCT', 'nFeature_SCT', 'S.Score', 'G2M.Score', 'Phase', 'seurat_clusters', 'SampleTagsShort', 'active_ident'
    var: 'variable_gene', 'gene_name'
    uns: 'active_ident_colors', 'seurat_clusters_colors'
    obsm: 'X_pca', 'X_umap'
    layers: 'raw_count'
Filtering genes by counts ...
Normalizing total counts ...
Log1p transforming ...
Subsetting highly variable ge

In [None]:
batch_data = next(iter(loader))
batch_data = dataset[[0]]

input_gene_ids = batch_data["gene_ids"].to(device)
input_values = batch_data["values"].to(device)
target_values = batch_data["target_values"].to(device)
if config.model == "BioFormer":
            interactions = batch_data['interactions'].to(device)

: 

In [None]:
model(input_gene_ids, input_values, interactions)

In [None]:
# with torch.no_grad():
#     for i, batch_data in enumerate(loader):
        
#         print(f'Batch no. {i}')

#         input_gene_ids = batch_data["gene_ids"].to(device)
#         input_values = batch_data["values"].to(device)
#         target_values = batch_data["target_values"].to(device)

#         if config.model == "BioFormer":
#             interactions = batch_data['interactions'].to(device)

#         with torch.cuda.amp.autocast(enabled=config.amp):
            
#             if config.model == "scGPT":
#                 output_dict = model(input_gene_ids, input_values)
#             elif config.model == "BioFormer":
#                 output_dict = model(input_gene_ids, input_values, interactions)
            
#             output_values = output_dict["mlm_output"]

#             masked_positions = input_values.eq(-1)
#             loss = criterion(output_values, target_values, masked_positions)

#         total_loss += loss.item() * len(input_gene_ids)
#         total_mre += masked_relative_error(output_values, target_values, masked_positions).item() * len(input_gene_ids)
#         total_num += len(input_gene_ids)