In [1]:
import torch
import torch.nn as nn

import random

from transformers import BertTokenizer, BertModel
from torchtext import data

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [3]:
len(tokenizer.vocab)

30522

In [4]:
tokens = tokenizer.tokenize('Hello WORLD how ARE yoU?')

print(tokens)

['hello', 'world', 'how', 'are', 'you', '?']


In [5]:
indexes = tokenizer.convert_tokens_to_ids(tokens)

print(indexes)

[7592, 2088, 2129, 2024, 2017, 1029]


In [6]:
init_token = tokenizer.cls_token
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

[CLS] [SEP] [PAD] [UNK]


In [7]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


In [8]:
init_token_idx = tokenizer.cls_token_id
eos_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

101 102 0 100


In [9]:
max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']

print(max_input_length)

512


In [10]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:max_input_length-2]
    return tokens

In [11]:
TEXT = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

LABEL = data.LabelField(dtype = torch.float)

In [12]:
fields = [('text', TEXT), ('label', LABEL)]

In [13]:
train_data = data.TabularDataset.splits(
                path = '',
                train = 'qnli.csv',
                format = 'csv',
                fields = fields,
                skip_header = True)

train_data , valid_data = train_data[0].split(split_ratio=0.9,
                                             random_state = random.seed(10))

In [14]:
print(f"Number of training examples: {len(train_data)}")
print(f"Number of validation examples: {len(valid_data)}")

Number of training examples: 78733
Number of validation examples: 8748


In [15]:
print(vars(train_data.examples[6]))

{'text': [2029, 5040, 2001, 2034, 2000, 2022, 7183, 1999, 1996, 2047, 2806, 2013, 10630, 2692, 2000, 13138, 2487, 1029, 102, 2053, 14001, 5040, 1010, 2036, 1999, 2605, 1010, 2387, 1996, 5700, 6503, 1997, 1037, 14584, 1997, 2019, 2972, 5040, 1999, 1996, 2047, 2806, 2013, 10630, 2692, 2000, 13138, 2487, 1012], 'label': 'entailment'}


In [16]:
tokens = tokenizer.convert_ids_to_tokens(vars(train_data.examples[6])['text'])

print(tokens)

['which', 'cathedral', 'was', 'first', 'to', 'be', 'rebuilt', 'in', 'the', 'new', 'style', 'from', '115', '##0', 'to', '123', '##1', '?', '[SEP]', 'no', '##yon', 'cathedral', ',', 'also', 'in', 'france', ',', 'saw', 'the', 'earliest', 'completion', 'of', 'a', 'rebuilding', 'of', 'an', 'entire', 'cathedral', 'in', 'the', 'new', 'style', 'from', '115', '##0', 'to', '123', '##1', '.']


In [17]:
LABEL.build_vocab(train_data)

In [18]:
print(LABEL.vocab.stoi)

defaultdict(None, {'not_entailment': 0, 'entailment': 1})


In [19]:
BATCH_SIZE = 24

device = torch.device('cuda')

train_iterator, valid_iterator = data.BucketIterator.splits(
    (train_data, valid_data), 
    batch_size = BATCH_SIZE, 
    device = device,
    sort_key=lambda x: len(x.text))

In [20]:
class BERTSentiment(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        embedding_dim = self.bert.config.to_dict()['hidden_size']

        self.output = nn.Linear(embedding_dim, 1)
                
    def forward(self, text):
        
        #text = [batch size, sent len]
        
        embedded = self.bert(text)[0]
        
        logits = embedded[:,0,:]
        final_logits = self.output(logits)
        
        return final_logits

In [21]:
models = []

for i in range(2):
    new_model = BERTSentiment().cuda()
    new_model.load_state_dict(torch.load(f'saved_models/qnli_model_{i+1}.pt'))
    models.append(new_model)

In [22]:
def model_ensemble_output(models, tensor):
    
    outputs = []
    for i in range(len(models)):
        outputs.append(models[i](tensor))
        
    return sum(outputs)

In [23]:
def predict_sentiment(models, tokenizer, sentence):
    for model in models:
        model.eval()
    
    tokens = tokenizer.tokenize(sentence)
    tokens = tokens[:max_input_length-2]
    indexed = [init_token_idx] + tokenizer.convert_tokens_to_ids(tokens) + [eos_token_idx]
    tensor = torch.LongTensor(indexed).to(device)
    tensor = tensor.unsqueeze(0)
    prediction = torch.sigmoid(model_ensemble_output(models, tensor))
    return prediction.item()

In [24]:
idxs = random.sample(range(0,len(valid_data.examples)),30)
for i in idxs:
    txt = ' '.join(tokenizer.convert_ids_to_tokens(vars(valid_data.examples[i])['text']))
    label = vars(valid_data.examples[i])['label']
    pred = predict_sentiment(models, tokenizer, txt)
    
    print(f"TEXT\n{txt}")
    print(f"LABEL : {label}")
    print(f"PREDICTION : {pred}\n\n")

TEXT
greece ' s constitution has how many articles ? [SEP] the constitution , which consists of 120 articles , provides for a separation of powers into executive , legislative , and judicial branches , and grants extensive specific guarantees ( further reinforced in 2001 ) of civil liberties and social rights .
LABEL : entailment
PREDICTION : 0.9999871253967285


TEXT
what year did american idol begin airing ? [SEP] it began airing on fox on june 11 , 2002 , as an addition to the idols format based on the british series pop idol and has since become one of the most successful shows in the history of american television .
LABEL : entailment
PREDICTION : 0.999969482421875


TEXT
what two general aviation airports are operated by the san diego regional airport authority ? [SEP] it is operated by an independent agency , the san diego regional airport authority .
LABEL : not_entailment
PREDICTION : 0.0002637379802763462


TEXT
what are the two sub - signals in each frequency band referred t