# Text entailment using BERT model

In [3]:
import torch
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.utils.data.dataset import T_co
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, BertConfig

In [4]:
DATA_DEV = './data/snli_1.0_dev.txt'
DATA_TEST = './data/snli_1.0_test.txt'
DATA_TRAIN = './data/snli_1.0_train.txt'

BATCH_SIZE = 50

LABELS = {'neutral': 0, 'contradiction': 1, 'entailment': 2}

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

In [22]:
class SNLIDataset(Dataset):
    def __init__(self, csv_file: str):
        self.data = pd.read_csv(csv_file, sep='\t')
        self.data = self.data[['gold_label', 'sentence1', 'sentence2']]
        self.data = self.data[self.data['gold_label'] != '-']
        self.data.dropna(inplace=True)

        self.sentences: pd.Series = '[CLS]' + self.data['sentence1'] + '[SEP]' + self.data['sentence2'] + '[SEP]'

        # find the max length
        # tokens = [tokenizer.tokenize(sen) for sen in self.sentences.values]
        # self.max_len = max([len(t_list) for t_list in tokens])
        self.max_len = 200

        self.labels = [LABELS[v] for v in self.data['gold_label'].values]

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int) -> T_co:
        sentence = self.sentences.iloc[index]

        encoded = tokenizer.encode_plus(
            sentence,
            return_tensors='pt',
            add_special_tokens=True,
            return_attention_mask=True,
            padding='max_length',
            max_length=self.max_len,
        )

        return encoded['input_ids'], encoded['attention_mask'], self.labels[index]

In [23]:
train = SNLIDataset(csv_file=DATA_TRAIN)
dev = SNLIDataset(csv_file=DATA_DEV)
test = SNLIDataset(csv_file=DATA_TEST)

In [24]:
train_loader = DataLoader(train, batch_size=BATCH_SIZE)
dev_loader = DataLoader(dev, batch_size=BATCH_SIZE)
test_loader = DataLoader(test, batch_size=BATCH_SIZE)

In [25]:
model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=3, output_attentions=False, output_hidden_states=False)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [26]:
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [40]:
max_epoch = 3
train_loss = []
eval_loss = []

In [41]:
lr = 1e-3
optimizer = AdamW(model.parameters(), lr=lr)

In [42]:
for epoch in range(max_epoch):
    print(f'Epoch {epoch + 1}/{max_epoch}')
    
    # # train
    # model.train()
    # train_loss = 0

    # for batch in train_loader:
    #     batch_ids = batch[0].to(device)
    #     batch_masks = batch[1].to(device)
    #     batch_labels = batch[2].to(device)

        # optimizer.zero_grad()
        # loss, logits = model(
        #     batch_ids, 
        #     # token_type_ids = None, 
        #     # attention_mask = batch_masks, 
        #     labels = batch_labels)
        # loss, _ = model(batch_ids, token_type_ids = None, attention_mask=batch_masks, labels=batch_labels)
    #     train_loss += loss.item()
    #     loss.backward()
    #     torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    #     optimizer.step()

    # avg_train_loss = train_loss / len(train_loader)
    # train_loss.append(avg_train_loss)
    # print(f'loss {avg_train_loss}')

Epoch 1/3
Epoch 2/3
Epoch 3/3
