In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [3]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [4]:
train_list = json.load(open('../data/qa_train_list.json', encoding='utf8'))

In [5]:
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [6]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

## Testing with entailment

In [7]:
checkpoint = torch.load('save_small' + str(6))
model.load_state_dict(checkpoint['model_state_dict'])
device = "cuda"
_ = model.to(device)

In [8]:
def inference(model, tokens, length):
    return model.generate(
        tokens.to(device),
        max_length=tokens.shape[1] + length,
        pad_token_id=tokenizer.eos_token_id,
    )

In [9]:
def generate_answer(model, prompt, topk=10):
    model.eval()
    with torch.no_grad():
        tokens = tokenizer.encode(prompt, return_tensors='pt')
        text = prompt
        start = len(prompt)
        while "\n" not in text[start:]:
            output = inference(model, tokens, length)
            decoded = tokenizer.decode(output[0], skip_special_tokens=True)
            text += decoded[len(text):]
            tokens = output

    end = text.find('\n', start)
    return text[start:end].split(':')[-1].strip()

In [10]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [11]:
def generate_all_answers(model, prompt, num_replicas=10):
    model.eval()
    
    answers_and_scores = [("", 1) for _ in range(num_replicas*num_replicas)]
    
    texts = [prompt] * num_replicas
    start = len(prompt)
    with torch.no_grad():
        while any("\n" not in text[start:] for text in texts):
            tokens = tokenizer.batch_encode_plus(texts, return_tensors='pt').input_ids
            tokens = tokens.to(device)
            output = model(tokens)
            for answer_index in range(num_replicas):
                probs = torch.softmax(output.logits[answer_index][-1], dim=-1)
                indices = torch.topk(probs, k=num_replicas).indices
                for output_index, token_index in enumerate(indices):
                    p = probs[token_index]
                    total_index = answer_index * num_replicas + output_index
                    new_tokens = tokens[answer_index].tolist() + [int(token_index)]
                    decoded_text = tokenizer.decode(new_tokens)
                    new_score =  answers_and_scores[total_index][1] * float(p)
                    answers_and_scores[total_index] = decoded_text, new_score
            
            new_answers_and_scores = []
            already_answered = set()
            for answer, score in answers_and_scores:
                if answer in already_answered:
                    continue
                
                new_answers_and_scores.append((answer, score))
                already_answered.add(answer)
            
            new_answers_and_scores = sorted(new_answers_and_scores, key=lambda x: -x[1])
            texts = [new_answers_and_scores[index][0] for index in range(num_replicas)]
            for i in range(num_replicas):
                for j in range(num_replicas):
                    answers_and_scores[i * j] = new_answers_and_scores[j]

    scores = []
    for index in range(num_replicas):
        end = texts[index].find('\n', start)
        texts[index] = texts[index][start:end].split(':')[-1].strip()
        scores.append(answers_and_scores[index][1])

    return texts, scores

In [12]:
story = "Albert says: 'My father's ship is called st. George, mine is Sir George'."
question = "How is albert's ship called?"

In [13]:
%%time
prompt = f"""
In the text below two people are discussing a story.

Story:
{story}

Discussion:
Q: {question}
A: 
""".strip()

answers, scores = generate_all_answers(model, prompt, num_replicas=4)

CPU times: user 63.3 ms, sys: 821 µs, total: 64.1 ms
Wall time: 63.7 ms


In [14]:
answers

['st. George', 'St. George', 'Sir George', 'George']

In [15]:
scores

[0.29918041831489056,
 0.29318870276476167,
 0.29318870276476167,
 0.29318870276476167]

In [16]:
def generate_multiple_answers_with_dropout(model, prompt, num_replicas=25, length=5):
    model.train()
    outputs = []
    start = len(prompt)
    with torch.no_grad():
        tokens = tokenizer.encode(prompt, return_tensors='pt')
        tokens = tokens.repeat(num_replicas,1)
        texts = [prompt] * num_replicas
        while any("\n" not in text[start:] for text in texts):
            if tokens.shape[1] + length > 1024:
                break
            
            output = model.generate(
                tokens.cuda(),
                max_length=tokens.shape[1] + length,
                pad_token_id=50256
            )
            texts = tokenizer.batch_decode(output, skip_special_tokens=True)
            tokens = output
            
    for index in range(num_replicas):
        end = texts[index].find('\n', start)
        texts[index] = texts[index][start:end].split(':')[-1].strip()

    return texts

