In [2]:
import pandas as pd
import torch
import lightning as L

from model.modeling_demolta import DeMOLTaConfig
from trainer import LitMOLLAForRegression, LitDeMOLTaForRegression, SaveTrainableParamsCheckpoint
from datautils import LitMOLLAFineTuneDataModule, LitDeMOLTaFineTuneDataModule

In [3]:
BATCH_SIZE = 4
SEED = 42
TEXT_MODEL_NAME = './Llama-2-7b-hf'
# TEXT_MODEL_NAME = 'meta-llama/Llama-2-7b-hf'
# HF_ACCESS_TOKEN = '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'

In [4]:
L.seed_everything(SEED)

Global seed set to 42


42

In [5]:
lit_finetune_data_module = LitDeMOLTaFineTuneDataModule(
    df_path='./data/train.csv',
    batch_size=BATCH_SIZE,
    seed=SEED,
    k_fold=5,
    train_fold=0,
)

In [6]:
demolta_config = DeMOLTaConfig(
    num_layers=12,
    node_hidden_dim=768,
    edge_hidden_dim=256,
    node_ff_dim=3072,
    edge_ff_dim=1536,
    num_heads=12,
)

In [None]:
lit_model = LitMOLLAForRegression(
    demolta_config=demolta_config,
    text_model_name=TEXT_MODEL_NAME,
)

In [4]:
lit_finetune_data_module = LitMOLLAFineTuneDataModule(
    df_path='./data/train.csv',
    batch_size=BATCH_SIZE,
    query = 'What is the LC-MS/MS percentage value of the molecule after reacting with MLM (Mouse Liver Microsome) for 30 minutes?',
    column_name='MLM',
    tokenizer_name = TEXT_MODEL_NAME,
    seed=SEED,
    k_fold=5,
    train_fold=0,
)

In [5]:
demolta_config = DeMOLTaConfig(
    num_layers=12,
    node_hidden_dim=768,
    edge_hidden_dim=256,
    node_ff_dim=3072,
    edge_ff_dim=1536,
    num_heads=12,
)

In [6]:
lit_model = LitMOLLAForRegression(
    demolta_config=demolta_config,
    text_model_name=TEXT_MODEL_NAME,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [7]:
checkpoint = torch.load('./checkpoints/mola-pretrain-base-Llama-2-7b-hf-step=60000-train_loss=2.1271-val_loss=2.60.ckpt')
lit_model.load_state_dict(checkpoint, strict=False)

_IncompatibleKeys(missing_keys=['model.language_model.model.embed_tokens.weight', 'model.language_model.model.layers.0.self_attn.q_proj.weight', 'model.language_model.model.layers.0.self_attn.k_proj.weight', 'model.language_model.model.layers.0.self_attn.v_proj.weight', 'model.language_model.model.layers.0.self_attn.o_proj.weight', 'model.language_model.model.layers.0.self_attn.rotary_emb.inv_freq', 'model.language_model.model.layers.0.mlp.gate_proj.weight', 'model.language_model.model.layers.0.mlp.down_proj.weight', 'model.language_model.model.layers.0.mlp.up_proj.weight', 'model.language_model.model.layers.0.input_layernorm.weight', 'model.language_model.model.layers.0.post_attention_layernorm.weight', 'model.language_model.model.layers.1.self_attn.q_proj.weight', 'model.language_model.model.layers.1.self_attn.k_proj.weight', 'model.language_model.model.layers.1.self_attn.v_proj.weight', 'model.language_model.model.layers.1.self_attn.o_proj.weight', 'model.language_model.model.layers

In [8]:
# for param in lit_model.model.mol_model.parameters():
#     param.requires_grad = False

In [9]:
checkpoint_callback = SaveTrainableParamsCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoints/',
    filename='molla-llama2-pretrain=60000-finetune-{epoch}-{val_loss:.4f}',
    save_top_k=1,
)

In [10]:
trainer = L.Trainer(
    accelerator='gpu',
    max_epochs=10,
    callbacks=[checkpoint_callback],
    val_check_interval=0.5,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [11]:
trainer.fit(lit_model, lit_finetune_data_module)

You are using a CUDA device ('NVIDIA GeForce RTX 3070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading some Jax models, missing a dependency. No module named 'jax'
  df = df.groupby('SMILES').mean().reset_index()
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 8.00 GiB total capacity; 7.25 GiB already allocated; 0 bytes free; 7.28 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF