## MLM_pretrain

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import sys
import torch
import time
import json
from torch import nn
import copy
from loss import masked_mse_loss
from utils import data_loader, generation_evaluate, pretrain_generation,load_config, train, evaluate, test, inference
import datetime
import wandb
# sys.path.insert(0, "../")
from scplantllm.model import TransformerModel

In [2]:
start_time = time.time()
config = load_config('setting.json')

In [3]:
hyperparameter_defaults = dict(
    parallel=True,
    epochs=1, 
    batch_size=64,
    lr=1e-4,
    ntoken= 185622,
    nctype= 44, 
    nbatch_effect= 238,
    ecs_threshold=0.0, 
    layer_size=512,
    hlayer_size=512,
    nlayers=6,
    nhead=8,
    nlayers_cls=3,
    dropout=0.5,
    schedule_ratio=0.9, 
    save_eval_interval=5,
    fast_transformer=True,
    explicit_zero_prob=False,
    pre_norm=True,
)
current_time = datetime.datetime.now()
timestamp = current_time.strftime("%YY%mM%dD%HH%MM%SS")
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_SILENT"] = "true"
run = wandb.init(
    config=hyperparameter_defaults,
    project="test",
    entity="aibio",
    group=f"{config.train_strategy}_{config.input_emb_style}",
    mode="offline"
)
model_config = wandb.config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if config.input_emb_style == "category":
    n_input_bins = config.n_bins + 2 # pad_value:-2, cls_value:0, masked_value:-1
else:
    n_input_bins = config.n_bins

In [4]:
print("Loading model")
model = TransformerModel(
    ntoken=model_config.ntoken, 
    d_model=model_config.layer_size, 
    nhead=model_config.nhead, 
    d_hid=model_config.hlayer_size,
    nlayers=model_config.nlayers, 
    nlayers_cls=model_config.nlayers_cls, 
    n_cls=model_config.nctype, 
    dropout=model_config.dropout, 
    pad_value=int(config.pad_value),
    pad_token_id=config.pad_token_id, 
    use_batch_labels=config.use_batch_labels, 
    num_batch_labels=model_config.nbatch_effect, 
    input_emb_style=config.input_emb_style, 
    n_input_bins= n_input_bins, 
    cell_emb_style="cls", 
    use_fast_transformer=model_config.fast_transformer, 
    pre_norm=model_config.pre_norm,)

model.to(device)

Loading model




TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(185622, 512, padding_idx=0)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): CategoryValueEncoder(
    (embedding): Embedding(103, 512, padding_idx=101)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x FlashTransformerEncoderLayer(
        (self_attn): FlashMHA(
          (Wqkv): Linear(in_features=512, out_features=1536, bias=True)
          (inner_attn): FlashAttention()
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, e

In [5]:
data_path = 'data/processed/has_celltype'

In [6]:
train_sampler, train_loader, train_size=model_config.metadata = data_loader(data_path, data_type='train', start_chunk=1, end_chunk=1, batch_size=model_config.batch_size, append_cls=True)
valid_sampler, valid_loader, valid_metadata = data_loader(data_path, data_type='valid', start_chunk=1, num_chunks=1, batch_size=model_config.batch_size,append_cls=True)
test_sampler, test_loader, test_metadata = data_loader(data_path,  data_type='test',start_chunk=1, num_chunks=1, batch_size=model_config.batch_size, append_cls=True)

In [7]:
criterion_gep_gepc = masked_mse_loss
optimizer = torch.optim.Adam(
    model.parameters(), lr=model_config.lr, eps= 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=model_config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

In [8]:
best_val_loss = float("inf")
for epoch in range(model_config.epochs):
    epoch_start_time = time.time()

    pretrain_generation(model, train_loader, criterion_gep_gepc, scaler, optimizer, scheduler, device, config, epoch)
    with torch.no_grad():
        val_loss = generation_evaluate(model, valid_loader, criterion_gep_gepc, device, config, epoch)

    current_time = datetime.datetime.now()
    timestamp = current_time.strftime("%YY%mM%dD%HH%MM%SS")
    save_path = f'./model_param/{config.train_strategy}'
    os.makedirs(save_path, exist_ok=True)

    checkpoint_path = os.path.join(save_path, f"{timestamp}_{config.input_emb_style}_model_{epoch}.pth")
    # torch.save(model.module.state_dict(), checkpoint_path)
        
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        best_model_name = f"best_model_{config.input_emb_style}_{best_model_epoch}_{timestamp}.pth"


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


| epoch   0 | 100/356 batches | lr 0.00010 | ms/batch 490.48 | loss  3.29 | Scale Factor: 16384.0 | real loss: 3.29 | curl gep:  0.00 | mre 31082.99
| epoch   0 | 200/356 batches | lr 0.00010 | ms/batch 494.12 | loss  2.16 | Scale Factor: 16384.0 | real loss: 2.16 | curl gep:  0.00 | mre 16229.21
| epoch   0 | 300/356 batches | lr 0.00010 | ms/batch 487.44 | loss  1.99 | Scale Factor: 16384.0 | real loss: 1.99 | curl gep:  0.00 | mre 12339.96


In [9]:
end_time = time.time()
print(f"Total time: {end_time - start_time} seconds")

Total time: 211.82380628585815 seconds


## CLS_pretrain

In [10]:
model_name = f"./model_params/scPlantLLM_model.pth"

try:
    model.load_state_dict(torch.load(model_name))
except:

    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_name)
    pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

In [11]:
criterion_gep_gepc = masked_mse_loss
criterion_cls = nn.CrossEntropyLoss()
criterion_dab = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=model_config.lr, eps=1e-4 if config.amp else 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=model_config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)
best_val_loss = float("inf")


In [12]:
for epoch in range(model_config.epochs):
    epoch_start_time = time.time()
    train(model, train_loader, criterion_gep_gepc, criterion_dab, criterion_cls, scaler, optimizer, scheduler, device, config,  epoch, model_config.parallel)
    epoch_end_time = time.time()
    print(f"Epoch {epoch} time: {epoch_end_time - epoch_start_time}")
    
    val_loss = evaluate(model, test_loader, criterion_gep_gepc, criterion_dab, criterion_cls, device, config, epoch)

    current_time = datetime.datetime.now()
    timestamp = current_time.strftime("%YY%mM%dD%HH%MM%SS")
    save_path = f'./model_param/{config.train_strategy}'
    os.makedirs(save_path, exist_ok=True)

    checkpoint_path = os.path.join(save_path, f"{timestamp}_{config.input_emb_style}_model_{epoch}.pth")
    # torch.save(model.module.state_dict(), checkpoint_path)
        
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        best_model_name = f"best_model_{config.input_emb_style}_{best_model_epoch}_{timestamp}.pth"

end_time = time.time()
print(f"Train time: {end_time - start_time}")
print("Train finished!")

| epoch   0 | 100/356 batches | train/accuracy: 0.79546875, train/error_rate: 0.21453125
| epoch   0 | 100/356 batches | lr 0.00010 | ms/batch 500.44 | loss 51709.86 | scale factor: 32768.0 |scaled loss  1.58 |cls  1.31 | 
| epoch   0 | 200/356 batches | train/accuracy: 0.93015625, train/error_rate: 0.06984375
| epoch   0 | 200/356 batches | lr 0.00010 | ms/batch 482.80 | loss 7034.68 | scale factor: 32768.0 |scaled loss  0.21 |cls  0.21 | 
| epoch   0 | 300/356 batches | train/accuracy: 0.9328125, train/error_rate: 0.0671875
| epoch   0 | 300/356 batches | lr 0.00010 | ms/batch 439.30 | loss 6845.37 | scale factor: 32768.0 |scaled loss  0.21 |cls  0.21 | 
Epoch 0 time: 163.0573172569275
valid/loss: 91665.9375, valid/cls: 0.003290872199853, valid/accuracy: 0.9342105263157895, valid/precision: 0.9230882822477018, valid/recall: 0.9104830924928804, valid/macro_f1: 0.9152900094932492, valid/micro_f1: 0.9342105263157895
Train time: 380.6561756134033
Train finished!


## Prediction

In [13]:
celltype_vocab_path = f'data/processed/has_celltype/cell_type_vocab.meta.json'
with open(celltype_vocab_path) as f:
        celltype_vocab = json.load(f)
celltype_vocab = {value : key for key, value in celltype_vocab.items()} 
batch_effect_vocab_file = f'data/processed/has_celltype/batch_effect_vocab.meta.json'
with open(batch_effect_vocab_file) as f:
        batch_effect_vocab = json.load(f)
batch_effect_vocab  = {value : key for key, value in batch_effect_vocab.items()} 

In [14]:
fine_tune = True
if fine_tune:
    criterion_gep_gepc = masked_mse_loss
    criterion_cls = nn.CrossEntropyLoss()
    criterion_dab = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=model_config.lr, eps=1e-4 if config.amp else 1e-8)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=model_config.schedule_ratio)
    scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

    best_val_loss = float("inf")
    for epoch in range(model_config.epochs):
            epoch_start_time = time.time()
            train(model, train_loader, criterion_gep_gepc, criterion_dab, criterion_cls, scaler, optimizer, scheduler, device, config,  epoch, model_config.parallel)
            epoch_end_time = time.time()
            print(f"Epoch {epoch} time: {epoch_end_time - epoch_start_time}")
    
            val_loss = evaluate(model, valid_loader, criterion_gep_gepc, criterion_dab, criterion_cls, device, config, epoch)



