In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [3]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [4]:
train_list = json.load(open('../data/qa_train_list.json', encoding='utf8'))
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [5]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [6]:
_limit = 1024
data = []
total_skipped = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        tokens = tokens[:, :_limit]
    data.append(tokens)
print(f'Skipped {total_skipped} out of {len(train_list)}')

Token indices sequence length is longer than the specified maximum sequence length for this model (1088 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1306 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1140 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1037 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1363 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1148 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1108 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1112 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1155 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1040 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Skipped 0 out of 7199


In [7]:
train_batches = batchify(data, 1)

In [17]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

def test(test_model, batches):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs, labels=inputs)[0]
        total_loss += loss.item()

    return total_loss / len(batches)

In [10]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(10):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_small' + str(epoch))
    scheduler.step()

100%|██████████| 7199/7199 [13:58<00:00,  8.59it/s]


Epoch: 0 Loss: 2.3709521861784424


100%|██████████| 7199/7199 [13:51<00:00,  8.66it/s]


Epoch: 1 Loss: 2.228626253257478


100%|██████████| 7199/7199 [13:56<00:00,  8.61it/s]


Epoch: 2 Loss: 2.139353451670533


100%|██████████| 7199/7199 [14:02<00:00,  8.54it/s]


Epoch: 3 Loss: 2.0787513495104264


100%|██████████| 7199/7199 [14:05<00:00,  8.51it/s]


Epoch: 4 Loss: 2.0206317302998213


100%|██████████| 7199/7199 [14:05<00:00,  8.52it/s]


Epoch: 5 Loss: 1.9785397347766072


100%|██████████| 7199/7199 [14:48<00:00,  8.10it/s]


Epoch: 6 Loss: 1.935774866545891


100%|██████████| 7199/7199 [14:48<00:00,  8.10it/s]


Epoch: 7 Loss: 1.9040290086380987


100%|██████████| 7199/7199 [14:05<00:00,  8.52it/s]


Epoch: 8 Loss: 1.870218997902714


100%|██████████| 7199/7199 [13:45<00:00,  8.72it/s]


Epoch: 9 Loss: 1.846056071316671


In [16]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
checkpoint = torch.load('save_small' + str(6))
model.load_state_dict(checkpoint['model_state_dict'])
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-8)
random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)

for epoch in range(10):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    scheduler.step()
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_superfine' + str(epoch))

KeyboardInterrupt: 

