## MLM_pretrain

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
project_path = '/media/workspace/caoguangshuo/scPlantGPT'
os.chdir(f'{project_path}/s03_scPlantGPT/trainer')
import sys
import torch
import time
import json
from torch import nn
import copy
from loss import masked_mse_loss
from utils import data_loader, generation_evaluate, pretrain_generation,load_config, train, evaluate, test, inference
import datetime
import wandb
sys.path.insert(0, "../")
from scplantllm.model import TransformerModel

In [2]:
start_time = time.time()
config = load_config('../Util/test.json')

In [3]:
hyperparameter_defaults = dict(
    parallel=True,
    epochs=1, 
    batch_size=64,
    lr=1e-4,
    ntoken= 185622,
    nctype= 44, 
    nbatch_effect= 238,
    ecs_threshold=0.0, 
    layer_size=512,
    hlayer_size=512,
    nlayers=6,
    nhead=8,
    nlayers_cls=3,
    dropout=0.5,
    schedule_ratio=0.9, 
    save_eval_interval=5,
    fast_transformer=True,
    explicit_zero_prob=False,
    pre_norm=True,
)
current_time = datetime.datetime.now()
timestamp = current_time.strftime("%YY%mM%dD%HH%MM%SS")
run = wandb.init(
    config=hyperparameter_defaults,
    project="test",
    entity="aibio",
    group=f"{config.train_strategy}_{config.input_emb_style}",
)
model_config = wandb.config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if config.input_emb_style == "category":
    n_input_bins = config.n_bins + 2 # pad_value:-2, cls_value:0, masked_value:-1