In [17]:
story = "Albert says: 'My father's ship is called st. George, mine is Sir George'."
question = "How is albert's ship called?"

In [18]:
%%time
prompt = f"""
In the text below two people are discussing a story.

Story:
{story}

Discussion:
Q: {question}
A: 
""".strip()

answers = generate_multiple_answers_with_dropout(model, prompt, num_replicas=15)

CPU times: user 75.9 ms, sys: 0 ns, total: 75.9 ms
Wall time: 75.5 ms


In [19]:
answers

['St. George',
 'st. George',
 'George',
 'st. George',
 'st. George',
 "George's",
 'St. George',
 'St. George',
 'St. George',
 'st. George',
 'St. George',
 'St. George',
 'St. George',
 'St. George',
 'st. George']

In [20]:
def generate_multiple_answers(model, prompt, num_replicas, length=100):
    tokens = tokenizer(prompt, return_tensors='pt').to("cuda")
    tokens_length = tokens.input_ids.shape[1]
    if tokens_length + length > 1024:
        return ''
    generated_ids = model.generate(**tokens,
                                   num_beams=num_replicas,
                                   num_return_sequences=num_replicas,
                                   max_length=tokens_length + length,
                                  )
    generated_sentences = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    sentences = []
    for index, item in enumerate(generated_sentences):
        output = generated_sentences[index]
        offset = len(prompt)
        start = offset + 1
        end = min(output.find('\n', start), output.find('Q:', start))
        sentences.append(item[start: end])
    
    return sentences

In [21]:
statement_model = GPT2LMHeadModel.from_pretrained('gpt2')
statement_model.cuda()
checkpoint = torch.load('save_statement' + str(0))
statement_model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [22]:
story = "Albert says: 'My father's ship is called st. George, mine is Sir George'."
question = "How is albert's ship called?"

In [23]:
%%time
def get_answer_prompt(story, question):
    return f"""
In the text below two people are discussing a story.

Story:
{story}

Discussion:
Q: {question}
A: 
""".strip()

prompt = get_answer_prompt(story, question)
#answers = generate_multiple_answers_with_dropout(model, prompt, num_replicas=15)
answers = generate_multiple_answers(model, prompt, num_replicas=50)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


CPU times: user 1.56 s, sys: 5.97 ms, total: 1.57 s
Wall time: 1.57 s


In [24]:
import torch

from typing import Dict, List
from transformers import AutoTokenizer, AutoModelForSequenceClassification

_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Entailer:
    def __init__(self):
        model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
        self._tokenizer = AutoTokenizer.from_pretrained(model_name)
        self._model = AutoModelForSequenceClassification.from_pretrained(model_name).to(
            _device
        )

    def get_relation(self, premise: str, hypothesis: str) -> Dict[str, float]:
        with torch.no_grad():
            encodings = self._tokenizer(
                premise, hypothesis, truncation=True, return_tensors="pt"
            )
            output = self._model(encodings["input_ids"].to(_device))
            prediction = torch.softmax(output["logits"][0], -1).tolist()
            label_names = ["entailment", "neutral", "contradiction"]
            prediction = {name: float(pred) for pred, name in zip(prediction, label_names)}
            return prediction

    def entails(self, premise: str, hypothesis: str, threshold=0.5) -> bool:
        prediction = self.get_relation(premise, hypothesis)
        if prediction["entailment"] > threshold:
            return True
        
        return prediction["entailment"] > threshold

    def batch_entails(self, premise: List[str], hypothesis: List[str], threshold=0.5) -> bool:
        with torch.no_grad():
            encodings = self._tokenizer.batch_encode_plus([lhs + " [SEP] " + rhs 
                                                           for lhs, rhs in zip(premise, hypothesis)], 
                                                          return_tensors="pt", 
                                                          padding=True
            )
            output = self._model(**encodings.to(_device))
            prediction = torch.softmax(output["logits"], dim=-1)
            return (prediction[:, 0] > threshold).tolist()


In [25]:
entailer = Entailer()



In [26]:
entailer.batch_entails(["my name is alberto", "my name is John"], ["Alberto is the name", "Alberto is speaking"])

[True, False]

