## MLM_pretrain

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
project_path = '/media/workspace/caoguangshuo/scPlantLLM'
os.chdir(f'{project_path}/s03_scPlantLLM/trainer')
import sys
import torch
import time
import json
from torch import nn
import copy
from loss import masked_mse_loss
from utils import data_loader, generation_evaluate, pretrain_generation,load_config, train, evaluate, test
import datetime
import wandb
sys.path.insert(0, "../")
from model import TransformerModel

In [None]:
start_time = time.time()
config = load_config('setting.json')

In [3]:
hyperparameter_defaults = dict(
    parallel=True,
    epochs=1, 
    batch_size=64,
    lr=1e-4,
    ntoken= 185622,
    nctype= 44, 
    nbatch_effect= 238,
    ecs_threshold=0.0, 
    layer_size=512,
    hlayer_size=512,
    nlayers=6,
    nhead=8,
    nlayers_cls=3,
    dropout=0.5,
    schedule_ratio=0.9, 
    save_eval_interval=5,
    fast_transformer=True,
    explicit_zero_prob=False,
    pre_norm=True,
)
current_time = datetime.datetime.now()
timestamp = current_time.strftime("%YY%mM%dD%HH%MM%SS")
run = wandb.init(
    config=hyperparameter_defaults,
    project="test",
    entity="aibio",
    group=f"{config.train_strategy}_{config.input_emb_style}",
)
model_config = wandb.config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if config.input_emb_style == "category":
    n_input_bins = config.n_bins + 2 # pad_value:-2, cls_value:0, masked_value:-1
