In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [3]:
dev_dict = json.load(open('../data/coqa-dev-v1.0.json', encoding='utf8'))

In [4]:
train_list = json.load(open('../data/qa_train_list.json', encoding='utf8'))
adversarial_list = json.load(open('../data/adversarial_claims.json', encoding='utf8'))[:200]
dev_list = json.load(open('../data/qa_dev_list.json', encoding='utf8'))

In [5]:
train_list = train_list + adversarial_list
random.shuffle(train_list)

In [6]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [7]:
_limit = 1024
data = []
total_trimmed = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        tokens = tokens[:, :_limit]
        total_trimmed += 1
    data.append(tokens)
print(f'Trimmed {total_trimmed} out of {len(train_list)}')

Token indices sequence length is longer than the specified maximum sequence length for this model (1030 > 1024). Running this sequence through the model will result in indexing errors


Trimmed 52 out of 7399


In [8]:
train_batches = batchify(data, 1)

In [9]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [10]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

def test(test_model, batches):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs, labels=inputs)[0]
        total_loss += loss.item()

    return total_loss / len(batches)

In [11]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(10):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'save_adversarial' + str(epoch))
    scheduler.step()

100%|██████████| 7399/7399 [14:45<00:00,  8.36it/s]


Epoch: 0 Loss: 2.381719716820689


100%|██████████| 7399/7399 [14:48<00:00,  8.33it/s]


Epoch: 1 Loss: 2.236366303100153


100%|██████████| 7399/7399 [14:49<00:00,  8.32it/s]


Epoch: 2 Loss: 2.143880487551704


100%|██████████| 7399/7399 [14:50<00:00,  8.31it/s]


Epoch: 3 Loss: 2.082029467556086


100%|██████████| 7399/7399 [14:51<00:00,  8.30it/s]


Epoch: 4 Loss: 2.0208299141918458


100%|██████████| 7399/7399 [14:51<00:00,  8.30it/s]


Epoch: 5 Loss: 1.9773719848322053


100%|██████████| 7399/7399 [14:51<00:00,  8.30it/s]


Epoch: 6 Loss: 1.9325691349547625


100%|██████████| 7399/7399 [14:52<00:00,  8.29it/s]


Epoch: 7 Loss: 1.8998581303714304


100%|██████████| 7399/7399 [14:31<00:00,  8.49it/s]


Epoch: 8 Loss: 1.8648103845894441


100%|██████████| 7399/7399 [14:54<00:00,  8.27it/s]


Epoch: 9 Loss: 1.840191237170288


In [12]:
_question_prompt = '\nQ: '
_answer_prompt = '\nA: '
    
def get_text_up_to_question_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    return text[0:pos + 1]
    
def get_answers_number(text):
    return text.count(_answer_prompt)

def get_answer_number(text, number):
    pos = text.find(_answer_prompt)
    for _ in range(number):
        pos = text.find(_answer_prompt, pos + 1)
    end = text.find('\n', pos + len(_answer_prompt))
    return text[pos + len(_answer_prompt):end]

def get_question_number(text, number):
    pos = text.find(_question_prompt)
    for _ in range(number):
        pos = text.find(_question_prompt, pos + 1)
    end = text.find('\n', pos + len(_question_prompt))
    return text[pos + len(_question_prompt):end]

def get_all_answers(dev_dict, dev_index):
    answers = [[item['input_text'] for item in dev_dict['data'][dev_index]['answers']]]
    answers += [[item['input_text'] for item in dev_dict['data'][dev_index]['additional_answers'][str(index)]] for index in range(3)]
    return [list(set([answers[j][i] for j in range(len(answers))])) for i in range(len(answers[0]))]