In [27]:
def get_discussion_from_text(text, last_n=None):
    trigger = "Discussion:\n"
    start = text.find(trigger) + len(trigger)
    end = text.rfind("\n")
    text = text[start:end]
    if last_n:        
        chunks = text.split("\nQ:")
        text = "Q:" + "\nQ:".join(chunks[-last_n:])
        return text

    return text

In [28]:
dialogue = get_discussion_from_text(dev_list[0], last_n=3)

In [29]:
def get_statement_prompt_from_dialogue_and_answer(dialogue, answer):
        return f"""
Discussion:
{dialogue.strip()}
A: {answer}

Statement:
    """.strip()

In [30]:
print(get_statement_prompt_from_dialogue_and_answer(dialogue, answer="No"))

Discussion:
Q: Where did Cotton's mother put her to clean the paint off?
A: a bucket of water
Q: What did the other cats do when Cotton emerged from the bucket of water?
A: licked her face
Q: Did they want Cotton to change the color of her fur?
A: No

Statement:


In [31]:
def get_statement_prompt_from_question_and_answer(question, answer):
    return f"""
Discussion:
Q: {question}
A: {answer}

Statement:
    """.strip()

In [32]:
def generate_answers_from_multiple_prompts(model, prompts, length=5):
    model.eval()
    starts = [len(prompt) for prompt in prompts]
    num_replicas = len(prompts)
    with torch.no_grad():
        model.eval()
        tokens = tokenizer.batch_encode_plus(prompts, padding=True, return_tensors='pt').input_ids
        texts = prompts
        while any("\n" not in text[start + 1:] for text, start in zip(texts, starts)):
            if tokens.shape[1] + length > 1024:
                break
            
            output = model.generate(
                tokens.cuda(),
                max_length=tokens.shape[1] + length,
                pad_token_id=50256
            )
            texts = tokenizer.batch_decode(output, skip_special_tokens=True)
            tokens = output

    for index in range(num_replicas):
        end = texts[index].find('\n', starts[index] + 1)
        texts[index] = texts[index][starts[index]:end].split(':')[-1].strip()

    return texts

In [33]:
def get_statement_from_question_and_answer(question, answer):
    text = f"""
Discussion:
Q: {question}
A: {answer}

Statement:
    """.strip()
    return generate_answer(statement_model, text + '\n')

In [34]:
%%time
count_dict = {}
for answer in answers:
    count_dict.setdefault(answer, 0)
    count_dict[answer] += 1

CPU times: user 13 µs, sys: 2 µs, total: 15 µs
Wall time: 17.2 µs


### Measuring Score on Dev set

In [35]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

In [36]:
statement_model = GPT2LMHeadModel.from_pretrained('gpt2')
statement_model.cuda()
checkpoint = torch.load('save_statement' + str(0))
statement_model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [37]:
story = "Albert says: 'My father's ship is called st. George, mine is Sir George'."
question = "How is albert's ship called?"
prompt = get_answer_prompt(story, question)

In [38]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    answers += [[item['input_text'] for item in dev_dict['data'][dev_index]['additional_answers'][str(index)]] for index in range(3)]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [39]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [40]:
def get_answer_with_entailment(model, statement_model, prompt, story, dialogue):
    #answers = generate_multiple_answers_with_dropout(model, prompt, num_replicas=15)
    answers, _ = generate_all_answers(model, prompt, num_replicas=10)
    statements = []
    
    valid_answers = set()
    repetitions = 0
    while not valid_answers and repetitions < 1:
        count_dict = {}
        for answer in answers:
            count_dict.setdefault(answer, 0)
            count_dict[answer] += 1
        
        answers = count_dict.keys()
        
        statement_prompts = []
        for answer in answers:
            statement_prompts.append(get_statement_prompt_from_dialogue_and_answer(dialogue, answer))
        
        statements = generate_answers_from_multiple_prompts(statement_model, statement_prompts)
        entailment_predictions = entailer.batch_entails([story] * len(statements), statements)
        for answer, answer_is_true in zip(answers, entailment_predictions):
            if answer_is_true:
                valid_answers.add(answer)
                
        repetitions += 1
            
    if not valid_answers:
        return "unknown"
        
    ranked_answers = [(answer, count_dict[answer]) for answer in valid_answers]
    return sorted(ranked_answers, key=lambda x: -x[1])[0][0]

In [60]:
index = 2
number = 1

small_text = get_text_from_data_item(dev_dict['data'][index], 
                                     max_num_questions=3,
                                     question_number=number,
                                     last_question=False)
prediction = get_answer_with_entailment(model, 
               statement_model,
               small_text,
               dev_dict['data'][index]["story"],
               get_discussion_from_text(small_text, last_n=2))

In [61]:
prediction

'I know what is inside the bag--a thermos with hot soup and a stainless-steel container'

In [64]:
get_statement_prompt_from_dialogue_and_answer(get_discussion_from_text(small_text, last_n=3), "Her mommy's")

"Discussion:\nQ:Q: Who is at the door?\nA: An elderly Chinese lady and a little boy\nQ: Is she carrying something?\nA: Her mommy's\n\nStatement:"

In [65]:
print(small_text)

In the text below two people are discussing a story.

Story:
My doorbell rings. On the step, I find the elderly Chinese lady, small and slight, holding the hand of a little boy. In her other hand, she holds a paper carrier bag. 

I know this lady. It is not her first visit. She is the boy's grandmother, and her daughter bought the house next door last October. 

Her daughter, Nicole, speaks fluent English. But she is now in Shanghai, and her parents are here with the little boy. Nicole has obviously told her mother that I am having heart surgery soon, so her mother has decided I need more nutrients. 

I know what is inside the bag--a thermos with hot soup and a stainless-steel container with rice, vegetables and either chicken, meat or shrimp, sometimes with a kind of pancake. This has become an almost-daily practice. 

Communication between us is somewhat affected by the fact that she doesn't speak English and all I can say in Chinese is hello. Once, she brought an iPad as well as the

In [51]:
from fuzzywuzzy import fuzz

def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):

        all_answers = get_all_answers(dev_dict, index)
        total_questions = len(all_answers)        
        
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=3,
                                                 question_number=number,
                                                 last_question=False)
            try:
                prediction = get_answer_with_entailment(model, 
                               statement_model,
                               small_text,
                               dev_dict['data'][index]["story"],
                               get_discussion_from_text(small_text, last_n=2))
            
            except:
                continue
            
            if not prediction or prediction == "unknown":
                continue
        
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                    
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                    
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                    
                elif entailer.entails(label, prediction):
                    correct_answers += 1
                    it_was_answered = True
                    break
                    
                elif entailer.entails(label, prediction):
                    correct_answers += 1
                    it_was_answered = True
                    break
                    
                else:
                    wrong_predictions.append({
                        'index': index,
                        'number' : number,
                        'label': label, 
                        'prediction': prediction})
                    
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [52]:
compute_accuracy_of_model(model)

100%|███████████████████████████████████████████| 10/10 [08:35<00:00, 51.56s/it]


(0.6837606837606838,
 [{'index': 0,
   'number': 6,
   'label': 'paint herself like them',
   'prediction': '()'},
  {'index': 0,
   'number': 6,
   'label': 'she painted herself',
   'prediction': '()'},
  {'index': 1,
   'number': 1,
   'label': 'the bottle',
   'prediction': 'It was hard and clear'},
  {'index': 1,
   'number': 1,
   'label': 'a bottle',
   'prediction': 'It was hard and clear'},
  {'index': 1, 'number': 2, 'label': 'Asta', 'prediction': 'a friend'},
  {'index': 1, 'number': 2, 'label': 'Sharkie', 'prediction': 'a friend'},
  {'index': 1, 'number': 2, 'label': 'Asta', 'prediction': 'a friend'},
  {'index': 1,
   'number': 10,
   'label': 'unknown',
   'prediction': 'So did they open the note'},
  {'index': 2,
   'number': 1,
   'label': 'Yes',
   'prediction': 'I know what is inside the bag--a thermos with hot soup and a stainless-steel container'},
  {'index': 2,
   'number': 1,
   'label': 'a little boy',
   'prediction': 'I know what is inside the bag--a thermos 

# first 10, last 3 questions:

    without replicas: 0.770

    15 replicas, 5 repetitions, single question: 0.787

    15 replicas, 5 repetitions, using whole dialogue: 0.777
    
    Using unique beam search (deterministic, only one repetition):
    - 10 replicas, 1 repetition, last 3 question for statements: 0.628099173553719
    - 10 replicas, 1 repetition, last 2 question for statements: 0.6837606837606838