else:
    n_input_bins = config.n_bins

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mcgshuo[0m ([33maibio[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
model = TransformerModel(
    ntoken=model_config.ntoken, 
    d_model=model_config.layer_size, 
    nhead=model_config.nhead, 
    d_hid=model_config.hlayer_size,
    nlayers=model_config.nlayers, 
    nlayers_cls=model_config.nlayers_cls, 
    n_cls=model_config.nctype, 
    dropout=model_config.dropout, 
    pad_value=int(config.pad_value),
    pad_token_id=config.pad_token_id, 
    do_mvc=config.GEPC, 
    do_dab=True, 
    use_batch_labels=config.use_batch_labels, 
    num_batch_labels=model_config.nbatch_effect, 
    domain_spec_batchnorm=config.DSBN, 
    input_emb_style=config.input_emb_style, 
    n_input_bins= n_input_bins, 
    cell_emb_style="cls", 
    mvc_decoder_style="inner product", 
    ecs_threshold=model_config.ecs_threshold, 
    explicit_zero_prob=model_config.explicit_zero_prob, 
    use_fast_transformer=model_config.fast_transformer, 
    pre_norm=model_config.pre_norm,)

model.to(device)



TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(185622, 512, padding_idx=0)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): CategoryValueEncoder(
    (embedding): Embedding(103, 512, padding_idx=101)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x FlashTransformerEncoderLayer(
        (self_attn): FlashMHA(
          (Wqkv): Linear(in_features=512, out_features=1536, bias=True)
          (inner_attn): FlashAttention()
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, e

In [5]:
data_path = f'/public/workspace/caoguangshuo/scPlantGPT_analysis/github_test/model_data/has_celltype'

In [22]:
train_sampler, train_loader, train_size=model_config.metadata = data_loader(data_path, data_type='train', start_chunk=1, end_chunk=1, batch_size=model_config.batch_size, append_cls=True)
valid_sampler, valid_loader, valid_metadata = data_loader(data_path, data_type='valid', start_chunk=1, num_chunks=1, batch_size=model_config.batch_size,append_cls=True)
test_sampler, test_loader, test_metadata = data_loader(data_path,  data_type='test',start_chunk=1, num_chunks=1, batch_size=model_config.batch_size, append_cls=True)



In [23]:
criterion_gep_gepc = masked_mse_loss
optimizer = torch.optim.Adam(
    model.parameters(), lr=model_config.lr, eps= 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=model_config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

In [24]:
best_val_loss = float("inf")
for epoch in range(model_config.epochs):
    epoch_start_time = time.time()

    pretrain_generation(model, train_loader, criterion_gep_gepc, scaler, optimizer, scheduler, device, config, epoch)
    with torch.no_grad():
        val_loss = generation_evaluate(model, valid_loader, criterion_gep_gepc, device, config, epoch)

    current_time = datetime.datetime.now()
    timestamp = current_time.strftime("%YY%mM%dD%HH%MM%SS")
    save_path = f'./model_param/{config.train_strategy}'
    os.makedirs(save_path, exist_ok=True)

    checkpoint_path = os.path.join(save_path, f"{timestamp}_{config.input_emb_style}_model_{epoch}.pth")
    # torch.save(model.module.state_dict(), checkpoint_path)
        
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        best_model_name = f"best_model_{config.input_emb_style}_{best_model_epoch}_{timestamp}.pth"


| epoch   0 | 100/356 batches | lr 0.00010 | ms/batch 133.24 | loss  1.98 | Scale Factor: 65536.0 | real loss: 1.98 | curl gep:  0.00 | mre 10245.78
| epoch   0 | 200/356 batches | lr 0.00010 | ms/batch 130.44 | loss  4.01 | Scale Factor: 32768.0 | real loss: 4.01 | curl gep:  0.00 | mre 7211.55
| epoch   0 | 300/356 batches | lr 0.00010 | ms/batch 130.66 | loss  1.88 | Scale Factor: 32768.0 | real loss: 1.88 | curl gep:  0.00 | mre 8512.01


In [25]:
end_time = time.time()
print(f"Total time: {end_time - start_time} seconds")

Total time: 1163.7337045669556 seconds


## CLS_pretrain

In [26]:
model_name = f"/media/workspace/caoguangshuo/scPlantGPT/s03_scPlantGPT/trainer/model_param/scAraGPT_pretrain_clean_label_nlayer_6_mask0.15/2024Y08M17D23H15M48S_category_model_10.pth"

try:
    model.load_state_dict(torch.load(model_name))
except:

    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_name)
    pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

In [27]:
criterion_gep_gepc = masked_mse_loss
criterion_cls = nn.CrossEntropyLoss()
criterion_dab = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=model_config.lr, eps=1e-4 if config.amp else 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=model_config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)
best_val_loss = float("inf")


In [28]:
for epoch in range(model_config.epochs):
    epoch_start_time = time.time()
    train(model, train_loader, criterion_gep_gepc, criterion_dab, criterion_cls, scaler, optimizer, scheduler, device, config,  epoch, model_config.parallel)
    epoch_end_time = time.time()
    print(f"Epoch {epoch} time: {epoch_end_time - epoch_start_time}")
    
    val_loss = evaluate(model, test_loader, criterion_gep_gepc, criterion_dab, criterion_cls, device, config, epoch)

    current_time = datetime.datetime.now()
    timestamp = current_time.strftime("%YY%mM%dD%HH%MM%SS")
    save_path = f'./model_param/{config.train_strategy}'
    os.makedirs(save_path, exist_ok=True)

    checkpoint_path = os.path.join(save_path, f"{timestamp}_{config.input_emb_style}_model_{epoch}.pth")
    # torch.save(model.module.state_dict(), checkpoint_path)
        
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        best_model_name = f"best_model_{config.input_emb_style}_{best_model_epoch}_{timestamp}.pth"

end_time = time.time()
print(f"Train time: {end_time - start_time}")
print("Train finished!")

| epoch   0 | 100/356 batches | train/accuracy: 0.6409375, train/error_rate: 0.3690625
| epoch   0 | 100/356 batches | lr 0.00010 | ms/batch 136.15 | loss 80766.91 | scale factor: 65536.0 |scaled loss  1.23 |cls  1.23 | 
| epoch   0 | 200/356 batches | train/accuracy: 0.89828125, train/error_rate: 0.10171875
| epoch   0 | 200/356 batches | lr 0.00010 | ms/batch 134.10 | loss 21526.94 | scale factor: 65536.0 |scaled loss  0.33 |cls  0.33 | 
| epoch   0 | 300/356 batches | train/accuracy: 0.91671875, train/error_rate: 0.08328125
| epoch   0 | 300/356 batches | lr 0.00010 | ms/batch 135.96 | loss 17929.13 | scale factor: 65536.0 |scaled loss  0.27 |cls  0.27 | 
Epoch 0 time: 48.18659210205078
valid/loss: 7193.296875, valid/cls: 0.0037097649428209194, valid/accuracy: 0.928453947368421, valid/precision: 0.9130795682003277, valid/recall: 0.912497585510345, valid/macro_f1: 0.9098543813190176, valid/micro_f1: 0.928453947368421
Train time: 1213.9485495090485
Train finished!


## Prediction

In [29]:
celltype_vocab_path = f'{project_path}/s03_scPlantGPT/cross_data/Ara_celltype_record_clean_vocab.meta.json'
with open(celltype_vocab_path) as f:
        celltype_vocab = json.load(f)
celltype_vocab = {value : key for key, value in celltype_vocab.items()} 
batch_effect_vocab_file = f'{project_path}/s03_scPlantGPT/cross_data/Ara_batch_effect_vocab.meta.json'
with open(batch_effect_vocab_file) as f:
        batch_effect_vocab = json.load(f)
batch_effect_vocab  = {value : key for key, value in batch_effect_vocab.items()} 

In [30]:
fine_tune = True
if fine_tune:
    criterion_gep_gepc = masked_mse_loss
    criterion_cls = nn.CrossEntropyLoss()
    criterion_dab = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=model_config.lr, eps=1e-4 if config.amp else 1e-8)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=model_config.schedule_ratio)
    scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

    best_val_loss = float("inf")
    for epoch in range(model_config.epochs):
            epoch_start_time = time.time()
            train(model, train_loader, criterion_gep_gepc, criterion_dab, criterion_cls, scaler, optimizer, scheduler, device, config,  epoch, model_config.parallel)
            epoch_end_time = time.time()
            print(f"Epoch {epoch} time: {epoch_end_time - epoch_start_time}")
    
            val_loss = evaluate(model, valid_loader, criterion_gep_gepc, criterion_dab, criterion_cls, device, config, epoch)



| epoch   0 | 100/356 batches | train/accuracy: 0.94125, train/error_rate: 0.06875
| epoch   0 | 100/356 batches | lr 0.00010 | ms/batch 135.85 | loss 13909.18 | scale factor: 65536.0 |scaled loss  0.21 |cls  0.21 | 
| epoch   0 | 200/356 batches | train/accuracy: 0.93953125, train/error_rate: 0.06046875
| epoch   0 | 200/356 batches | lr 0.00010 | ms/batch 135.15 | loss 12491.03 | scale factor: 65536.0 |scaled loss  0.19 |cls  0.19 | 
| epoch   0 | 300/356 batches | train/accuracy: 0.940625, train/error_rate: 0.059375
| epoch   0 | 300/356 batches | lr 0.00010 | ms/batch 136.78 | loss 12015.00 | scale factor: 65536.0 |scaled loss  0.18 |cls  0.18 | 
Epoch 0 time: 48.32874774932861
valid/loss: 6326.0869140625, valid/cls: 0.0029768527846930452, valid/accuracy: 0.944078947368421, valid/precision: 0.9214912245972599, valid/recall: 0.9215046225968255, valid/macro_f1: 0.9197182795961176, valid/micro_f1: 0.944078947368421


In [31]:
start_time = time.time()
cell_types_predictions, cell_types_labels, cell_names, probabilities, cell_embeddings, batch_labels_list = test(model, test_loader, test_metadata, device, config)
predict_end_time = time.time()
print(f"Using time to predict: {predict_end_time - start_time}")

test/accuracy: 0.9407894736842105, test/precision: 0.9237605047762502, test/recall: 0.9207030126581097, test/macro_f1: 0.9209807292497322, test/micro_f1: 0.9407894736842105
Using time to predict: 4.82179069519043


## Inference

In [36]:
start_time = time.time()
cell_types_predictions, cell_types_labels, cell_names, probabilities, cell_embeddings, batch_labels_list = inference(model, test_loader, test_metadata, device, config)
predict_end_time = time.time()
print(f"Using time to predict: {predict_end_time - start_time}")
print(cell_types_predictions[:5])

Using time to predict: 4.338064432144165
[17, 29, 17, 20, 19]


[1;34mwandb[0m: 🚀 View run [33mpear-pastry-17[0m at: [34mhttps://wandb.ai/aibio/test/runs/fr0jukih[0m
[1;34mwandb[0m: Find logs at: [1;35mwandb/run-20250314_171919-fr0jukih/logs[0m


In [None]:
wandb.finish()