In [7]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    answers += [[item['input_text'] for item in dev_dict['data'][dev_index]['additional_answers'][str(index)]] for index in range(3)]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [8]:
def generate_answer_number(model, text, number):
    prompt = get_text_up_to_question_number(text, number)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 20
    tokens_length = tokens.shape[1]
    output = model.generate(
             tokens,
             max_length=tokens_length + _length,
             temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_number(output, number)

In [9]:
def compute_accuracy_of_model_with_all_answers(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            prediction = generate_answer_number(model, text, number)
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            if not it_was_answered:
                print('No Answer for number', number)
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions

In [10]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [11]:
def generate_answer(model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             #temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    offset = len(prompt)
    start = offset + 1
    end = output.find('\n', start)
    return output[start:end].split(':')[-1].strip()

In [12]:
def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            prediction = generate_answer(model, small_text)
            #print('PREDICTION', prediction)
            if not prediction:
                print('NO PREDICTION!!')
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [16]:
checkpoint = torch.load('save_small' + str(1))
model.load_state_dict(checkpoint['model_state_dict'])
_ = model.train()

In [39]:
def generate_multiple_answers(model, prompt, num_replicas=6):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    
    outputs = []
    for _ in range(num_replicas):
        output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
        )
        output = tokenizer.decode(output[0], skip_special_tokens=True)
        offset = len(prompt)
        start = offset + 1
        end = output.find('\n', start)
        outputs.append(output[start:end].split(':')[-1].strip())
        
    return list(set(outputs))

In [40]:
doc=0
number = 0
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)
generate_multiple_answers(model, small_text)

['white', 'orange']

In [24]:
false_positives = None
wrong_answers = None

for index in range(10):
    checkpoint = torch.load('save_small' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, wrong_answers, false_positives = compute_accuracy_of_model(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

100%|██████████| 10/10 [01:03<00:00,  6.35s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 0  with accuracy: 0.6808510638297872


  0%|          | 0/10 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [22]:
def clean_answer(answer):
    answer = answer.lower()
    answer = answer.strip()
    answer = answer.replace('.', '')
    return answer

In [23]:
def compute_accuracy_of_model_only_y_or_n(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            prediction = generate_answer(model, small_text)
            if clean_answer(prediction) not in ['yes', 'no']:
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [24]:
false_positives = None
wrong_answers = None

for index in range(10):
    checkpoint = torch.load('save_small' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, wrong_answers, false_positives = compute_accuracy_of_model_only_y_or_n(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

100%|██████████| 10/10 [01:03<00:00,  6.38s/it]


Epoch: 0  with accuracy: 0.7352941176470589


100%|██████████| 10/10 [01:03<00:00,  6.36s/it]


Epoch: 1  with accuracy: 0.8


100%|██████████| 10/10 [01:03<00:00,  6.30s/it]


Epoch: 2  with accuracy: 0.7428571428571429


100%|██████████| 10/10 [01:02<00:00,  6.29s/it]


Epoch: 3  with accuracy: 0.7428571428571429


100%|██████████| 10/10 [01:04<00:00,  6.47s/it]


Epoch: 4  with accuracy: 0.7428571428571429


100%|██████████| 10/10 [01:05<00:00,  6.52s/it]


Epoch: 5  with accuracy: 0.7142857142857143


100%|██████████| 10/10 [01:04<00:00,  6.45s/it]


Epoch: 6  with accuracy: 0.7714285714285715


100%|██████████| 10/10 [01:03<00:00,  6.38s/it]


Epoch: 7  with accuracy: 0.7142857142857143


100%|██████████| 10/10 [01:05<00:00,  6.56s/it]


Epoch: 8  with accuracy: 0.7714285714285715


100%|██████████| 10/10 [01:04<00:00,  6.47s/it]

Epoch: 9  with accuracy: 0.7428571428571429





In [25]:
false_positives

['they took it to the bottom of the ocean',
 'yes',
 'it lost nearly all of its healthy qualities']

In [26]:
wrong_answers

[{'label': 'white', 'prediction': 'orange'},
 {'label': 'with her mommy and 5 sisters',
  'prediction': 'her mommy and 5 other sisters'},
 {'label': 'she painted herself', 'prediction': 'painted herself like them'},
 {'label': 'paint herself like them',
  'prediction': 'painted herself like them'},
 {'label': 'the farmer', 'prediction': "Cotton's mommy"},
 {'label': "the farmer's", 'prediction': "Cotton's mommy"},
 {'label': "the old farmer's", 'prediction': "Cotton's mommy"},
 {'label': 'a bottle', 'prediction': 'It was hard and clear'},
 {'label': 'the bottle', 'prediction': 'It was hard and clear'},
 {'label': 'Asta', 'prediction': 'Sharkie'},
 {'label': 'no', 'prediction': 'Yes'},
 {'label': 'No', 'prediction': 'Yes'},
 {'label': "They took the note to Asta's papa",
  'prediction': 'they took it to the bottom of the ocean'},
 {'label': 'unknown', 'prediction': 'they took it to the bottom of the ocean'},
 {'label': "They took the note to Asta's papa",
  'prediction': 'they took it t

In [21]:
for index in range(2, 3):
    checkpoint = torch.load('save_superfine' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, _ = compute_accuracy_of_model(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

  0%|          | 0/10 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [None]:
compute_accuracy_of_model(model)

In [None]:
def compute_prediction_list_from_model(model, limit=None):
    prediction_list = []
    dlist = dev_list[:limit]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                     max_num_questions=5,
                                     question_number=number,
                                     last_question=False)
            prediction = generate_answer(model, small_text)
            if not prediction:
                prediction = 'unknown'
            prediction_list.append({
                'id': dev_dict['data'][index]['id'],
                'turn_id': number,
                'answer': prediction,
            })

    return prediction_list

In [None]:
checkpoint = torch.load('save_small' + str(7))
model.load_state_dict(checkpoint['model_state_dict'])
prediction_list = compute_prediction_list_from_model(model)
json.dump(prediction_list, open('../data/predictions.json', 'w', encoding='utf8'))

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
checkpoint = torch.load('save_squad2_' + str(1))
model.load_state_dict(checkpoint['model_state_dict'])
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(0, 6):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    accuracy, _ = compute_accuracy_of_model(model)
    print('Epoch:', epoch, 'Loss:', loss)
    print('Epoch:', epoch, ' with accuracy:', accuracy)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_coqa_' + str(epoch))

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
checkpoint = torch.load('save_superfine'+str(0))
model.load_state_dict(checkpoint['model_state_dict'])
accuracy, _ = compute_accuracy_of_model(model)

In [None]:
accuracy

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
checkpoint = torch.load('save' + str(2))
model.load_state_dict(checkpoint['model_state_dict'])
accuracy, wrong_predictions = compute_accuracy_of_model(model)
wrong_predictions

In [None]:
accuracy

In [None]:
get_all_answers(dev_dict, 0)