In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Question Answering part

In [10]:
def generate_multiple_answers(model, prompt, num_replicas=25):
    model.train()
    outputs = []
    with torch.no_grad():
        tokens = tokenizer.encode(prompt, return_tensors='pt')
        tokens = tokens.repeat(num_replicas,1)
        _length = 50
        tokens_length = tokens.shape[1]
        if tokens_length + _length > 1024:
            return ''

        
        output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
        )
        for index in range(num_replicas):
            text = tokenizer.decode(output[index, :], skip_special_tokens=True)
            offset = len(prompt)
            start = offset + 1
            end = text.find('\n', start)
            outputs.append(text[start:end].split(':')[-1].strip())

    return outputs

In [14]:
config = GPT2Config(attn_pdrop=0.1, resid_pdrop=0.1, embd_pdrop=0.1)

In [15]:
model = GPT2LMHeadModel(config).from_pretrained('gpt2')
model.cuda()
checkpoint = torch.load('save_small' + str(6))
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [23]:
import numpy as np
from sentence_transformers import SentenceTransformer


sentence_model = SentenceTransformer('msmarco-distilbert-base-v3')
sentence_model = sentence_model.to(device)

In [17]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [24]:
def get_embeddings_from_text(text):
    outputs = sentence_model.encode(text)
    return outputs

def group_similar_answers_and_get_scores(answers):
    answers_dict = {}
    threshold = 0.7
    embeddings = get_embeddings_from_text(answers)
    embeddings = np.array([e/np.linalg.norm(e) for e in embeddings])
    similarity_matrix = np.matmul(embeddings, embeddings.transpose())
    superseded = set()
    superseded_from = {}
    for i in range(len(answers)):
        for j in range(len(answers)):
            if i > j:
                continue
            if i != j and answers[i] == answers[j]:
                continue
            if similarity_matrix[i][j] > threshold :
                answers_dict.setdefault(i, 0)
                answers_dict[i] += 1
                if i != j:
                    superseded.add(j)
                    superseded_from.setdefault(i, [])
                    superseded_from[i].append(j)

    answers_and_scores = [(index, score/len(answers))
                          for index, score in answers_dict.items() 
                          if index not in superseded]
    
    new_scores_dict = {}
    total_score = sum(item[1] for item in answers_and_scores)
    for answer_index, score in answers_and_scores:
        answer_group = [answers[answer_index]]
        if answer_index in superseded_from:
            answer_group += [answers[i] for i in superseded_from[answer_index]]
        answer_group = tuple(set(answer_group))
        if answer_group in new_scores_dict:
            new_scores_dict[answer_group] += score / total_score
        else:
            new_scores_dict[answer_group] = score / total_score
    
    
    return sorted(list(new_scores_dict.items()), key=lambda x: -x[1])

In [18]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [19]:
doc=0
number = 0
small_text = get_text_from_data_item(dev_dict['data'][doc], 
                                     max_num_questions=5, 
                                     question_number=number,
                                     last_question=False)
answers = generate_multiple_answers(model, small_text)
answers = group_similar_answers_and_get_scores(answers)
print(small_text)
print(answers)