In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [3]:
len(tokenizer.get_vocab())

50257

In [4]:
mask_token = list(tokenizer.get_vocab()).index("//")

In [5]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [6]:
train_list = json.load(open('../data/qa_train_list.json', encoding='utf8'))

In [7]:
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [8]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [9]:
_limit = 1024
data = []
total_trimmed = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        tokens = tokens[:, :_limit]
        total_trimmed += 1
    data.append(tokens)
print(f'Trimmed {total_trimmed} out of {len(train_list)}')

Token indices sequence length is longer than the specified maximum sequence length for this model (1088 > 1024). Running this sequence through the model will result in indexing errors


Trimmed 52 out of 7199


In [10]:
train_batches = batchify(data, 1)

In [11]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-6)

In [12]:
model.transformer.wte

Embedding(50257, 768)

In [13]:
original_causal_masks = []
for index in range(11, -1, -1):
    if index < 11:
        break

    causal_mask = model.transformer.h[index].attn.bias
    original_causal_masks.append(causal_mask.clone())
    model.transformer.h[index].attn.bias = torch.ones_like(causal_mask)

In [14]:
torch.rand_like

<function _VariableFunctionsClass.rand_like>

In [15]:
import random

def mask_some_tokens(tokens, ratio):
    mask = torch.rand_like(tokens, dtype=torch.double)
    tokens = torch.where(mask > ratio, tokens, mask_token)
    return tokens

In [16]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        labels = inputs.clone()
        inputs = mask_some_tokens(inputs, ratio=0.1)
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=labels.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

def test(test_model, batches):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs, labels=inputs)[0]
        total_loss += loss.item()

    return total_loss / len(batches)

In [17]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(8):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_small' + str(epoch))
    scheduler.step()

100%|███████████████████████████████████████████████████████████████████████████████████████████████| 7199/7199 [13:13<00:00,  9.08it/s]


Epoch: 0 Loss: 1.0636040196557528


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 7199/7199 [13:13<00:00,  9.08it/s]


Epoch: 1 Loss: 0.48959086684696146


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 7199/7199 [13:07<00:00,  9.14it/s]


Epoch: 2 Loss: 0.4137509625282432


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 7199/7199 [13:08<00:00,  9.13it/s]


Epoch: 3 Loss: 0.38410387325326606


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 7199/7199 [13:09<00:00,  9.12it/s]


Epoch: 4 Loss: 0.3661047230683063


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 7199/7199 [13:10<00:00,  9.11it/s]


Epoch: 5 Loss: 0.35104502990670067


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 7199/7199 [13:20<00:00,  9.00it/s]


Epoch: 6 Loss: 0.34063764571613664


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 7199/7199 [13:42<00:00,  8.75it/s]


Epoch: 7 Loss: 0.33304477801073223


