In [1]:
%reload_ext autoreload
%autoreload 2

In [None]:
import os
import random

import accelerate
import torch
import transformers

from src._shared import (
    apply_lora_to_model,
    freeze_base_models,
    load_clip_model,
    load_config,
    load_tokenizers,
    prepare_dataset,
    save_model_and_logs,
    setup_environment,
    setup_trainer,
    train_model,
)

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

train_config = load_config()
model_name_identifier, device, report_to, run, USE_WANDB, SEED = setup_environment(train_config)

accelerate.utils.set_seed(SEED + 1)
transformers.set_seed(SEED + 2)
torch.manual_seed(SEED + 3)
random.seed(SEED + 4)

In [None]:
tokenizer_plm, tokenizer_llm = load_tokenizers(train_config)
dataset = prepare_dataset(train_config, tokenizer_plm, tokenizer_llm)

In [None]:
print(dataset)
print(dataset["train"][0])

In [None]:
model = load_clip_model(train_config, device)

In [None]:
# train_config['lora']['enabled'] = False

if train_config['lora']['enabled']:
    model = apply_lora_to_model(model, train_config)
else:
    freeze_base_models(model)

In [None]:
model

In [9]:
# for param in model.named_parameters():
#     print(param[0], param[1].is_cuda)

In [None]:
model.print_trainable_parameters()

In [None]:
import src.model.utils as utils
utils.check_model_on_cuda(model)
utils.check_model_parameters_requires_grad(model)

In [None]:
trainer = setup_trainer(model, dataset, train_config, model_name_identifier, USE_WANDB, tokenizer_plm, tokenizer_llm)

In [None]:
train_model(trainer)

In [None]:
save_model_and_logs(model, trainer, model_name_identifier, train_config)

In [11]:
# if train_config.lora.enabled:
#     model = apply_lora_to_model(model, train_config)
# else:
#     freeze_base_models(model)