In [1]:
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoModel, AlbertTokenizerFast

from data.utils import get_data, collate_fn_buckets
from data import SNLICrossEncoderBucketingDataset, SNLIBiEncoderBucketingDataset

from model import NLICrossEncoder, NLIBiEncoder

from training import Trainer
from training.utils import unfreeze_parameters

# Data

In [2]:
train, test, val = get_data('snli_1.0')

In [3]:
LABELS = ['entailment', 'contradiction', 'neutral']
NUM_LABELS = len(LABELS)

train = train[train.target.isin(set(LABELS))]
val = val[val.target.isin(set(LABELS))]
test = test[test.target.isin(set(LABELS))]

In [4]:
train.head()

Unnamed: 0,sentence1,sentence2,target
0,A person on a horse jumps over a broken down a...,A person is training his horse for a competition.,neutral
1,A person on a horse jumps over a broken down a...,"A person is at a diner, ordering an omelette.",contradiction
2,A person on a horse jumps over a broken down a...,"A person is outdoors, on a horse.",entailment
3,Children smiling and waving at camera,They are smiling at their parents,neutral
4,Children smiling and waving at camera,There are children present,entailment


In [5]:
target2idx = {l: i for i, l in enumerate(LABELS)}
train.target = train.target.map(target2idx)
val.target = val.target.map(target2idx)
test.target = test.target.map(target2idx)

In [6]:
albert_tokenizer = AlbertTokenizerFast.from_pretrained('albert-base-v2')

In [7]:
train_dataset = SNLIBiEncoderBucketingDataset(albert_tokenizer, train.sentence1, train.sentence2, train.target, batch_size=48)
val_dataset = SNLIBiEncoderBucketingDataset(albert_tokenizer, val.sentence1, val.sentence2, val.target, batch_size=48)

100%|██████████| 1098734/1098734 [01:39<00:00, 11065.87it/s]
100%|██████████| 19648/19648 [00:02<00:00, 9592.41it/s] 


In [8]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn_buckets)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn_buckets)

# Model

In [9]:
albert = AutoModel.from_pretrained('albert-base-v2')
unfreeze_parameters(albert)

In [10]:
model = NLIBiEncoder(albert, NUM_LABELS, lambda x: x.pooler_output)

In [11]:
criterion = nn.CrossEntropyLoss()
trainer = Trainer(model, criterion, 'albert/v1/albert_snli.pt', 'albert/v1/albert_optimizer.pt', lr=1e-4, device='cuda:0')

# Training

In [12]:
trainer.train(train_loader, val_loader, num_epochs=30)

Current progress:   0%|          | 0/11446 [00:00<?, ?it/s]

Epoch 1/30


Mean loss: 0.7507698660492897. Current progress: : 100%|██████████| 11446/11446 [30:31<00:00,  6.25it/s]


Mean training loss: 0.8181385246790437. Mean validation loss: 0.7315447595061325.
Training accuracy: 0.6384074762408372. Validation accuracy: 0.6774226384364821.


Mean loss: 0.7508805348277092. Current progress: :   0%|          | 1/11446 [00:00<32:38,  5.84it/s]

Epoch 2/30


Mean loss: 0.7003002393245698. Current progress: : 100%|██████████| 11446/11446 [30:18<00:00,  6.30it/s]


Mean training loss: 0.7109451954277916. Mean validation loss: 0.7010768849675248.
Training accuracy: 0.6941061257774858. Validation accuracy: 0.7001221498371335.


Mean loss: 0.7002300676107407. Current progress: :   0%|          | 1/11446 [00:00<32:37,  5.85it/s]

Epoch 3/30


Mean loss: 0.6453728113770485. Current progress: : 100%|██████████| 11446/11446 [30:05<00:00,  6.34it/s]


Mean training loss: 0.6575446358898002. Mean validation loss: 0.6577041307600533.
Training accuracy: 0.7249361537915455. Validation accuracy: 0.7259771986970684.


Mean loss: 0.6450433875918389. Current progress: :   0%|          | 1/11446 [00:00<31:23,  6.08it/s]

Epoch 4/30


Mean loss: 0.6088032639026641. Current progress: : 100%|██████████| 11446/11446 [30:07<00:00,  6.33it/s]
Mean loss: 0.6087459021806717. Current progress: :   0%|          | 1/11446 [00:00<26:04,  7.31it/s]

Mean training loss: 0.60605473937208. Mean validation loss: 0.6581013076189087.
Training accuracy: 0.7514503055334594. Validation accuracy: 0.7313721498371335.
Epoch 5/30


Mean loss: 0.5646379224658012. Current progress: : 100%|██████████| 11446/11446 [30:40<00:00,  6.22it/s]
Mean loss: 0.5646630463600159. Current progress: :   0%|          | 1/11446 [00:00<30:04,  6.34it/s]

Mean training loss: 0.5554217946766589. Mean validation loss: 0.6873039284857309.
Training accuracy: 0.7772873143090139. Validation accuracy: 0.730150651465798.
Epoch 6/30


Mean loss: 0.5264265257716179. Current progress: : 100%|██████████| 11446/11446 [30:15<00:00,  6.31it/s]


Mean training loss: 0.5080134919121219. Mean validation loss: 0.7082920808617662.
Training accuracy: 0.7997404285295622. Validation accuracy: 0.7251628664495114.
Validation performance has started degrading. Performing early stopping.


NLIBiEncoder(
  (premise_encoder): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(30000, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True