In [1]:
import torch
import torch.nn as nn

import random
import numpy as np

from transformers import BertTokenizer, BertModel
from torchtext import data

import time

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [3]:
len(tokenizer.vocab)

30522

In [4]:
tokens = tokenizer.tokenize('Hello WORLD how ARE yoU?')

print(tokens)

['hello', 'world', 'how', 'are', 'you', '?']


In [5]:
indexes = tokenizer.convert_tokens_to_ids(tokens)

print(indexes)

[7592, 2088, 2129, 2024, 2017, 1029]


In [6]:
init_token = tokenizer.cls_token
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

[CLS] [SEP] [PAD] [UNK]


In [7]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


In [8]:
init_token_idx = tokenizer.cls_token_id
eos_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


In [9]:
max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']

print(max_input_length)

512


In [10]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:max_input_length-2]
    return tokens

In [11]:
TEXT = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

LABEL = data.LabelField(dtype = torch.float)

In [12]:
fields = [('text', TEXT), ('label', LABEL)]

In [13]:
train_data = data.TabularDataset.splits(
                path = '',
                train = 'qnli.csv',
                format = 'csv',
                fields = fields,
                skip_header = True)

train_data , valid_data = train_data[0].split(split_ratio=0.98,
                                             random_state = random.seed(10))

In [14]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")

Number of training examples: 85731
Number of validation examples: 1750


In [15]:
print(vars(train_data.examples[10]))

{'text': [2129, 2116, 6926, 1005, 1055, 2020, 1999, 1996, 2095, 2456, 2883, 1029, 102, 2004, 1997, 1996, 2883, 1997, 2456, 1010, 2045, 2024, 6391, 1010, 5511, 2549, 2111, 1010, 4008, 1010, 4749, 2581, 3911, 1010, 1998, 2385, 1010, 6255, 2629, 2945, 1999, 1996, 2103, 1012], 'label': 'entailment'}


In [16]:
tokens = tokenizer.convert_ids_to_tokens(vars(train_data.examples[10])['text'])

print(tokens)

['how', 'many', 'citizen', "'", 's', 'were', 'in', 'the', 'year', '2000', 'census', '?', '[SEP]', 'as', 'of', 'the', 'census', 'of', '2000', ',', 'there', 'are', '84', ',', '08', '##4', 'people', ',', '44', ',', '49', '##7', 'households', ',', 'and', '16', ',', '77', '##5', 'families', 'in', 'the', 'city', '.']


In [17]:
LABEL.build_vocab(train_data)

In [18]:
print(LABEL.vocab.stoi)

defaultdict(None, {'not_entailment': 0, 'entailment': 1})


In [19]:
BATCH_SIZE = 24

device = torch.device('cuda')

train_iterator, valid_iterator = data.BucketIterator.splits(
    (train_data, valid_data), 
    batch_size = BATCH_SIZE, 
    device = device,
    sort_key=lambda x: len(x.text))

In [20]:
class BERTSentiment(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        embedding_dim = self.bert.config.to_dict()['hidden_size']

        self.output = nn.Linear(embedding_dim, 1)
                
    def forward(self, text):
        
        #text = [batch size, sent len]
        
        embedded = self.bert(text)[0]
        
        logits = embedded[:,0,:]
        final_logits = self.output(logits)
        
        return final_logits

In [21]:
model = BERTSentiment().cuda()

In [22]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 109,483,009 trainable parameters


In [23]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.00002)

In [24]:
criterion = nn.BCEWithLogitsLoss().cuda()

In [25]:
def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc

In [26]:
N_EPOCHS = 2

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    # TRAIN
    ############################################################################
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in train_iterator:
        
        optimizer.zero_grad()
        
        predictions = model(batch.text).squeeze(1)
        
        loss = criterion(predictions, batch.label)
        
        acc = binary_accuracy(predictions, batch.label)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    train_loss, train_acc = epoch_loss / len(train_iterator), epoch_acc / len(train_iterator)
    ############################################################################
    
    # VALID
    ############################################################################
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in valid_iterator:

            predictions = model(batch.text).squeeze(1)
            
            loss = criterion(predictions, batch.label)
            
            acc = binary_accuracy(predictions, batch.label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    valid_loss, valid_acc = epoch_loss / len(valid_iterator), epoch_acc / len(valid_iterator)
    ############################################################################
    end_time = time.time()
    
    print(f"EPOCH : {epoch+1} | TIME : {end_time-start_time:.2f}")
    print(f"TRAIN LOSS : {train_loss:.2f}\tTRAIN ACC : {train_acc*100:.2f}")
    print(f"VALID LOSS : {valid_loss:.2f}\tVALID ACC : {valid_acc*100:.2f}\n")
    torch.save(model.state_dict(), f'saved_models/qnli_model_{epoch+1}.pt')

EPOCH : 1 | TIME : 1782.03
TRAIN LOSS : 0.36	TRAIN ACC : 84.26
VALID LOSS : 0.28	VALID ACC : 88.35

EPOCH : 2 | TIME : 1792.61
TRAIN LOSS : 0.20	TRAIN ACC : 92.14
VALID LOSS : 0.29	VALID ACC : 88.86



In [27]:
def predict_sentiment(model, tokenizer, sentence):
    model.eval()
    tokens = tokenizer.tokenize(sentence)
    tokens = tokens[:max_input_length-2]
    indexed = [init_token_idx] + tokenizer.convert_tokens_to_ids(tokens) + [eos_token_idx]
    tensor = torch.LongTensor(indexed).to(device)
    tensor = tensor.unsqueeze(0)
    prediction = torch.sigmoid(model(tensor))
    return prediction.item()

In [28]:
idxs = random.sample(range(0,len(valid_data.examples)),20)
for i in idxs:
    txt = ' '.join(tokenizer.convert_ids_to_tokens(vars(valid_data.examples[i])['text']))
    label = vars(valid_data.examples[i])['label']
    pred = predict_sentiment(model, tokenizer, txt)
    
    print(f"TEXT\n{txt}")
    print(f"LABEL : {label}")
    print(f"PREDICTION : {pred}\n\n")

TEXT
who rendered a france 2 camera person unconscious ? [SEP] a camera ##man for france 2 was struck in the face by a police officer , knocked unconscious , and had to be sent to hospital .
LABEL : entailment
PREDICTION : 0.965059757232666


TEXT
what items were found on the korean farm site ? [SEP] the farm was dated between 360 ##0 and 3000 b . c . pottery , stone projectile points , and possible houses were also found .
LABEL : entailment
PREDICTION : 0.9859975576400757


TEXT
what is the heat generated from a concentrating solar power system used for ? [SEP] in all of these systems a working fluid is heated by the concentrated sunlight , and is then used for power generation or energy storage .
LABEL : not_entailment
PREDICTION : 0.9740356206893921


TEXT
what is the eastern ##most river in punjab ? [SEP] the capital and largest city is lahore which was the historical capital of the wider punjab region .
LABEL : not_entailment
PREDICTION : 0.003199330996721983


TEXT
when was the 