In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
import ipywidgets as widgets

In [3]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

In [22]:
train_dict = json.load(open('../data/qa_train_BART.json', encoding='utf8'))
dev_dict = json.load(open('../data/qa_dev_BART.json', encoding='utf8'))

In [5]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        in_length = item[0].shape[1]
        out_length = item[1].shape[1]
        length = (in_length, out_length)
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0][0] for item in chunk])
        labels = torch.stack([item[1][0] for item in chunk])
        batches.append((inputs, labels))

    return batches

In [6]:
_limit = 1024
data = []
total_skipped = 0
for item in train_dict:
    input_tokens = tokenizer.encode(item['inputs'], return_tensors='pt')
    output_tokens = tokenizer.encode(item['labels'], return_tensors='pt')
    if input_tokens.shape[1] > _limit:
        input_tokens = input_t okens[:, :_limit]
        total_skipped += 1
        continue
    if output_tokens.shape[1] > _limit:
        output_tokens = output_tokens[:, :_limit]
        total_skipped += 1
        continue
    data.append((input_tokens, output_tokens))
print(f'Skipped {total_skipped} out of {len(train_dict)}')

Token indices sequence length is longer than the specified maximum sequence length for this model (1026 > 1024). Running this sequence through the model will result in indexing errors


Skipped 172 out of 101448


In [7]:
random.shuffle(data)

In [8]:
train_batches = batchify(data, 2)

In [9]:
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

In [10]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch[0]
        labels = batch[1]
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=labels.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

def test(test_model, batches):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs, labels=inputs)[0]
        total_loss += loss.item()

    return total_loss / len(batches)

In [11]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(10):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_small' + str(epoch))
    scheduler.step()

100%|██████████| 52893/52893 [1:17:30<00:00, 11.37it/s]


Epoch: 0 Loss: 1.405121539219829


100%|██████████| 52893/52893 [1:17:37<00:00, 11.36it/s]


Epoch: 1 Loss: 1.1591838920990236


100%|██████████| 52893/52893 [1:17:32<00:00, 11.37it/s]


Epoch: 2 Loss: 0.9497297893278275


100%|██████████| 52893/52893 [1:17:27<00:00, 11.38it/s]


Epoch: 3 Loss: 0.8411330554772071


100%|██████████| 52893/52893 [1:14:10<00:00, 11.88it/s]


Epoch: 4 Loss: 0.6874568890077913


100%|██████████| 52893/52893 [1:14:22<00:00, 11.85it/s]


Epoch: 5 Loss: 0.5952435001647443


100%|██████████| 52893/52893 [1:13:29<00:00, 12.00it/s]


Epoch: 6 Loss: 0.4727807852065438


100%|██████████| 52893/52893 [1:13:24<00:00, 12.01it/s]


Epoch: 7 Loss: 0.3982763123362173


100%|██████████| 52893/52893 [1:13:15<00:00, 12.03it/s]


Epoch: 8 Loss: 0.30694361956989796


100%|██████████| 52893/52893 [1:13:09<00:00, 12.05it/s]


Epoch: 9 Loss: 0.2541343873214737


