In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [3]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [4]:
train_list = json.load(open('../data/qa_train_list.json', encoding='utf8'))
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [5]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [6]:
_limit = 1024
data = []
total_trimmed = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        tokens = tokens[:, :_limit]
        total_trimmed += 1
    data.append(tokens)
print(f'Trimmed {total_trimmed} out of {len(train_list)}')

Token indices sequence length is longer than the specified maximum sequence length for this model (1088 > 1024). Running this sequence through the model will result in indexing errors


Trimmed 52 out of 7199


In [7]:
train_batches = batchify(data, 1)

In [6]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [9]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

def test(test_model, batches):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs, labels=inputs)[0]
        total_loss += loss.item()

    return total_loss / len(batches)

In [None]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(10):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_small' + str(epoch))
    scheduler.step()

In [7]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    answers += [[item['input_text'] for item in dev_dict['data'][dev_index]['additional_answers'][str(index)]] for index in range(3)]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [8]:
def generate_answer_number(model, text, number):
    prompt = get_text_up_to_question_number(text, number)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 20
    tokens_length = tokens.shape[1]
    output = model.generate(
             tokens,
             max_length=tokens_length + _length,
             temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_number(output, number)

In [9]:
def compute_accuracy_of_model_with_all_answers(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            prediction = generate_answer_number(model, text, number)
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            if not it_was_answered:
                print('No Answer for number', number)
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions

In [10]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [56]:
def get_text_from_data_item_and_story(story, item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + story + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [11]:
def generate_answer(model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             #temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    offset = len(prompt)
    start = offset + 1
    end = output.find('\n', start)
    return output[start:end].split(':')[-1].strip()

# Simplify text according to question

In [90]:
import re
from nltk.tokenize.texttiling import TextTilingTokenizer

In [91]:
def clean_chapter(text):
    text = re.sub("\[.*?\]", "", text)
    text = re.sub("\(.*?\)", "", text)
    return text

In [123]:
def get_chapters_from_nltk(text_file):
    tt = TextTilingTokenizer(w=1, k=2)
    paragraphs = tt.tokenize(text_file)
    paragraphs = [p.replace('\r\n', '\n') for p in paragraphs]
    paragraphs = [clean_chapter(p) for p in paragraphs]
    return paragraphs

In [93]:
from sentence_transformers import SentenceTransformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sentence_model = SentenceTransformer('msmarco-distilbert-base-v3')
sentence_model = sentence_model.to(device)

In [94]:
from gensim.models import KeyedVectors

In [95]:
def get_embeddings_from_text(text):
    outputs = sentence_model.encode(text)
    return outputs

In [154]:
def get_sentence_retriever_from_text(text):
    retriever = KeyedVectors(768)
    
    try:
        sentences = get_chapters_from_nltk(text)
    except ValueError:
        embeddings = get_embeddings_from_text(text)
        retriever.add_vector(text, embeddings)
        return retriever
    for sentence in sentences:
        embeddings = get_embeddings_from_text(sentence)
        retriever.add_vector(sentence, embeddings)
        
    return retriever

In [334]:
def get_documents_and_scores(retriever, query):
    embeddings = get_embeddings_from_text(query)
    results = retriever.similar_by_vector(embeddings, topn=5)
    return results

In [335]:
def get_text_from_retriever(retriever, query):
    docs_and_scores = get_documents_and_scores(retriever, query)
    return '. '.join(item[0] for item in docs_and_scores).strip() + '.'

In [336]:
dev_dict['data'][0]['questions'][0]['input_text']

'What color was Cotton?'

In [337]:
def get_text_for_index(index):
    return dev_dict['data'][index]['story']

def get_question_for_index_and_number(index, number):
    return dev_dict['data'][index]['questions'][number]['input_text']

In [338]:
retriever = get_sentence_retriever_from_text(get_text_for_index(0))

In [339]:
print(get_text_from_retriever(retriever, get_question_for_index_and_number(0, 0)))

Once upon a time, in a barn near a farm house, there lived a little white kitten named Cotton. Cotton lived high up in a nice warm place above the barn where all of the farmer's horses slept. But Cotton wasn't alone in her little home above the barn, oh no. She shared her hay bed with her mommy and 5 other sisters. All of her sisters were cute and fluffy, like Cotton. But she was the only white one in the bunch. The rest of her sisters were all orange with beautiful white tiger stripes like Cotton's mommy. Being different made Cotton quite sad. She often wished she looked like the rest of her family. So one day, when Cotton found a can of the old farmer's orange paint, she used it to paint herself like them. When her mommy and sisters found her they started laughing..  

"What are you doing, Cotton?!" 

"I only wanted to be more like you". 

Cotton's mommy rubbed her face on Cotton's and said "Oh Cotton, but your fur is so pretty and special, like you. We would never want you to be any

## Actually test the model

In [340]:
def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:100]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):

        all_answers = get_all_answers(dev_dict, index)
        total_questions = len(all_answers)        
        
        for number in range(total_questions):
            retriever = get_sentence_retriever_from_text(get_text_for_index(index))
            story = get_text_from_retriever(retriever, get_question_for_index_and_number(index, number))
            small_text = get_text_from_data_item_and_story(story, dev_dict['data'][index],
                                     max_num_questions=5,
                                     question_number=number,
                                     last_question=False)
            #print(small_text)
            prediction = generate_answer(model, small_text)
            #print('prediction', prediction)
            if not prediction:
                print('NO PREDICTION!!')
                prediction = 'unknown'
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [341]:
false_positives = None
wrong_answers = None

for index in range(6,7):
    checkpoint = torch.load('save_small' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, wrong_answers, false_positives = compute_accuracy_of_model(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

  3%|▎         | 3/100 [00:18<10:36,  6.56s/it]

NO PREDICTION!!
NO PREDICTION!!


 15%|█▌        | 15/100 [01:40<09:27,  6.68s/it]

NO PREDICTION!!
NO PREDICTION!!


 27%|██▋       | 27/100 [03:18<11:00,  9.04s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:23<00:00,  7.43s/it]

Epoch: 6  with accuracy: 0.6866708780795957





In [333]:
wrong_answers

[{'label': "the old farmer's", 'prediction': 'her mommy'},
 {'label': 'the farmer', 'prediction': 'her mommy'},
 {'label': "the farmer's", 'prediction': 'her mommy'},
 {'label': 'they started laughing', 'prediction': 'laughed'},
 {'label': 'rubbed her face', 'prediction': 'laughed'},
 {'label': 'started laughing', 'prediction': 'laughed'},
 {'label': 'into a big bucket of water', 'prediction': 'a bucket of water'},
 {'label': 'dropped her into a big bucket of water',
  'prediction': 'a bucket of water'},
 {'label': 'the bottle', 'prediction': 'It was hard and clear'},
 {'label': 'a bottle', 'prediction': 'It was hard and clear'},
 {'label': 'Sharkie', 'prediction': 'a friend'},
 {'label': 'Asta', 'prediction': 'a friend'},
 {'label': 'Asta', 'prediction': 'a friend'},
 {'label': 'no', 'prediction': 'Yes'},
 {'label': 'No', 'prediction': 'Yes'},
 {'label': "They took the note to Asta's papa",
  'prediction': "took it to Asta's papa"},
 {'label': 'unknown', 'prediction': "took it to Asta

### Best result with best_answers=3
