In [40]:
import os
import random
import torch
from nltk.translate.bleu_score import sentence_bleu
from tqdm import tqdm
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from transformers import EncoderDecoderModel, BertTokenizer
from dataset import Oxford2019Dataset

In [13]:
model_loc = 'models/dpp_5_epochs/'
model_type = 'bert-base-uncased'
data_loc = './data'

In [4]:
tokenizer = BertTokenizer.from_pretrained(model_type)

In [5]:
def make_data_loader(filename: str, batch = 24, file_loc: str = os.path.join(data_loc, 'Oxford-2019')) -> Dataset:
    dataset = Oxford2019Dataset(data_loc=os.path.join(file_loc, filename))
    # data_loader = DataLoader(dataset, batch_size=batch, shuffle=True, pin_memory=True)
    return dataset

In [6]:
test_set = make_data_loader('test.txt')

In [43]:
def bleu_scores(model, test_set, sample_size = None):
    sample_ids = range(len(test_set))
    if sample_size is not None:
        sample_ids = random.sample(sample_ids, sample_size)

    scores = []
    for id in tqdm(sample_ids):
        word, example, definition, _ = test_set[id]
        input_ids = torch.tensor(tokenizer.encode(example, add_special_tokens=True)).unsqueeze(0)
        generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
        generated_def = tokenizer.decode(generated.squeeze(), skip_special_tokens=True)

        a = generated_def.split()
        b = definition.split()

        max_length = max(len(a), len(b))

        # a.extend(['<pad>']*(max_length-len(a)))
        # b.extend(['<pad>']*(max_length-len(b)))

        score = sentence_bleu(a, b)
        scores.append(score)
    return scores


def explore(model, test_set, sample_size = None):
    sample_ids = range(len(test_set))
    if sample_size is not None:
        sample_ids = random.sample(sample_ids, sample_size)

    result = []
    for id in tqdm(sample_ids):
        word, example, definition, _ = test_set[id]
        input_ids = torch.tensor(tokenizer.encode(example, add_special_tokens=True)).unsqueeze(0)
        generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
        generated_def = tokenizer.decode(generated.squeeze(), skip_special_tokens=True)

        result.append((word, generated_def, definition))
    return result

In [9]:
model = EncoderDecoderModel.from_pretrained(model_loc)
model.eval()

  0%|          | 0/100 [00:01<?, ?it/s]


NameError: name 'bleu_score' is not defined

In [57]:
bleu_scores_5_epochs = bleu_scores(model, test_set[:100])
sum((s for s in bleu_scores_5_epochs if s != 0))

  7%|▋         | 7/100 [00:07<01:36,  1.04s/it]


KeyboardInterrupt: 

In [54]:
model_10_epochs = EncoderDecoderModel.from_pretrained('models/dpp_10_epochs/')

In [60]:
explore(model_10_epochs, test_set[10:20])

100%|██████████| 10/10 [00:09<00:00,  1.04it/s]


[('anaemic', 'a person who is not in the same way', 'suffering from anaemia'),
 ('anaemic', 'a person who is not in the same way', 'suffering from anaemia'),
 ('horde',
  'a person who is a person who is a person who is not a person who is not a',
  'an army or tribe of nomadic warriors'),
 ('horde',
  'a person who is a person who is a person who is not a person who is not a',
  'an army or tribe of nomadic warriors'),
 ('horde',
  'a person who is a person who is a person who is not a person who is not a',
  'an army or tribe of nomadic warriors'),
 ('order',
  'a person who is a person who is a religious church',
  'a society of knights bound by a common rule of life and having a combined military and monastic character'),
 ('order',
  'a person who is a person who is a religious church',
  'a society of knights bound by a common rule of life and having a combined military and monastic character'),
 ('order',
  'a person who is a person who is a religious church',
  'a society of kn

In [61]:
bleu_scores_10_epochs = bleu_scores(model_10_epochs, test_set[:100])
sum((s for s in bleu_scores_10_epochs if s != 0))

100%|██████████| 100/100 [01:30<00:00,  1.10it/s]


0

In [8]:
model_20_epochs = EncoderDecoderModel.from_pretrained('models/dpp_20_epochs/')

In [56]:
explore(model_20_epochs, test_set, 20)

100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


[('aesthetically',
  'of a person having a particular quality',
  'with regard to beauty'),
 ('monody',
  'a musical instrument with a long, narrow neck, used for playing music or music in a particular',
  'an ode sung by a single actor in a greek tragedy'),
 ('caress',
  'a small, narrow, narrow piece of something',
  'touch or stroke gently or lovingly'),
 ('library',
  'a book or other item of information that is used to provide a particular information',
  'a collection of films , recorded music , etc , organized systematically and kept for research or borrowing'),
 ('dura',
  'a small, narrow piece of material, typically one that is attached to the body of a person',
  'the tough outermost membrane enveloping the brain and spinal cord'),
 ('stick',
  'make a person or animal more attractive or attractive',
  'adhere or cling to something'),
 ('waistline',
  "a narrow strip of cloth or other material used to cover a person's body",
  "the measurement around a person 's body at the 

In [51]:
sample_size = 1000
bleu_scores_20_epochs = bleu_scores(model_20_epochs, test_set, sample_size)
print(sum(bleu_scores_20_epochs) / float(sample_size))

100%|██████████| 1000/1000 [16:06<00:00,  1.03it/s]


6.595307732167172e-232


In [14]:
model_20_epochs_large_lr = EncoderDecoderModel.from_pretrained('models/dpp_20_epochs_high_lr')

In [15]:
explore(model_20_epochs_large_lr, test_set[60:80])

100%|██████████| 20/20 [00:19<00:00,  1.02it/s]


[('puppetry',
  'the the the the the the the the the the the the the the the the the the the',
  'pretence'),
 ('gyp',
  'a a a a a a a a a a a a a a a a a a a',
  'an act of cheating someone'),
 ('gyp',
  'a a a a a a a a a a a a a a a a a a a',
  'an act of cheating someone'),
 ('gyp',
  'a a a a a a a a a a a a a a a a a a a',
  'an act of cheating someone'),
 ('interdisciplinary',
  'having having having having having having having having having having having having having having having having having having having',
  'relating to more than one branch of knowledge'),
 ('interdisciplinary',
  'having having having having having having having having having having having having having having having having having having having',
  'relating to more than one branch of knowledge'),
 ('interdisciplinary',
  'having having having having having having having having having having having having having having having having having having having',
  'relating to more than one branch of knowledge