In [12]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    answers += [[item['input_text'] for item in dev_dict['data'][dev_index]['additional_answers'][str(index)]] for index in range(3)]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [13]:
def generate_answer_number(model, text, number):
    prompt = get_text_up_to_question_number(text, number)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 20
    tokens_length = tokens.shape[1]
    output = model.generate(
             tokens,
             max_length=tokens_length + _length,
             temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_number(output, number)

In [14]:
def compute_accuracy_of_model_with_all_answers(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            prediction = generate_answer_number(model, text, number)
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            if not it_was_answered:
                print('No Answer for number', number)
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions

In [15]:
def get_text_from_data_item(item, question_number=-1, last_question=True):
    text = ''
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][:question_number+1], 
                                       item['answers'][:question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [16]:
def generate_answer(model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             #temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return output

In [17]:
def generate_multiple_answers(model, prompt, num_replicas=6):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    
    outputs = []
    for _ in range(num_replicas):
        output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
        )
        output = tokenizer.decode(output[0], skip_special_tokens=True)
        offset = len(prompt)
        start = offset + 1
        end = output.find('\n', start)
        outputs.append(output[start:end].split(':')[-1].strip())
        
    return list(set(outputs))

In [18]:
checkpoint = torch.load('save_small' + str(1))
model.load_state_dict(checkpoint['model_state_dict'])
_ = model.train()

In [19]:
doc=0
number = 6
generate_answer(model, dev_dict[doc + number]['inputs'])

'she was a can'

In [27]:
def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_dict[:1000]
    for item in tqdm(dlist, total=len(dlist)):
        small_text = item['inputs']
        prediction = generate_answer(model, small_text)
        if not prediction:
            print('NO PREDICTION!!')
            continue
        prediction = prediction.replace('.', '').replace('"', '')
        it_was_answered = False
        for label in item['labels']:
            label = label.replace('.', '').replace('"', '')

            if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                false_positives.append(prediction)

            if prediction.lower() == label.lower():
                correct_answers += 1
                it_was_answered = True
                break
            elif prediction.lower() in label.lower():
                correct_answers += 1
                it_was_answered = True
                break
            elif label.lower() in prediction.lower():
                correct_answers += 1
                it_was_answered = True
                break
            else:
                wrong_predictions.append({'label': label, 'prediction': prediction})
        total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [28]:
false_positives = None
wrong_answers = None

for index in range(10):
    checkpoint = torch.load('save_small' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, wrong_answers, false_positives = compute_accuracy_of_model(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

100%|██████████| 1000/1000 [00:49<00:00, 20.24it/s]


Epoch: 0  with accuracy: 0.545


100%|██████████| 1000/1000 [00:51<00:00, 19.55it/s]


Epoch: 1  with accuracy: 0.569


100%|██████████| 1000/1000 [00:54<00:00, 18.44it/s]


Epoch: 2  with accuracy: 0.617


100%|██████████| 1000/1000 [00:52<00:00, 19.04it/s]


Epoch: 3  with accuracy: 0.577


100%|██████████| 1000/1000 [00:53<00:00, 18.56it/s]


Epoch: 4  with accuracy: 0.614


100%|██████████| 1000/1000 [00:57<00:00, 17.28it/s]


Epoch: 5  with accuracy: 0.634


100%|██████████| 1000/1000 [00:58<00:00, 17.05it/s]


Epoch: 6  with accuracy: 0.623


100%|██████████| 1000/1000 [00:58<00:00, 17.12it/s]


Epoch: 7  with accuracy: 0.639


100%|██████████| 1000/1000 [00:59<00:00, 16.70it/s]


Epoch: 8  with accuracy: 0.646


100%|██████████| 1000/1000 [01:00<00:00, 16.61it/s]

Epoch: 9  with accuracy: 0.626





In [None]:
wrong_answers

In [None]:
def clean_answer(answer):
    answer = answer.lower()
    answer = answer.strip()
    answer = answer.replace('.', '')
    return answer

In [None]:
def compute_accuracy_of_model_only_y_or_n(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            prediction = generate_answer(model, small_text)
            if clean_answer(prediction) not in ['yes', 'no']:
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [None]:
false_positives = None
wrong_answers = None

for index in range(10):
    checkpoint = torch.load('save_small' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, wrong_answers, false_positives = compute_accuracy_of_model_only_y_or_n(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

In [None]:
false_positives

In [None]:
wrong_answers

In [None]:
for index in range(2, 3):
    checkpoint = torch.load('save_superfine' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, _ = compute_accuracy_of_model(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

In [None]:
compute_accuracy_of_model(model)

In [None]:
def compute_prediction_list_from_model(model, limit=None):
    prediction_list = []
    dlist = dev_list[:limit]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                     max_num_questions=5,
                                     question_number=number,
                                     last_question=False)
            prediction = generate_answer(model, small_text)
            if not prediction:
                prediction = 'unknown'
            prediction_list.append({
                'id': dev_dict['data'][index]['id'],
                'turn_id': number,
                'answer': prediction,
            })

    return prediction_list

In [None]:
checkpoint = torch.load('save_small' + str(7))
model.load_state_dict(checkpoint['model_state_dict'])
prediction_list = compute_prediction_list_from_model(model)
json.dump(prediction_list, open('../data/predictions.json', 'w', encoding='utf8'))

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
checkpoint = torch.load('save_squad2_' + str(1))
model.load_state_dict(checkpoint['model_state_dict'])
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(0, 6):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    accuracy, _ = compute_accuracy_of_model(model)
    print('Epoch:', epoch, 'Loss:', loss)
    print('Epoch:', epoch, ' with accuracy:', accuracy)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_coqa_' + str(epoch))

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
checkpoint = torch.load('save_superfine'+str(0))
model.load_state_dict(checkpoint['model_state_dict'])
accuracy, _ = compute_accuracy_of_model(model)

In [None]:
accuracy

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
checkpoint = torch.load('save' + str(2))
model.load_state_dict(checkpoint['model_state_dict'])
accuracy, wrong_predictions = compute_accuracy_of_model(model)
wrong_predictions

In [None]:
accuracy

In [None]:
get_all_answers(dev_dict, 0)