In [29]:
import pandas as pd
import torch
import lightning as L

from model.modeling_demolta import DeMOLTaConfig
from trainer import LitMOLAForRegression, get_finetune_dataloader, scaffold_split, SaveTrainableParamsCheckpoint

In [2]:
BATCH_SIZE = 4
SEED = 42
TEXT_MODEL_NAME = 'facebook/galactica-125m'

In [3]:
df = pd.read_csv('./data/train.csv')
smiles = df['SMILES'].tolist()

In [4]:
dfs = scaffold_split(df, smiles, 0.2, seed=SEED, k_fold=5, spplitter='fingerprints')

In [5]:
for train_df, val_df in dfs:
    break

In [6]:
train_dataloader = get_finetune_dataloader(
    df=train_df,
    batch_size=BATCH_SIZE,
)

In [7]:
val_dataloader = get_finetune_dataloader(
    df=val_df,
    batch_size=BATCH_SIZE,
)

In [8]:
demolta_config = DeMOLTaConfig(
    num_layers=12,
    hidden_dim=384,
    ff_dim=1536,
    num_heads=6,
)

In [9]:
lit_model = LitMOLAForRegression(
    demolta_config=demolta_config,
    text_model_name=TEXT_MODEL_NAME,
    n_class=2,
)



In [24]:
checkpoint = torch.load('./checkpoint/mola-pretrain-step=50000-val_loss=1.48.ckpt')

In [28]:
lit_model.load_state_dict(checkpoint, strict=False)

_IncompatibleKeys(missing_keys=['model.regressor.weight', 'model.regressor.bias'], unexpected_keys=[])

In [None]:
checkpoint_callback = SaveTrainableParamsCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoint/',
    filename='mola-pretrain50000-finetune+{step}-{val_loss:.2f}',
    save_top_k=3,
)

In [10]:
trainer = L.Trainer(
    accelerator='gpu',
    precision='bf16',
    max_epochs=10,
    callbacks=[checkpoint_callback],
    val_check_interval=0.5,
)

In [12]:
trainer.fit(lit_model, train_dataloader, val_dataloader)

(tensor(2221.4438, grad_fn=<MseLossBackward0>),
 tensor(4246.1055, grad_fn=<MseLossBackward0>))