In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
import lightning as L
import datasets
import pandas as pd

from torch.utils.data import DataLoader
from lightning.pytorch.callbacks import ModelCheckpoint

from src.model.modeling_instruction_tuning import LitInstructionModel

In [3]:
def train_valid_test_split(
    ds,
    split_ratio,
    seed: int,
):
    ds_train_test = ds.train_test_split(
        test_size=split_ratio[2] / sum(split_ratio), seed=seed
    )
    ds_train_valid = ds_train_test["train"].train_test_split(
        test_size=split_ratio[1] / (split_ratio[1] + split_ratio[0]), seed=seed
    )
    ds_train = ds_train_valid["train"]
    ds_valid = ds_train_valid["test"]
    ds_test = ds_train_test["test"]
    return ds_train, ds_valid, ds_test

In [4]:
SEED=42
DATASET_NAME = 'AutoML/bitviper'
MINI_BATCH_SIZE=2
N_BATCH = 8
BASE_MODEL_NAME='Qwen/Qwen3-0.6B'
EPOCHS=20
LEARNING_RATE = 5e-5
MAX_LENGTH=128
USE_QLORA = False
SPLITS = (1, 20, 79)


In [5]:
L.seed_everything(SEED)

Seed set to 42


42

In [6]:
ds = datasets.load_dataset(DATASET_NAME, split="train")
def preprocessing(example):
    example['sentence_noisy'] = example['text'][:MAX_LENGTH]
    example['sentence'] = example['label'][:MAX_LENGTH]
    return example
ds = ds.map(preprocessing)
ds_train, ds_valid, ds_test = train_valid_test_split(
    ds, SPLITS, SEED
)
ds_valid = ds_valid.select(range(len(ds_train)))
dl_train, dl_valid, dl_test = DataLoader(ds_train, batch_size=MINI_BATCH_SIZE), DataLoader(ds_valid, batch_size=MINI_BATCH_SIZE), DataLoader(ds_test, batch_size=1)

In [7]:
lit_model = LitInstructionModel(
    base_model_name=BASE_MODEL_NAME,
    lr=LEARNING_RATE,
    use_qlora=USE_QLORA,
    epochs=EPOCHS
)

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath=f"checkpoint/inst/{DATASET_NAME.split('/')[1]}/{BASE_MODEL_NAME.split('/')[1]}/",
    filename="{epoch:02d}-{valid_loss:.4f}",
    every_n_epochs=1,
    save_top_k=-1,
    save_weights_only=True
)

In [9]:
trainer = L.Trainer(
    callbacks=[checkpoint_callback], 
    precision='16-mixed',
    max_epochs=EPOCHS,
    enable_checkpointing=True,
    accumulate_grad_batches=N_BATCH
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(lit_model, dl_train, dl_valid)

You are using a CUDA device ('NVIDIA GeForce RTX 5070 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode
--------------------------------------------------
0 | model | Qwen3ForCausalLM | 596 M  | eval
--------------------------------------------------
596 M     Trainable params
0         Non-trainable params
596 M     Total params
2,384.200 Total estimated model params size (MB)
0         Modules in train mode
427       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\jinwo\.virtualenvs\KROP-L3im0CPD\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
c:\Users\jinwo\.virtualenvs\KROP-L3im0CPD\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined