In [1]:
%load_ext autoreload
%autoreload 2

from config import CFG
import dataset
import engine

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset, load_metric

import torch

import numpy as np
import os
from sklearn.model_selection import train_test_split

In [2]:
debug = True

In [3]:
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
tokenizer = AutoTokenizer.from_pretrained(CFG.model_name)

In [6]:
if not os.path.exists('model_checkpoints/base_model/'):
    model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name)
    model.save_pretrained('model_checkpoints/base_model')
else:
    model = AutoModelForSeq2SeqLM.from_pretrained('model_checkpoints/base_model/')

In [7]:
raw_dataset = load_dataset("europa_eac_tm", language_pair=("pl", "en"))

Using custom data configuration pl2en-0da2ec5e9ea613fc
Reusing dataset europa_eac_tm (/home/bartek/.cache/huggingface/datasets/europa_eac_tm/pl2en-0da2ec5e9ea613fc/0.0.0/955b2501a836c2ea49cfe3e719aec65dcbbc3356bbbe53cf46f08406eb77386a)


In [8]:
X = [i['translation']['pl'] for i in raw_dataset['train']]
y = [i['translation']['en'] for i in raw_dataset['train']]

if debug:
    X = X[: CFG.train_batch_size * 8]
    y = y[: CFG.train_batch_size * 8]

x_train, x_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2)

print(f'train size: {len(x_train)}, valid size: {len(x_valid)}')

train size: 51, valid size: 13


In [9]:
train_ds = dataset.TranslationDataset(x_train, y_train, tokenizer)

In [10]:
train_dl = torch.utils.data.DataLoader(train_ds, CFG.train_batch_size, num_workers = 1)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr = CFG.lr)

In [12]:
# engine.train_fn(model, optimizer, train_dl, device, scheduler=None)

In [13]:
valid_ds  = dataset.TranslationDataset(x_valid, y_valid, tokenizer)

In [14]:
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=CFG.valid_batch_size,  num_workers = 1)

In [15]:
valid_loss = engine.valid_fn(model, valid_dl, device)

HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))


torch.Size([13, 512, 63430])


In [16]:
valid_loss

tensor(10.8985)