| epoch   0 | 100/356 batches | train/accuracy: 0.95, train/error_rate: 0.06
| epoch   0 | 100/356 batches | lr 0.00010 | ms/batch 413.25 | loss 11988.32 | scale factor: 65536.0 |scaled loss  0.18 |cls  0.18 | 
| epoch   0 | 200/356 batches | train/accuracy: 0.95734375, train/error_rate: 0.04265625
| epoch   0 | 200/356 batches | lr 0.00010 | ms/batch 403.45 | loss 8567.25 | scale factor: 65536.0 |scaled loss  0.13 |cls  0.13 | 
| epoch   0 | 300/356 batches | train/accuracy: 0.95125, train/error_rate: 0.04875
| epoch   0 | 300/356 batches | lr 0.00010 | ms/batch 395.58 | loss 9403.54 | scale factor: 65536.0 |scaled loss  0.14 |cls  0.14 | 
Epoch 0 time: 143.29024934768677
valid/loss: 68070.09375, valid/cls: 0.002264972477477338, valid/accuracy: 0.9588815789473685, valid/precision: 0.9478250416583902, valid/recall: 0.9407023940006436, valid/macro_f1: 0.9413680357006167, valid/micro_f1: 0.9588815789473685


In [15]:
start_time = time.time()
cell_types_predictions, cell_types_labels, cell_names, probabilities, cell_embeddings, batch_labels_list = test(model, test_loader, test_metadata, device, config)
predict_end_time = time.time()
print(cell_types_predictions[:5])
print(cell_types_labels[:5])
print(f"Using time to predict: {predict_end_time - start_time}")

test/accuracy: 0.944078947368421, test/precision: 0.938828142548302, test/recall: 0.914617949871456, test/macro_f1: 0.924546586675447, test/micro_f1: 0.944078947368421
[3, 3, 7, 3, 8]
[3, 3, 7, 3, 8]
Using time to predict: 8.770305871963501


## Inference

In [16]:
start_time = time.time()
cell_types_predictions, _, cell_names, probabilities, cell_embeddings, batch_labels_list = inference(model, test_loader, test_metadata, device, config)
predict_end_time = time.time()
print(f"Using time to predict: {predict_end_time - start_time}")
print(cell_types_predictions[:5])

Using time to predict: 8.455778360366821
[3, 3, 7, 3, 8]


In [17]:
wandb.finish()