In [1]:
import torch
import torch.nn as nn

import random
import numpy as np

from transformers import BertTokenizer, BertModel
from torchtext import data

import time

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [3]:
len(tokenizer.vocab)

30522

In [4]:
tokens = tokenizer.tokenize('Hello WORLD how ARE yoU?')

print(tokens)

['hello', 'world', 'how', 'are', 'you', '?']


In [5]:
indexes = tokenizer.convert_tokens_to_ids(tokens)

print(indexes)

[7592, 2088, 2129, 2024, 2017, 1029]


In [6]:
init_token = tokenizer.cls_token
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

[CLS] [SEP] [PAD] [UNK]


In [7]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


In [8]:
init_token_idx = tokenizer.cls_token_id
eos_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


In [9]:
max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']

print(max_input_length)

512


In [10]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:max_input_length-2]
    return tokens

In [11]:
TEXT = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

LABEL = data.LabelField(dtype = torch.float)

In [12]:
fields = [('text', TEXT), ('label', LABEL)]

In [13]:
train_data = data.TabularDataset.splits(
                path = '',
                train = 'qnli.csv',
                format = 'csv',
                fields = fields,
                skip_header = True)

train_data , valid_data = train_data[0].split(split_ratio=0.96,
                                             random_state = random.seed(1234))

In [14]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")

Number of training examples: 67176
Number of validation examples: 2799


In [15]:
print(vars(train_data.examples[6]))

{'text': [2054, 6433, 2000, 17886, 2050, 2043, 2009, 5829, 2000, 1996, 2235, 20014, 4355, 3170, 1029, 102, 2009, 2003, 1037, 2485, 5662, 1997, 1037, 18847, 12617, 12412, 4308, 1006, 1041, 1012, 1043, 1012, 1010, 2216, 1999, 4286, 2030, 14695, 1007, 1010, 1998, 17886, 2050, 2003, 13995, 2182, 1999, 2172, 1996, 2168, 2126, 1012], 'label': 'not_entailment'}


In [16]:
tokens = tokenizer.convert_ids_to_tokens(vars(train_data.examples[6])['text'])

print(tokens)

['what', 'happens', 'to', 'digest', '##a', 'when', 'it', 'moves', 'to', 'the', 'small', 'int', '##est', '##ine', '?', '[SEP]', 'it', 'is', 'a', 'close', 'equivalent', 'of', 'a', 'mono', '##gas', '##tric', 'stomach', '(', 'e', '.', 'g', '.', ',', 'those', 'in', 'humans', 'or', 'pigs', ')', ',', 'and', 'digest', '##a', 'is', 'processed', 'here', 'in', 'much', 'the', 'same', 'way', '.']


In [17]:
LABEL.build_vocab(train_data)

In [18]:
print(LABEL.vocab.stoi)

defaultdict(None, {'not_entailment': 0, 'entailment': 1})


In [19]:
BATCH_SIZE = 24

device = torch.device('cuda')

train_iterator, valid_iterator = data.BucketIterator.splits(
    (train_data, valid_data), 
    batch_size = BATCH_SIZE, 
    device = device,
    sort_key=lambda x: len(x.text))

In [20]:
class BERTSentiment(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        embedding_dim = self.bert.config.to_dict()['hidden_size']

        self.output = nn.Linear(embedding_dim, 1)
                
    def forward(self, text):
        
        #text = [batch size, sent len]
        
        embedded = self.bert(text)[0]
        
        logits = embedded[:,0,:]
        final_logits = self.output(logits)
        
        return final_logits

In [21]:
model = BERTSentiment().cuda()

In [22]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 109,483,009 trainable parameters


In [23]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.00002)

In [24]:
criterion = nn.BCEWithLogitsLoss().cuda()

In [25]:
def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc

In [26]:
N_EPOCHS = 2

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    # TRAIN
    ############################################################################
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in train_iterator:
        
        optimizer.zero_grad()
        
        predictions = model(batch.text).squeeze(1)
        
        loss = criterion(predictions, batch.label)
        
        acc = binary_accuracy(predictions, batch.label)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    train_loss, train_acc = epoch_loss / len(train_iterator), epoch_acc / len(train_iterator)
    ############################################################################
    
    # VALID
    ############################################################################
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in valid_iterator:

            predictions = model(batch.text).squeeze(1)
            
            loss = criterion(predictions, batch.label)
            
            acc = binary_accuracy(predictions, batch.label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    valid_loss, valid_acc = epoch_loss / len(valid_iterator), epoch_acc / len(valid_iterator)
    ############################################################################
    end_time = time.time()
    
    print(f"EPOCH : {epoch+1} | TIME : {end_time-start_time:.2f}")
    print(f"TRAIN LOSS : {train_loss:.2f}\tTRAIN ACC : {train_acc*100:.2f}")
    print(f"VALID LOSS : {valid_loss:.2f}\tVALID ACC : {valid_acc*100:.2f}\n")
    torch.save(model.state_dict(), f'para_rank_model_{epoch+1}.pt')

EPOCH : 1 | TIME : 1901.51
TRAIN LOSS : 0.40	TRAIN ACC : 82.01
VALID LOSS : 0.32	VALID ACC : 86.55

EPOCH : 2 | TIME : 2543.73
TRAIN LOSS : 0.24	TRAIN ACC : 90.25
VALID LOSS : 0.34	VALID ACC : 86.57



In [27]:
def predict_sentiment(model, tokenizer, sentence):
    model.eval()
    tokens = tokenizer.tokenize(sentence)
    tokens = tokens[:max_input_length-2]
    indexed = [init_token_idx] + tokenizer.convert_tokens_to_ids(tokens) + [eos_token_idx]
    tensor = torch.LongTensor(indexed).to(device)
    tensor = tensor.unsqueeze(0)
    prediction = torch.sigmoid(model(tensor))
    return prediction.item()

In [28]:
idxs = random.sample(range(0,len(valid_data.examples)),20)
for i in idxs:
    txt = ' '.join(tokenizer.convert_ids_to_tokens(vars(valid_data.examples[i])['text']))
    label = vars(valid_data.examples[i])['label']
    pred = predict_sentiment(model, tokenizer, txt)
    
    print(f"TEXT\n{txt}")
    print(f"LABEL : {label}")
    print(f"PREDICTION : {pred}\n\n")

TEXT
in what century did public drinking regulations first exist in england ? [SEP] tavern owners were required to possess a licence to sell ale , and a separate licence for di ##sti ##lled spirits .
LABEL : not_entailment
PREDICTION : 0.000944725819863379


TEXT
what did the advancement ##s during the revolution improve for people ? [SEP] the advancement ##s made a great contribution to the quality of life .
LABEL : entailment
PREDICTION : 0.7218703031539917


TEXT
the largest sector of greece ' s economy is what ? [SEP] its economy mainly comprises the service sector ( 85 . 0 % ) and industry ( 12 . 0 % ) , while agriculture makes up 3 . 0 % of the national economic output .
LABEL : entailment
PREDICTION : 0.9812381863594055


TEXT
where did the mori ##sco ##s go when they were forced out of spain ? [SEP] the crown endeavour ##ed to compensate the nobles , who had lost much of their agricultural labour force ; this harmed the economy of the city for generations to come .
LABEL : not_