In [13]:
def generate_answer_number(model, text, number):
    prompt = get_text_up_to_question_number(text, number)
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 20
    tokens_length = tokens.shape[1]
    output = model.generate(
             tokens,
             max_length=tokens_length + _length,
             temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    return get_answer_number(output, number)

In [14]:
def compute_accuracy_of_model_with_all_answers(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    dlist = dev_list[:10]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):
        total_questions = get_answers_number(text)
        all_answers = get_all_answers(dev_dict, index)
        for number in range(total_questions):
            prediction = generate_answer_number(model, text, number)
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            if not it_was_answered:
                print('No Answer for number', number)
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions

In [15]:
def get_text_from_data_item(item, max_num_questions=0, question_number=-1, last_question=True):
    text = 'In the text below two people are discussing a story.\n\n'
    text += 'Story:\n' + item['story'] + '\n\n'
    text += 'Discussion:\n'
    text += '\n'.join(['Q: ' + q['input_text'] 
                       + '\nA: ' + a['input_text'] 
                       for q, a in zip(item['questions'][max(0,question_number-max_num_questions):question_number+1], 
                                       item['answers'][max(0,question_number-max_num_questions):question_number+1]) 
                      ])
    if not last_question:
        text = '\n'.join(text.split('\n')[:-1]) + '\n'
    return text

In [16]:
def generate_answer(model, prompt):
    model.eval()
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    output = model.generate(
             tokens.cuda(),
             max_length=tokens_length + _length,
             #temperature=0,
             pad_token_id=50256
    )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    offset = len(prompt)
    start = offset + 1
    end = output.find('\n', start)
    return output[start:end].split(':')[-1].strip()

In [17]:
def compute_accuracy_of_model(model):
    total_number_of_questions = 0
    correct_answers = 0
    wrong_predictions = []

    false_positives = []
    dlist = dev_list[:100]
    for index, text in tqdm(enumerate(dlist), total=len(dlist)):

        all_answers = get_all_answers(dev_dict, index)
        total_questions = len(all_answers)        
        
        for number in range(total_questions):
            small_text = get_text_from_data_item(dev_dict['data'][index], 
                                                 max_num_questions=5,
                                                 question_number=number,
                                                 last_question=False)
            prediction = generate_answer(model, small_text)
            if not prediction:
                print('NO PREDICTION!!')
                continue
            prediction = prediction.replace('.', '').replace('"', '')
            it_was_answered = False
            for label in all_answers[number]:
                label = label.replace('.', '').replace('"', '')

                if prediction.lower() != 'unknown' and label.lower() == 'unknown':
                    false_positives.append(prediction)
                
                if prediction.lower() == label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif prediction.lower() in label.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                elif label.lower() in prediction.lower():
                    correct_answers += 1
                    it_was_answered = True
                    break
                else:
                    wrong_predictions.append({'label': label, 'prediction': prediction})
            total_number_of_questions += 1

    return correct_answers/total_number_of_questions, wrong_predictions, false_positives

In [18]:
false_positives = None
wrong_answers = None

for index in range(10):
    checkpoint = torch.load('save_adversarial' + str(index))
    model.load_state_dict(checkpoint['model_state_dict'])
    accuracy, wrong_answers, false_positives = compute_accuracy_of_model(model.cuda())
    print('Epoch:', index, ' with accuracy:', accuracy)

 27%|██▋       | 27/100 [03:21<10:52,  8.94s/it]

NO PREDICTION!!
NO PREDICTION!!


 64%|██████▍   | 64/100 [08:07<04:05,  6.83s/it]

NO PREDICTION!!


 80%|████████  | 80/100 [10:02<02:28,  7.40s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:28<00:00,  7.48s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 0  with accuracy: 0.64914502849905


 27%|██▋       | 27/100 [03:19<10:48,  8.89s/it]

NO PREDICTION!!


 64%|██████▍   | 64/100 [08:04<04:04,  6.79s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:23<00:00,  7.44s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 1  with accuracy: 0.6805819101834282


 27%|██▋       | 27/100 [03:19<10:48,  8.89s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:47<00:00,  7.68s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 2  with accuracy: 0.690897597977244


 27%|██▋       | 27/100 [03:25<10:50,  8.91s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:36<00:00,  7.56s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 3  with accuracy: 0.6883691529709229


 27%|██▋       | 27/100 [03:19<10:47,  8.87s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:24<00:00,  7.45s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 4  with accuracy: 0.6978508217446271


 27%|██▋       | 27/100 [03:21<10:54,  8.97s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:28<00:00,  7.49s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 5  with accuracy: 0.6972187104930467


 27%|██▋       | 27/100 [03:20<10:49,  8.90s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:25<00:00,  7.46s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 6  with accuracy: 0.7092288242730721


 27%|██▋       | 27/100 [03:22<10:56,  8.99s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:30<00:00,  7.51s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 7  with accuracy: 0.7073324905183312


 27%|██▋       | 27/100 [03:21<10:53,  8.95s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:29<00:00,  7.50s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 8  with accuracy: 0.7035398230088495


 27%|██▋       | 27/100 [03:22<10:56,  9.00s/it]

NO PREDICTION!!


100%|██████████| 100/100 [12:29<00:00,  7.50s/it]

Epoch: 9  with accuracy: 0.6991150442477876





## Results

with 500 adv: epoch 7, accuracy 0.7079646017699115

with 1000 adv: Epoch 6, 0.7010113780025284
 
with 200 adv: Epoch 6, 0.7092288242730721

