In [1]:
import sys
sys.path.append('./utils/')

In [2]:
from datasets import load_dataset, interleave_datasets, logging
from tqdm import tqdm
import numpy as np
import sys
import pandas as pd
from torch import nn
from collections import Counter
import torch
from transformers import AdamW
from transformers import BertTokenizer, BertForMaskedLM, BertModel, AutoModelWithLMHead, AutoModel, AutoTokenizer
from fewshot import truncate_text_for_prompt, evaluate, tokenizer_and_numericalize, get_accuracy
from prompt_model import TransformerForPrompting

logging.set_verbosity(logging.ERROR)


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# model = BertModel.from_pretrained('bert-base-uncased')
dataset = load_dataset("boolq")


device = torch.device("cuda:1") 

dt = 2000

for_prompt_training = dataset['train']
for_prompt_testing = dataset['validation']

In [3]:
tokenizer.convert_tokens_to_ids('true'), tokenizer.convert_tokens_to_ids('false'), tokenizer.convert_tokens_to_ids('yes'), tokenizer.convert_tokens_to_ids('no')

(2995, 6270, 2748, 2053)

In [4]:
%%time
train_ds = for_prompt_training.map(tokenizer_and_numericalize, fn_kwargs={'truncate_text_for_prompt': truncate_text_for_prompt, 'tokenizer': tokenizer}, num_proc=32)
train_ds.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])

test_ds = for_prompt_testing.map(tokenizer_and_numericalize, fn_kwargs={'truncate_text_for_prompt': truncate_text_for_prompt, 'tokenizer': tokenizer}, num_proc=32)
test_ds.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])

Token indices sequence length is longer than the specified maximum sequence length for this model (908 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (560 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (619 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (599 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (733 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for thi

CPU times: user 18.4 s, sys: 1.13 s, total: 19.6 s
Wall time: 22.2 s


In [5]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8, shuffle=True)
test_loader = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)

In [7]:
output_dim = len(set(dataset['train']['answer']))
transformer = BertForMaskedLM.from_pretrained('bert-base-uncased')
model = TransformerForPrompting(transformer, output_dim=output_dim, freeze=False)

loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)
model = model.to(device)
optim = AdamW(model.parameters(), lr=5e-5)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
epochs = 1

model.train()

acc_total = []
lm_loss_total = []
tl_total = []
for epoch in range(epochs):
    # setup loop with TQDM and dataloader
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        # initialize calculated gradients (from prev step)
        optim.zero_grad()
        # pull all tensor batches required for training
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        task_labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask)
        task_loss = loss_fn(outputs[0], task_labels)
        
        loss = task_loss
        accuracy = get_accuracy(outputs[0], task_labels)
        # calculate loss for every parameter that needs grad update
        loss.backward()
        # update parameters
        optim.step()
        
        acc_total.append(accuracy.item())
        tl_total.append(task_loss.item())
#         print(acc_total, lm_loss_total, tl_total)
        
        # print relevant info to progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(task_loss=np.mean(tl_total), accuracy = np.mean(acc_total))
        
    epoch_losses_tl, epoch_accs = evaluate(test_loader, model, loss_fn, device)
    print(f"Evaluation Accuracy: {np.mean(epoch_accs)}  Evaluation loss Task: {np.mean(epoch_losses_tl)}")
    

Epoch 0:  45%|████▍     | 530/1179 [04:09<05:06,  2.12it/s, accuracy=0.629, task_loss=0.659]


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), f'/hd2/prompting_models/imdb/fine_tune_with_correct_model_{dt}.pt')

In [18]:
model.load_state_dict(torch.load(f'/hd2/prompting_models/imdb/transformer_with_correct_model_{dt}.pt'))

<All keys matched successfully>

In [None]:
# epoch_losses_lm, epoch_losses_tl, epoch_accs = evaluate(test_loader, model, loss_fn, device)
# np.mean(epoch_accs)

In [None]:
# def pattern_verbalizer1(question, passage, answer, for_eval=False):
#     return f"Based on the following passage: {passage} {question}. {tokenizer.mask_token}"
    
# def pattern_verbalizer2(question, passage, answer, for_eval=False):
#     return f"From the following passage: {passage} {question}. {tokenizer.mask_token}" 
    
# def pattern_verbalizer3(question, passage, answer, for_eval=False):
#     return f"{passage} Question: {question}. Answer: {tokenizer.mask_token}" 
    
# def pattern_verbalizer4(question, passage, answer, for_eval=False):
#     return f"Question: {question}? Passage: {passage} Answer: {tokenizer.mask_token}" 

In [None]:
# def get_votes_for_pattern(ds = dataset['validation'], pattern_function=pattern_verbalizer1):
#     val1 = ds.map(tokenizer_and_numericalize, fn_kwargs={'pattern_func': pattern_function, 'truncate_text_for_prompt': truncate_text_for_prompt}, num_proc=32)
#     val1.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
#     val_data = val1.map(create_lm_labels, fn_kwargs={'subset': 'test'}, num_proc=32)
#     val_data.set_format(type='torch', columns=['label', 'input_ids', 'token_type_ids', 'attention_mask', 'lm_labels'], device=device)
#     val_loader = torch.utils.data.DataLoader(val_data, batch_size=16, shuffle=True)
#     res = []
#     dict_reverse = {'false': False, 'true': True, 'no': False, 'yes': True}
#     for example in tqdm(val_data):
#         with torch.no_grad():
#             input_ids = example['input_ids'].unsqueeze(0).to(device)
#             attn = example['attention_mask'].unsqueeze(0).to(device)
#             mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]
#             token_logits = model(input_ids, attn)[1]
#             mask_token_logits = token_logits[0, mask_token_index, :]
#             res.append((bool(example['label'].item()), dict_reverse[tokenizer.decode([mask_token_logits.squeeze().argmax()])]))
#     accuracy = sum(int(a == b) for a,b in res)/len(res)
#     return res, accuracy

In [None]:
# res1, accuracy1 = get_votes_for_pattern(pattern_function=pattern_verbalizer1)
# res2, accuracy2 = get_votes_for_pattern(pattern_function=pattern_verbalizer2)
# res3, accuracy3 = get_votes_for_pattern(pattern_function=pattern_verbalizer3)
# res4, accuracy4 = get_votes_for_pattern(pattern_function=pattern_verbalizer4)
# accuracy1, accuracy2, accuracy3, accuracy4

In [None]:
# cl1 = [j for i, j in res1]
# cl2 = [j for i, j in res2]
# cl3 = [j for i, j in res3]
# cl4 = [j for i, j in res4]
# truth = [i for i, j in res1]

# result = pd.DataFrame({'truth': truth, 'pattern1': cl1, 'pattern2': cl2, 'pattern3': cl3, 'pattern4': cl4 })
# result['majority'] = result.apply(lambda x: Counter([x['pattern1'], x['pattern2'], x['pattern3'], x['pattern4']]).most_common(1)[0][0],axis=1)
# acc = sum(result['truth']==result['majority'])/len(result)
# acc

In [None]:
acc