In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
import ipywidgets as widgets

In [3]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")

In [4]:
train_dict = json.load(open('../data/qa_train_BART.json', encoding='utf8'))
random.shuffle(train_dict)
train_dict = train_dict[:10000]
dev_dict = json.load(open('../data/qa_dev_BART.json', encoding='utf8'))

In [5]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        in_length = item[0].shape[1]
        out_length = item[1].shape[1]
        length = (in_length, out_length)
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0][0] for item in chunk])
        labels = torch.stack([item[1][0] for item in chunk])
        batches.append((inputs, labels))

    return batches

In [6]:
_limit = 1024
data = []
total_skipped = 0
for item in train_dict:
    input_tokens = tokenizer.encode(item['inputs'], return_tensors='pt')
    output_tokens = tokenizer.encode(item['labels'], return_tensors='pt')
    if input_tokens.shape[1] > _limit:
        input_tokens = input_tokens[:, :_limit]
        total_skipped += 1
        continue
    if output_tokens.shape[1] > _limit:
        output_tokens = output_tokens[:, :_limit]
        total_skipped += 1
        continue
    data.append((input_tokens, output_tokens))
print(f'Skipped {total_skipped} out of {len(train_dict)}')

Token indices sequence length is longer than the specified maximum sequence length for this model (1090 > 1024). Running this sequence through the model will result in indexing errors


Skipped 16 out of 10000


In [7]:
random.shuffle(data)

In [8]:
train_batches = batchify(data, 1)

In [9]:
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [10]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch[0]
        labels = batch[1]
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=labels.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

def test(test_model, batches):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs, labels=inputs)[0]
        total_loss += loss.item()

    return total_loss / len(batches)

In [11]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(10):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_small' + str(epoch))
    scheduler.step()

  2%|▏         | 237/9984 [00:40<27:57,  5.81it/s]


RuntimeError: CUDA out of memory. Tried to allocate 198.00 MiB (GPU 0; 10.76 GiB total capacity; 8.91 GiB already allocated; 196.00 MiB free; 9.36 GiB reserved in total by PyTorch)
Exception raised from malloc at /pytorch/c10/cuda/CUDACachingAllocator.cpp:272 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f8d2a19d1e2 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1e64b (0x7f8d2a3f364b in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x1f464 (0x7f8d2a3f4464 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0x1faa1 (0x7f8d2a3f4aa1 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #4: at::native::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0x11e (0x7f8cdbd3890e in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0xf33949 (0x7f8cda172949 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0xf4d777 (0x7f8cda18c777 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x10e9c7d (0x7f8d14f28c7d in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x10e9f97 (0x7f8d14f28f97 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #9: at::empty(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0xfa (0x7f8d15033a1a in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::native::mm_cuda(at::Tensor const&, at::Tensor const&) + 0x6c (0x7f8cdb227ffc in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #11: <unknown function> + 0xf22a20 (0x7f8cda161a20 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #12: <unknown function> + 0xa56530 (0x7f8d14895530 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #13: at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, at::Tensor const&, at::Tensor const&) const + 0xbc (0x7f8d1507d81c in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #14: at::mm(at::Tensor const&, at::Tensor const&) + 0x4b (0x7f8d14fce6ab in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x2ed0a2f (0x7f8d16d0fa2f in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0xa56530 (0x7f8d14895530 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #17: at::Tensor c10::Dispatcher::call<at::Tensor, at::Tensor const&, at::Tensor const&>(c10::TypedOperatorHandle<at::Tensor (at::Tensor const&, at::Tensor const&)> const&, at::Tensor const&, at::Tensor const&) const + 0xbc (0x7f8d1507d81c in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #18: at::Tensor::mm(at::Tensor const&) const + 0x4b (0x7f8d15163cab in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #19: <unknown function> + 0x2d11fbb (0x7f8d16b50fbb in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #20: torch::autograd::generated::MmBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x25f (0x7f8d16b6c7df in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #21: <unknown function> + 0x3375bb7 (0x7f8d171b4bb7 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #22: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x1400 (0x7f8d171b0400 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #23: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x451 (0x7f8d171b0fa1 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #24: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x89 (0x7f8d171a9119 in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #25: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x4a (0x7f8d2af4f4ba in /home/alce/src/tacred_rel/.env/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #26: <unknown function> + 0xbd6df (0x7f8d46d746df in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #27: <unknown function> + 0x76db (0x7f8d4c9216db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #28: clone + 0x3f (0x7f8d4cc5aa3f in /lib/x86_64-linux-gnu/libc.so.6)


In [None]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    answers += [[item['input_text'] for item in dev_dict['data'][dev_index]['additional_answers'][str(index)]] for index in range(3)]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [None]:
def generate_answer_number(model, text, number):
    prompt = get_text_up_to_question_number(text, number)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 20
    tokens_length = tokens.shape[1]
    output = model.generate(
             tokens,
             max_length=tokens_length + _length,
             temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_number(output, number)

In [None]:
def compute_accuracy_of_model_with_all_answers(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            prediction = generate_answer_number(model, text, number)
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            if not it_was_answered:
                print('No Answer for number', number)
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions

In [None]:
def get_text_from_data_item(item, question_number=-1, last_question=True):
    text = ''
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][:question_number+1], 
                                       item['answers'][:question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [None]:
def generate_answer(model, prompt):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             #temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return output

In [None]:
def generate_multiple_answers(model, prompt, num_replicas=6):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    
    outputs = []
    for _ in range(num_replicas):
        output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             pad_token_id=50256
        )
        output = tokenizer.decode(output[0], skip_special_tokens=True)
        offset = len(prompt)
        start = offset + 1
        end = output.find('\n', start)
        outputs.append(output[start:end].split(':')[-1].strip())
        
    return list(set(outputs))

In [None]:
checkpoint = torch.load('save_small' + str(1))
model.load_state_dict(checkpoint['model_state_dict'])
_ = model.train()

In [None]:
doc=0
number = 6
generate_answer(model, dev_dict[doc + number]['inputs'])

In [None]:
def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_dict[:1000]
    for item in tqdm(dlist, total=len(dlist)):
        small_text = item['inputs']
        prediction = generate_answer(model, small_text)
        if not prediction:
            print('NO PREDICTION!!')
            continue
        prediction = prediction.replace('.', '').replace('"', '')
        it_was_answered = False
        for label in item['labels']:
            label = label.replace('.', '').replace('"', '')

            if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                false_positives.append(prediction)

            if prediction.lower() == label.lower():
                correct_answers += 1
                it_was_answered = True
                break
            elif prediction.lower() in label.lower():
                correct_answers += 1
                it_was_answered = True
                break
            elif label.lower() in prediction.lower():
                correct_answers += 1
                it_was_answered = True
                break
            else:
                wrong_predictions.append({'label': label, 'prediction': prediction})
        total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [None]:
false_positives = None
wrong_answers = None

for index in range(10):
    checkpoint = torch.load('save_small' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, wrong_answers, false_positives = compute_accuracy_of_model(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

In [None]:
wrong_answers

In [None]:
def clean_answer(answer):
    answer = answer.lower()
    answer = answer.strip()
    answer = answer.replace('.', '')
    return answer

In [None]:
def compute_accuracy_of_model_only_y_or_n(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            prediction = generate_answer(model, small_text)
            if clean_answer(prediction) not in ['yes', 'no']:
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [None]:
false_positives = None
wrong_answers = None

for index in range(10):
    checkpoint = torch.load('save_small' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, wrong_answers, false_positives = compute_accuracy_of_model_only_y_or_n(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

In [None]:
false_positives

In [None]:
wrong_answers