In [18]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    answers += [[item['input_text'] for item in dev_dict['data'][dev_index]['additional_answers'][str(index)]] for index in range(3)]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [19]:
def generate_answer_number(model, text, number):
    prompt = get_text_up_to_question_number(text, number)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 20
    tokens_length = tokens.shape[1]
    output = model.generate(
             tokens,
             max_length=tokens_length + _length,
             temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_number(output, number)

In [20]:
def compute_accuracy_of_model_with_all_answers(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            prediction = generate_answer_number(model, text, number)
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            if not it_was_answered:
                print('No Answer for number', number)
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions

In [21]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [22]:
def generate_answer(model, prompt, to_print=False):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             #temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    if to_print:
        print(output)
    offset = len(prompt)
    start = offset + 1
    end = output.find('\n', start)
    return output[start:end].split(':')[-1].strip()

In [23]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
_ = model.cuda()
checkpoint = torch.load('save_small' + str(0))
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [24]:
def restore_causal_masks(model, original_masks):
    #for index in range(11, -1, -1):
    #    if index < 11:
    #        break

    model.transformer.h[11].attn.bias = original_masks[0].clone()

    return model

In [25]:
model = restore_causal_masks(model, original_causal_masks)

In [26]:
small_text = get_text_from_data_item(dev_dict['data'][0],
                                     max_num_questions=8,
                                     question_number=1,
                                     last_question=False)

generate_answer(model, small_text, to_print=True)

In the text below two people are discussing a story.

Story:
Once upon a time, in a barn near a farm house, there lived a little white kitten named Cotton. Cotton lived high up in a nice warm place above the barn where all of the farmer's horses slept. But Cotton wasn't alone in her little home above the barn, oh no. She shared her hay bed with her mommy and 5 other sisters. All of her sisters were cute and fluffy, like Cotton. But she was the only white one in the bunch. The rest of her sisters were all orange with beautiful white tiger stripes like Cotton's mommy. Being different made Cotton quite sad. She often wished she looked like the rest of her family. So one day, when Cotton found a can of the old farmer's orange paint, she used it to paint herself like them. When her mommy and sisters found her they started laughing. 

"What are you doing, Cotton?!" 

"I only wanted to be more like you". 

Cotton's mommy rubbed her face on Cotton's and said "Oh Cotton, but your fur is so pret

'a barn near a farm house'

In [27]:
model.transformer.h[index].attn.is_cross_attention

False

In [28]:
for index in range(11, -1, -1):
    print(model.transformer.h[index].attn.bias)

tensor([[[[1, 0, 0,  ..., 0, 0, 0],
          [1, 1, 0,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 1, 0, 0],
          [1, 1, 1,  ..., 1, 1, 0],
          [1, 1, 1,  ..., 1, 1, 1]]]], device='cuda:0', dtype=torch.uint8)
tensor([[[[1, 0, 0,  ..., 0, 0, 0],
          [1, 1, 0,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 1, 0, 0],
          [1, 1, 1,  ..., 1, 1, 0],
          [1, 1, 1,  ..., 1, 1, 1]]]], device='cuda:0', dtype=torch.uint8)
tensor([[[[1, 0, 0,  ..., 0, 0, 0],
          [1, 1, 0,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 1, 0, 0],
          [1, 1, 1,  ..., 1, 1, 0],
          [1, 1, 1,  ..., 1, 1, 1]]]], device='cuda:0', dtype=torch.uint8)
tensor([[[[1, 0, 0,  ..., 0, 0, 0],
          [1, 1, 0,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 1, 0, 0],
          [1, 1, 1,  ..., 1, 1,

In [29]:
def generate_answer_with_typical_decoding(model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    generated_entropy = 0
    while tokens.shape[-1] < tokens_length + _length:
        new_tokens = model(tokens.cuda())
        normalized = torch.nn.functional.log_softmax(new_tokens.logits, dim=-1)
        p = torch.exp(normalized)
        entropy = -(normalized * p).nansum(-1, keepdim=True)
        shifted_scores = torch.abs(normalized + entropy)
        pred_ids = torch.argmin(shifted_scores, dim=-1)
        last_token = pred_ids[:, -1].cpu().detach()
        tokens = torch.cat([tokens, torch.tensor([[last_token]])], dim=-1)
        last_output = tokenizer.decode(last_token, skip_special_tokens=True)
        if last_output == '\n':
            break
        
    generated_output = tokens[:, tokens_length:]
    output = tokenizer.decode(generated_output[0], skip_special_tokens=True)
    end = output.find('\n')
    return output[:end].replace('A: ', '').strip()

In [30]:
def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:100]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):

        all_answers = get_all_answers(dev_dict, index)
        total_questions = len(all_answers)        
        
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=8,
                                                 question_number=number,
                                                 last_question=False)
            prediction = generate_answer(model, small_text)
            if not prediction:
                #print('NO PREDICTION!!')
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [31]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
_ = model.cuda()

In [32]:
false_positives = None
wrong_answers = None

for index in range(8):
    checkpoint = torch.load('save_small' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    restore_causal_masks(model, original_causal_masks)
    accuracy, wrong_answers, false_positives = compute_accuracy_of_model(model)
    print('Epoch:', index, ' with accuracy:', accuracy)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [09:18<00:00,  5.58s/it]


Epoch: 0  with accuracy: 0.22831632653061223


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [09:12<00:00,  5.52s/it]


Epoch: 1  with accuracy: 0.26305732484076433


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [09:05<00:00,  5.46s/it]


Epoch: 2  with accuracy: 0.30935709739019734


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [09:16<00:00,  5.57s/it]


Epoch: 3  with accuracy: 0.3382165605095541


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [09:26<00:00,  5.67s/it]


Epoch: 4  with accuracy: 0.3233226837060703


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [09:13<00:00,  5.54s/it]


Epoch: 5  with accuracy: 0.3485530546623794


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [09:20<00:00,  5.61s/it]


Epoch: 6  with accuracy: 0.3512476007677543


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [09:09<00:00,  5.50s/it]

Epoch: 7  with accuracy: 0.36293929712460066





In [33]:
wrong_answers

[{'label': 'in a barn',
  'prediction': "a nice warm place above the barn where all of the farmer's horses slept"},
 {'label': 'in a barn near',
  'prediction': "a nice warm place above the barn where all of the farmer's horses slept"},
 {'label': 'orange', 'prediction': 'pink'},
 {'label': 'orange and white', 'prediction': 'pink'},
 {'label': 'orange with white tiger stripes', 'prediction': 'pink'},
 {'label': 'she painted herself',
  'prediction': 'changed her mind I like being special'},
 {'label': 'paint herself like them',
  'prediction': 'changed her mind I like being special'},
 {'label': 'the farmer', 'prediction': 'white and white'},
 {'label': "the farmer's", 'prediction': 'white and white'},
 {'label': "the old farmer's", 'prediction': 'white and white'},
 {'label': 'they started laughing',
  'prediction': 'she painted herself herself herself herself herself herself herself herself herself herself herself herself herself herself herself herself herself herself herself hersel