In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [3]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [4]:
train_list = json.load(open('../data/qa_train_list.json', encoding='utf8'))

In [5]:
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [6]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [7]:
_limit = 1024
data = []
total_trimmed = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        tokens = tokens[:, :_limit]
        total_trimmed += 1
    data.append(tokens)
print(f'Trimmed {total_trimmed} out of {len(train_list)}')

Token indices sequence length is longer than the specified maximum sequence length for this model (1088 > 1024). Running this sequence through the model will result in indexing errors


Trimmed 52 out of 7199


In [7]:
train_batches = batchify(data, 1)

In [3]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [6]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

def test(test_model, batches):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs, labels=inputs)[0]
        total_loss += loss.item()

    return total_loss / len(batches)

In [10]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(10):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_small' + str(epoch))
    scheduler.step()

100%|██████████| 7199/7199 [13:32<00:00,  8.86it/s]


Epoch: 0 Loss: 2.371021380919949


  0%|          | 21/7199 [00:02<14:00,  8.54it/s]


KeyboardInterrupt: 

## Testing with entailment

In [96]:
checkpoint = torch.load('save_small' + str(6))
model.load_state_dict(checkpoint['model_state_dict'])
device = "cuda"
_ = model.to(device)

In [97]:
def inference(model, tokens, length):
    return model.generate(
        tokens.to(device),
        max_length=tokens.shape[1] + length,
        pad_token_id=tokenizer.eos_token_id,
    )

In [201]:
def generate_answer(model, prompt, length=10):
    model.eval()
    with torch.no_grad():
        tokens = tokenizer.encode(prompt, return_tensors='pt')
        text = prompt
        start = len(prompt)
        while "\n" not in text[start:]:
            output = inference(model, tokens, length)
            decoded = tokenizer.decode(output[0], skip_special_tokens=True)
            text += decoded[len(text):]
            tokens = output

    end = text.find('\n', start)
    return text[start:end].split(':')[-1].strip()

In [5]:
#### Optimize this with variable _length generation as in generate_answer()

def generate_multiple_answers_with_dropout(model, prompt, num_replicas=25):
    model.train()
    outputs = []
    with torch.no_grad():
        tokens = tokenizer.encode(prompt, return_tensors='pt')
        tokens = tokens.repeat(num_replicas,1)
        _length = 50
        tokens_length = tokens.shape[1]
        if tokens_length + _length > 1024:
            return ''
        
        output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
        )
        for index in range(num_replicas):
            text = tokenizer.decode(output[index, :], skip_special_tokens=True)
            offset = len(prompt)
            start = offset + 1
            end = text.find('\n', start)
            outputs.append(text[start:end].split(':')[-1].strip())

    return outputs

In [48]:
def generate_multiple_answers(model, prompt, num_replicas, length=100):
    tokens = tokenizer(prompt, return_tensors='pt').to("cuda")
    tokens_length = tokens.input_ids.shape[1]
    if tokens_length + length > 1024:
        return ''
    generated_ids = model.generate(**tokens,
                                   num_beams=num_replicas, 
                                   num_return_sequences=num_replicas,
                                   max_length=tokens_length + length,
                                   no_repeat_ngram_size=5,
                                  )
    generated_sentences = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    sentences = []
    for index, item in enumerate(generated_sentences):
        output = generated_sentences[index]
        offset = len(prompt)
        start = offset + 1
        end = min(output.find('\n', start), output.find('Q:', start))
        sentences.append(item[start: end])
    
    return sentences

In [195]:
statement_model = GPT2LMHeadModel.from_pretrained('gpt2')
statement_model.cuda()
checkpoint = torch.load('save_statement' + str(2))
statement_model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [9]:
question = "How is albert's ship called?"

In [10]:
story = "Albert says: 'My father's ship is called st. George, mine is Sir George'."

In [208]:
%time
prompt = f"""
In the text below two people are discussing a story.

Story:
{story}

Discussion:
Q: {question}
A: 
""".strip()

answers = generate_multiple_answers_with_dropout(model, prompt, num_replicas=15)
#answers = generate_multiple_answers(model, prompt, num_replicas=50)

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 4.77 µs


In [207]:
def get_statement_from_question_and_answer(question, answer):
    text = f"""
    Discussion:
    Q: {question}
    A: {answer}

    Statement:
    """.strip()
    return generate_answer(statement_model, text + '\n')

In [24]:
import torch

from typing import Dict
from transformers import AutoTokenizer, AutoModelForSequenceClassification

_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Entailer:
    def __init__(self):
        model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
        self._tokenizer = AutoTokenizer.from_pretrained(model_name)
        self._model = AutoModelForSequenceClassification.from_pretrained(model_name).to(
            _device
        )

    def get_relation(self, premise: str, hypothesis: str) -> Dict[str, float]:
        encodings = self._tokenizer(
            premise, hypothesis, truncation=True, return_tensors="pt"
        )
        output = self._model(encodings["input_ids"].to(_device))
        prediction = torch.softmax(output["logits"][0], -1).tolist()
        label_names = ["entailment", "neutral", "contradiction"]
        prediction = {name: float(pred) for pred, name in zip(prediction, label_names)}
        return prediction

    def entails(self, premise: str, hypothesis: str, threshold=0.8) -> bool:
        prediction = self.get_relation(premise, hypothesis)
        if prediction["entailment"] > threshold:
            return True

        if prediction["neutral"] > threshold:
            premise = self._add_presuppositions_to_premise(premise)
            prediction = self.get_relation(premise, hypothesis)

        return prediction["entailment"] > threshold

    def _add_presuppositions_to_premise(self, premise):
        premise = premise.replace("user says:", "user says to this bot:")
        premise = premise.replace("user asks:", "user asks to this bot:")
        return premise

In [25]:
entailer = Entailer()

In [200]:
%time
statements = []
for answer in set(answers):
    statements.append(get_statement_from_question_and_answer(question, answer))

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 7.63 µs


In [199]:
print(story)
print(question)
print()

for statement in statements:
    print(statement, entailer.entails(story, statement))

Albert says: 'My father's ship is called st. George, mine is Sir George'.
How is albert's ship called?

Albert's ship is called St. George '. False
What is the name of the ship that is named St. George? False
Albert's ship is called St. George. False
Albert's ship is called Sir George '. True
Albert's ship is called George. False
Albert's ship is called the st. George. False