else:
    n_input_bins = config.n_bins

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mcgshuo[0m ([33maibio[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
model = TransformerModel(
    ntoken=model_config.ntoken, 
    d_model=model_config.layer_size, 
    nhead=model_config.nhead, 
    d_hid=model_config.hlayer_size,
    nlayers=model_config.nlayers, 
    nlayers_cls=model_config.nlayers_cls, 
    n_cls=model_config.nctype, 
    dropout=model_config.dropout, 
    pad_value=int(config.pad_value),
    pad_token_id=config.pad_token_id, 
    do_mvc=config.GEPC, 
    do_dab=True, 
    use_batch_labels=config.use_batch_labels, 
    num_batch_labels=model_config.nbatch_effect, 
    domain_spec_batchnorm=config.DSBN, 
    input_emb_style=config.input_emb_style, 
    n_input_bins= n_input_bins, 
    cell_emb_style="cls", 
    mvc_decoder_style="inner product", 
    ecs_threshold=model_config.ecs_threshold, 
    explicit_zero_prob=model_config.explicit_zero_prob, 
    use_fast_transformer=model_config.fast_transformer, 
    pre_norm=model_config.pre_norm,)

model.to(device)



TransformerModel(
  (encoder): GeneEncoder(
    (embedding): Embedding(185622, 512, padding_idx=0)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (value_encoder): CategoryValueEncoder(
    (embedding): Embedding(103, 512, padding_idx=101)
    (enc_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x FlashTransformerEncoderLayer(
        (self_attn): FlashMHA(
          (Wqkv): Linear(in_features=512, out_features=1536, bias=True)
          (inner_attn): FlashAttention()
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=512, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=512, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, e

In [5]:
data_path = f'{project_path}/s03_scPlantGPT/cross_data/independent_clean_label/Ara_test'

In [6]:
train_sampler, train_loader, trai_size=model_config.bn_metadata = data_loader(data_path, data_type='train', start_chunk=1, end_chunk=1, batch_size=model_config.batch_size, append_cls=True)
valid_sampler, valid_loader, valid_metadata = data_loader(data_path, data_type='valid', start_chunk=1, num_chunks=1, batch_size=model_config.batch_size,append_cls=True)
test_sampler, test_loader, test_metadata = data_loader(data_path,  data_type='test',start_chunk=1, num_chunks=1, batch_size=model_config.batch_size, append_cls=True)



In [7]:
criterion_gep_gepc = masked_mse_loss
optimizer = torch.optim.Adam(
    model.parameters(), lr=model_config.lr, eps= 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=model_config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

In [8]:
best_val_loss = float("inf")
for epoch in range(model_config.epochs):
    epoch_start_time = time.time()

    pretrain_generation(model, train_loader, criterion_gep_gepc, scaler, optimizer, scheduler, device, config, epoch)
    with torch.no_grad():
        val_loss = generation_evaluate(model, valid_loader, criterion_gep_gepc, device, config,  epoch)

    current_time = datetime.datetime.now()
    timestamp = current_time.strftime("%YY%mM%dD%HH%MM%SS")
    save_path = f'./model_param/{config.train_strategy}'
    os.makedirs(save_path, exist_ok=True)

    checkpoint_path = os.path.join(save_path, f"{timestamp}_{config.input_emb_style}_model_{epoch}.pth")
    # torch.save(model.module.state_dict(), checkpoint_path)
        
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        best_model_name = f"best_model_{config.input_emb_style}_{best_model_epoch}_{timestamp}.pth"


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


| epoch   0 | 100/356 batches | lr 0.00010 | ms/batch 145.93 | loss  2.69 | Scale Factor: 16384.0 | real loss: 2.69 | curl gep:  0.00 | mre 30823.29
| epoch   0 | 200/356 batches | lr 0.00010 | ms/batch 142.69 | loss  2.13 | Scale Factor: 16384.0 | real loss: 2.13 | curl gep:  0.00 | mre 14751.19
| epoch   0 | 300/356 batches | lr 0.00010 | ms/batch 143.27 | loss  2.01 | Scale Factor: 16384.0 | real loss: 2.01 | curl gep:  0.00 | mre 10807.54


In [9]:
end_time = time.time()
print(f"Total time: {end_time - start_time} seconds")

Total time: 69.9159927368164 seconds


## CLS_pretrain

In [10]:
model_name = f"/media/workspace/caoguangshuo/scPlantGPT/s03_scPlantGPT/trainer/model_param/scAraGPT_pretrain_clean_label_nlayer_6_mask0.15/2024Y08M17D23H15M48S_category_model_10.pth"

try:
    model.load_state_dict(torch.load(model_name))
except:

    model_dict = model.state_dict()
    pretrained_dict = torch.load(model_name)
    pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

In [11]:
criterion_gep_gepc = masked_mse_loss
criterion_cls = nn.CrossEntropyLoss()
criterion_dab = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=model_config.lr, eps=1e-4 if config.amp else 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=model_config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)
best_val_loss = float("inf")


In [12]:
for epoch in range(model_config.epochs):
    epoch_start_time = time.time()
    train(model, train_loader, criterion_gep_gepc, criterion_dab, criterion_cls, scaler, optimizer, scheduler, device, config,  epoch, model_config.parallel)
    epoch_end_time = time.time()
    print(f"Epoch {epoch} time: {epoch_end_time - epoch_start_time}")
    
    val_loss = evaluate(model, test_loader, criterion_gep_gepc, criterion_dab, criterion_cls, device, config, epoch)

    current_time = datetime.datetime.now()
    timestamp = current_time.strftime("%YY%mM%dD%HH%MM%SS")
    save_path = f'./model_param/{config.train_strategy}'
    os.makedirs(save_path, exist_ok=True)

    checkpoint_path = os.path.join(save_path, f"{timestamp}_{config.input_emb_style}_model_{epoch}.pth")
    # torch.save(model.module.state_dict(), checkpoint_path)
        
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        best_model_name = f"best_model_{config.input_emb_style}_{best_model_epoch}_{timestamp}.pth"

end_time = time.time()
print(f"Train time: {end_time - start_time}")
print("Train finished!")

| epoch   0 | 100/356 batches | train/accuracy: 0.63625, train/error_rate: 0.37375
| epoch   0 | 100/356 batches | lr 0.00010 | ms/batch 152.11 | loss 81370.35 | scale factor: 65536.0 |scaled loss  1.24 |cls  1.24 | 
| epoch   0 | 200/356 batches | train/accuracy: 0.8825, train/error_rate: 0.1175
| epoch   0 | 200/356 batches | lr 0.00010 | ms/batch 161.64 | loss 24675.50 | scale factor: 65536.0 |scaled loss  0.38 |cls  0.38 | 
| epoch   0 | 300/356 batches | train/accuracy: 0.91234375, train/error_rate: 0.08765625
| epoch   0 | 300/356 batches | lr 0.00010 | ms/batch 173.82 | loss 18255.73 | scale factor: 65536.0 |scaled loss  0.28 |cls  0.28 | 
Epoch 0 time: 58.6326789855957
valid/loss: 11605.9091796875, valid/cls: 0.003593619496218468, valid/accuracy: 0.9259868421052632, valid/precision: 0.8966481461133624, valid/recall: 0.8810464374992562, valid/macro_f1: 0.8876422074238292, valid/micro_f1: 0.9259868421052632
Train time: 130.5401647090912
Train finished!


## Prediction

In [13]:
celltype_vocab_path = f'{project_path}/s03_scPlantGPT/cross_data/Ara_celltype_record_clean_vocab.meta.json'
with open(celltype_vocab_path) as f:
        celltype_vocab = json.load(f)
celltype_vocab = {value : key for key, value in celltype_vocab.items()} 
batch_effect_vocab_file = f'{project_path}/s03_scPlantGPT/cross_data/Ara_batch_effect_vocab.meta.json'
with open(batch_effect_vocab_file) as f:
        batch_effect_vocab = json.load(f)
batch_effect_vocab  = {value : key for key, value in batch_effect_vocab.items()} 

In [14]:
fine_tune = True
if fine_tune:
    criterion_gep_gepc = masked_mse_loss
    criterion_cls = nn.CrossEntropyLoss()
    criterion_dab = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=model_config.lr, eps=1e-4 if config.amp else 1e-8)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=model_config.schedule_ratio)
    scaler = torch.cuda.amp.GradScaler(enabled=config.amp)

    best_val_loss = float("inf")
    for epoch in range(model_config.epochs):
            epoch_start_time = time.time()
            train(model, train_loader, criterion_gep_gepc, criterion_dab, criterion_cls, scaler, optimizer, scheduler, device, config,  epoch, model_config.parallel)
            epoch_end_time = time.time()
            print(f"Epoch {epoch} time: {epoch_end_time - epoch_start_time}")
    
            val_loss = evaluate(model, valid_loader, criterion_gep_gepc, criterion_dab, criterion_cls, device, config, epoch)



| epoch   0 | 100/356 batches | train/accuracy: 0.94078125, train/error_rate: 0.06921875
| epoch   0 | 100/356 batches | lr 0.00010 | ms/batch 192.14 | loss 14419.05 | scale factor: 65536.0 |scaled loss  0.22 |cls  0.22 | 
| epoch   0 | 200/356 batches | train/accuracy: 0.9303125, train/error_rate: 0.0696875
| epoch   0 | 200/356 batches | lr 0.00010 | ms/batch 177.44 | loss 13790.64 | scale factor: 65536.0 |scaled loss  0.21 |cls  0.21 | 
| epoch   0 | 300/356 batches | train/accuracy: 0.93921875, train/error_rate: 0.06078125
| epoch   0 | 300/356 batches | lr 0.00010 | ms/batch 187.89 | loss 12033.29 | scale factor: 65536.0 |scaled loss  0.18 |cls  0.18 | 
Epoch 0 time: 67.1901388168335
valid/loss: 21615.615234375, valid/cls: 0.0025315810313546344, valid/accuracy: 0.9473684210526315, valid/precision: 0.9326877438457523, valid/recall: 0.9223693415099895, valid/macro_f1: 0.9244231323973169, valid/micro_f1: 0.9473684210526315


In [15]:
start_time = time.time()
cell_types_predictions, cell_types_labels, cell_names, probabilities, cell_embeddings, batch_labels_list = test(model, test_loader, test_metadata, device, config)
predict_end_time = time.time()
print(f"Using time to predict: {predict_end_time - start_time}")

test/accuracy: 0.9325657894736842, test/precision: 0.9188976906119279, test/recall: 0.8895667581705615, test/macro_f1: 0.9010580003023527, test/micro_f1: 0.9325657894736842
Using time to predict: 3.7367961406707764


In [None]:
wandb.finish()