In [None]:
import sys
sys.path.append('/kaggle/input/kaggle-lib/kaggle_lib')

In [None]:
%pip install omegaconf angle-emb emoji

In [None]:
import gc, os, torch, wandb, pandas
from datasets import Dataset
from src.config import Config
from src.models import model_helper_factory
from src.train_utils import fix_reproducibility, train
from src.logging import WandbLogger
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("wandb-api-key")

os.environ["WANDB_API_KEY"] = wandb_api_key
wandb.login()

In [None]:
model_name = "bart"

config = Config.load_config(
    config_path="/kaggle/input/kaggle-lib/kaggle_lib/config", model_name=model_name
)
config = Config.to_dict(config)
config["seed"] = 42

In [None]:
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"

with WandbLogger().init_wandb(**config["wandb"], config=config):
    config = wandb.config
    fix_reproducibility(config.seed)

    # get the train helper
    train_helper = model_helper_factory(config)

    # Make the tokenizer and the model
    tokenizer = train_helper.make_tokenizer()
    model = train_helper.make_model(tokenizer)
    
    # Make the data
    train_dataset = Dataset.from_pandas(pandas.read_pickle(f'/kaggle/input/{model_name}-train/train.pkl'))
    val_dataset =  Dataset.from_pandas(pandas.read_pickle(f'/kaggle/input/{model_name}-train/val.pkl'))
#     train_dataset = train_dataset.select(range(16))
#     val_dataset = val_dataset.select(range(60))
    
    collator = train_helper.make_data_collator(tokenizer, model)

    train_dataloader = train_helper.make_dataloader(
        train_dataset, collate_fn=collator, split="train"
    )
    val_dataloader = train_helper.make_dataloader(
        val_dataset, collate_fn=collator, split="val"
    )

    # Make the loss, the optimizer and the scheduler
    optimizer = train_helper.make_optimizer(model)
    scheduler = train_helper.make_scheduler(
        optimizer, steps_per_epoch=len(train_dataloader)
    )

    loss_fn = train_helper.make_loss(tokenizer)

    train(
        model,
        train_dataloader,
        val_dataloader,
        optimizer,
        loss_fn,
        scheduler,
        config,
    )

gc.collect()
torch.cuda.empty_cache()
