In [None]:
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoModel, AlbertTokenizerFast

from data.utils import get_data, collate_fn
from data import SNLICrossEncoderDataset, SNLIBiEncoderDataset

from model import NLICrossEncoder, NLIBiEncoder

from training import Trainer
from training.utils import unfreeze_parameters

# Data

In [None]:
train, test, val = get_data('snli_1.0')

In [None]:
LABELS = ['entailment', 'contradiction', 'neutral']
NUM_LABELS = len(LABELS)

train = train[train.target.isin(set(LABELS))]
val = val[val.target.isin(set(LABELS))]
test = test[test.target.isin(set(LABELS))]

In [None]:
target2idx = {l: i for i, l in enumerate(LABELS)}
train.target = train.target.map(target2idx)
val.target = train.target.map(target2idx)
test.target = train.target.map(target2idx)

In [None]:
train.head()

In [None]:
albert_tokenizer = AlbertTokenizerFast.from_pretrained('albert-base-v2')

In [None]:
train_dataset = SNLIBiEncoderDataset(albert_tokenizer, train.sentence1, train.sentence2, train.target, batch_size=64)
val_dataset = SNLIBiEncoderDataset(albert_tokenizer, train.sentence1, train.sentence2, train.target, batch_size=64)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

# Model

In [None]:
albert = AutoModel.from_pretrained('albert-base-v2')
unfreeze_parameters(albert)

In [None]:
model = NLIBiEncoder(albert, NUM_LABELS, lambda x: x.pooler_output)

In [None]:
criterion = nn.CrossEntropyLoss()
trainer = Trainer(model, criterion, 'albert/albert_snli.pth', 'albert/albert_optimizer.pth', lr=1e-4, device='cuda:0')

# Training

In [None]:
trainer.train(train_loader, val_loader, num_epochs=30)