In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [117]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [119]:
train_list = json.load(open('../data/qa_train_list.json', encoding='utf8'))

JSONDecodeError: Expecting value: line 1 column 1 (char 0)

In [120]:
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [3]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [None]:
_limit = 1024
data = []
total_trimmed = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        tokens = tokens[:, :_limit]
        total_trimmed += 1
    data.append(tokens)
print(f'Trimmed {total_trimmed} out of {len(train_list)}')

In [7]:
train_batches = batchify(data, 1)

In [4]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [6]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

def test(test_model, batches):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs, labels=inputs)[0]
        total_loss += loss.item()

    return total_loss / len(batches)

In [10]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(10):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_small' + str(epoch))
    scheduler.step()

100%|██████████| 7199/7199 [13:32<00:00,  8.86it/s]


Epoch: 0 Loss: 2.371021380919949


  0%|          | 21/7199 [00:02<14:00,  8.54it/s]


KeyboardInterrupt: 

In [5]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    answers += [[item['input_text'] for item in dev_dict['data'][dev_index]['additional_answers'][str(index)]] for index in range(3)]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [6]:
def generate_answer_number(model, text, number):
    prompt = get_text_up_to_question_number(text, number)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 20
    tokens_length = tokens.shape[1]
    output = model.generate(
             tokens,
             max_length=tokens_length + _length,
             temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_number(output, number)

In [7]:
def compute_accuracy_of_model_with_all_answers(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            prediction = generate_answer_number(model, text, number)
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            if not it_was_answered:
                print('No Answer for number', number)
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions

In [8]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [9]:
def generate_answer(model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             #temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    offset = len(prompt)
    start = offset + 1
    end = output.find('\n', start)
    return output[start:end].split(':')[-1].strip()

In [10]:
def generate_answer_with_typical_decoding(model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    generated_entropy = 0
    while tokens.shape[-1] < tokens_length + _length:
        new_tokens = model(tokens.cuda())
        normalized = torch.nn.functional.log_softmax(new_tokens.logits, dim=-1)
        p = torch.exp(normalized)
        entropy = -(normalized * p).nansum(-1, keepdim=True)
        shifted_scores = torch.abs(normalized + entropy)
        pred_ids = torch.argmin(shifted_scores, dim=-1)
        last_token = pred_ids[:, -1].cpu().detach()
        tokens = torch.cat([tokens, torch.tensor([[last_token]])], dim=-1)
        last_output = tokenizer.decode(last_token, skip_special_tokens=True)
        if last_output == '\n':
            break
        
    generated_output = tokens[:, tokens_length:]
    output = tokenizer.decode(generated_output[0], skip_special_tokens=True)
    end = output.find('\n')
    return output[:end].replace('A: ', '').strip()

In [11]:
def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:100]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):

        all_answers = get_all_answers(dev_dict, index)
        total_questions = len(all_answers)        
        
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=8,
                                                 question_number=number,
                                                 last_question=False)
            prediction = generate_answer(model, small_text)
            if not prediction:
                print('NO PREDICTION!!')
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [136]:
false_positives = None
wrong_answers = None

for index in range(6,7):
    checkpoint = torch.load('save_small' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, wrong_answers, false_positives = compute_accuracy_of_model(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

 27%|██████████████████████████▍                                                                       | 27/100 [02:31<08:10,  6.72s/it]

NO PREDICTION!!


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [09:21<00:00,  5.61s/it]

Epoch: 6  with accuracy: 0.702275600505689





In [137]:
wrong_answers

[{'label': 'paint herself like them', 'prediction': 'she painted herself'},
 {'label': 'the farmer', 'prediction': "her mommy's"},
 {'label': "the old farmer's", 'prediction': "her mommy's"},
 {'label': "the farmer's", 'prediction': "her mommy's"},
 {'label': 'they started laughing', 'prediction': 'laughed'},
 {'label': 'rubbed her face', 'prediction': 'laughed'},
 {'label': 'started laughing', 'prediction': 'laughed'},
 {'label': 'dropped her into a big bucket of water',
  'prediction': 'a bucket of water'},
 {'label': 'a bottle', 'prediction': 'It was hard and clear'},
 {'label': 'the bottle', 'prediction': 'It was hard and clear'},
 {'label': 'no', 'prediction': 'Yes'},
 {'label': 'No', 'prediction': 'Yes'},
 {'label': "They took the note to Asta's papa",
  'prediction': "took it to Asta's papa"},
 {'label': "They took the note to Asta's papa",
  'prediction': "took it to Asta's papa"},
 {'label': 'unknown', 'prediction': "took it to Asta's papa"},
 {'label': 'unknown', 'prediction'

### Using Typical decoding:
accuracy: 0.56

### max questions 3
Epoch: 6  with accuracy: 0.6955696202531646

### max questions 4
Epoch: 6  with accuracy: 0.7010113780025284

### max questions 5
Epoch: 6  with accuracy: 0.7010113780025284

### max questions 6
Epoch: 6  with accuracy: 0.7039848197343453

### max questions 7
Epoch: 6  with accuracy: 0.706700379266751

### max questions 8
Epoch: 6  with accuracy: 0.702275600505689




In [12]:
checkpoint = torch.load('save_small' + str(6))
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [13]:
import numpy as np

max_probs = 5

def generate_answer_and_get_confidence(model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    generated_entropy = 0
    while tokens.shape[-1] < tokens_length + _length:
        new_token = model(tokens.cuda())
        probs = torch.softmax(new_token.logits[:, -1, :], dim=-1)
        probs_and_indices = [(p, index) for index, p in enumerate(probs[0].cpu().detach())]
        probs_and_indices = sorted(probs_and_indices, key=lambda x: -x[0])
        probs = [item[0] for item in probs_and_indices[:max_probs]]
        generated_entropy -= np.dot(probs, np.log(probs))
        tokens = torch.cat([tokens, torch.tensor([[torch.argmax(new_token.logits[:, -1, :])]])], dim=-1)
        last_token = tokens[:, -1]
        last_output = tokenizer.decode(last_token, skip_special_tokens=True)
        if last_output == '\n':
            break
        
    print(tokens.shape)
    print(tokens_length)
    generated_output = tokens[:, tokens_length:]
    print(generated_output.shape)
    output = tokenizer.decode(generated_output[0], skip_special_tokens=True)
    end = output.find('\n')
    return output[:end].strip(), generated_entropy

In [14]:
prompt = """
In the text below two people are discussing a story.

Story:
The Wag says: "My ship is called George".

Discussion:
Q: What is the Wag's ship called?
A: 
""".strip()

generate_answer_with_typical_decoding(model, prompt)

'George'

In [26]:
prompt = """
In the text below two people are discussing a story.

Story:
The user says: "My ship is called George".

Discussion:
Q: Who is speaking?
A: 
""".strip()

generate_answer_with_typical_decoding(model, prompt)

CPU times: user 25.3 ms, sys: 13.7 ms, total: 39 ms
Wall time: 38.3 ms


'George.'

In [53]:
prompt = """
In the text below two people are discussing a story.

Story:
The Wag speaks: "My ship is called George".

Discussion:
Q: Who is speaking?
A: 
""".strip()

generate_answer_with_typical_decoding(model, prompt)

'The Waggard.'

In [54]:
%%time
generate_answer_and_get_confidence(model, prompt)

torch.Size([1, 44])
41
torch.Size([1, 3])
CPU times: user 4.47 s, sys: 0 ns, total: 4.47 s
Wall time: 4.47 s


('The Wag', 1.748390108346939)

In [55]:
def generate_answer_greedy(model, prompt, max_length= 50):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    tokens_length = tokens.shape[1]
    if tokens_length + max_length > 1024:
        return ''
    
    while tokens.shape[-1] < tokens_length + max_length:
        new_tokens = model(tokens.cuda())
        pred_ids = torch.argmax(new_tokens.logits, dim=-1)
        last_token = pred_ids[:, -1].cpu().detach()
        tokens = torch.cat([tokens, torch.tensor([[last_token]])], dim=-1)
        last_output = tokenizer.decode(last_token, skip_special_tokens=True)
        if last_output == '\n':
            break
        
    generated_output = tokens[:, tokens_length:]
    output = tokenizer.decode(generated_output[0], skip_special_tokens=True)
    end = output.find('\n')
    return output[:end].replace('A: ', '').strip()

In [56]:
%%time
generate_answer_greedy(model, prompt)

CPU times: user 22.2 ms, sys: 0 ns, total: 22.2 ms
Wall time: 22 ms


'The Wag'

#### Compute prob of specific sequences

In [66]:
def get_sequence_probability_given_prompt(model, prompt, sequence):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    sequence_tokens = tokenizer.encode(sequence, return_tensors='pt')[0]
    token_index = 0
    total_prob = 1
    while token_index < len(sequence_tokens):
        new_tokens = model(tokens.cuda())
        last_distribution = new_tokens.logits[:, -1].cpu().detach()
        probs = torch.nn.functional.softmax(last_distribution)
        total_prob *= probs[0][sequence_tokens[token_index]]
        token_index += 1
        
        pred_ids = torch.argmax(new_tokens.logits, dim=-1)
        last_token = pred_ids[:, -1].cpu().detach()
        tokens = torch.cat([tokens, torch.tensor([[last_token]])], dim=-1)

            
    return total_prob

In [75]:
prompt = """
In the text below two people are discussing a story.

Story:
The user speaks: "My ship is called George".

Discussion:
Q: Who is the user speaking to?
A: 
""".strip()

In [77]:
get_sequence_probability_given_prompt(model, prompt, "user")

  probs = torch.nn.functional.softmax(last_distribution)


tensor(6.8406e-08)

In [78]:
generate_answer_greedy(model, prompt)

'George'