In [3]:
from transformers import AutoTokenizer, CanineForMultipleChoice
import torch
import random
import numpy as np
from google.colab import drive

drive.mount('/content/drive')

tokenizer = AutoTokenizer.from_pretrained("google/canine-s")
model = CanineForMultipleChoice.from_pretrained("google/canine-s")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Some weights of CanineForMultipleChoice were not initialized from the model checkpoint at google/canine-s and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
model.to(device)

# Load the data
SP_all = np.load('/content/drive/My Drive/C5470prj/BrainTeaser/data/SP-train.npy', allow_pickle=True)
WP_all = np.load('/content/drive/My Drive/C5470prj/BrainTeaser/data/WP-train.npy', allow_pickle=True)

SP_train, SP_test = train_test_split(SP_all, test_size = 0.2, random_state=42)
SP_train, SP_val  = train_test_split(SP_train, test_size = 0.25, random_state=42)
WP_train, WP_test = train_test_split(WP_all, test_size = 0.2, random_state=42)
WP_train, WP_val  = train_test_split(WP_train, test_size = 0.25, random_state=42)

# train
num_epochs = 10

optimizer = optim.Adam(model.parameters(), lr=8e-5)
#optimizer = optim.RMSprop(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)
model.to(device)
model.train()

for epoch in range(num_epochs):
    shuffle(SP_train, random_state=epoch)
    shuffle(WP_train, random_state=epoch)
    total_loss = 0

    bad_qs_train = 0
    correct_answers_train = 0
    model.train()
    for data in np.concatenate((SP_test, WP_test)):
        optimizer.zero_grad()

        id = data['id']
        question = data['question']
        choice0 = data['choice_list'][0]
        choice1 = data['choice_list'][1]
        choice2 = data['choice_list'][2]
        choice3 = data['choice_list'][3]
        labels = torch.tensor(0).unsqueeze(0).to(device)  # choice i is correct, batch size 1
        encoding = tokenizer([question, question, question, question], [choice0, choice1, choice2, choice3], return_tensors="pt", padding=True)
        outputs = model(**{k: v.to(device).unsqueeze(0) for k, v in encoding.items()}, labels=labels)  # batch size is 1

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if data['answer'] not in data['choice_list']:
            bad_qs_train += 1
        else:
            correct_index = data['choice_list'].index(data['answer'])
            logits = outputs.logits
            predicted_answer = np.argmax(F.softmax(logits).cpu().detach().numpy())
            if predicted_answer == correct_index:
                correct_answers_train += 1
    scheduler.step()
    num_qs_train = len(np.concatenate((SP_train, WP_train)))

    #val
    model.eval()
    with torch.no_grad():
        total_loss_val = 0
        bad_qs_val = 0
        correct_answers_val = 0
        for data in np.concatenate((SP_val, WP_val)):
            id = data['id']
            question = data['question']
            choice0 = data['choice_list'][0]
            choice1 = data['choice_list'][1]
            choice2 = data['choice_list'][2]
            choice3 = data['choice_list'][3]
            labels = torch.tensor(0).unsqueeze(0).to(device)  # choice i is correct, batch size 1
            encoding = tokenizer([question, question, question, question], [choice0, choice1, choice2, choice3], return_tensors="pt", padding=True)
            outputs = model(**{k: v.to(device).unsqueeze(0) for k, v in encoding.items()}, labels=labels)  # batch size is 1

            loss = outputs.loss
            total_loss_val += loss.item()

            if data['answer'] not in data['choice_list']:
                bad_qs_val += 1
            else:
                correct_index = data['choice_list'].index(data['answer'])
                logits = outputs.logits
                predicted_answer = np.argmax(F.softmax(logits).cpu().detach().numpy())
                if predicted_answer == correct_index:
                    correct_answers_val += 1
    num_qs_val = len(np.concatenate((SP_val, WP_val)))

    print(f"Epoch: {epoch + 1}, Train Loss: {total_loss / num_qs_train}, Train Acc: {correct_answers_train / (num_qs_train - bad_qs_train)}")
    print(f"          Val Loss: {total_loss_val / num_qs_val}, Val Acc: {correct_answers_val / (num_qs_val - bad_qs_val)}")



cuda


  predicted_answer = np.argmax(F.softmax(logits).cpu().detach().numpy())
  predicted_answer = np.argmax(F.softmax(logits).cpu().detach().numpy())


Epoch: 1, Train Loss: 0.46922535653467534, Train Acc: 0.08719851576994433
          Val Loss: 1.3686579143144808, Val Acc: 0.4088397790055249
Epoch: 2, Train Loss: 0.45893661633685784, Train Acc: 0.07606679035250463
          Val Loss: 1.3862034932025888, Val Acc: 0.24861878453038674
Epoch: 3, Train Loss: 0.46759292615784537, Train Acc: 0.09461966604823747
          Val Loss: 1.385454057329926, Val Acc: 0.2430939226519337
Epoch: 4, Train Loss: 0.46483285162183974, Train Acc: 0.07235621521335807
          Val Loss: 1.179926059481518, Val Acc: 0.32044198895027626
Epoch: 5, Train Loss: 0.4302246452205711, Train Acc: 0.11317254174397032
          Val Loss: 1.3826506696847263, Val Acc: 0.281767955801105
Epoch: 6, Train Loss: 0.4444693381035769, Train Acc: 0.09647495361781076
          Val Loss: 1.357872890339372, Val Acc: 0.2154696132596685
Epoch: 7, Train Loss: 0.45104863985821053, Train Acc: 0.09461966604823747
          Val Loss: 1.388925182226613, Val Acc: 0.2541436464088398
Epoch: 8, T

In [3]:
# testing
model.eval()

correct_answers = 0
question_count = 0
SP_count = 0
SP_correct = 0
WP_count = 0
WP_correct = 0
wp = False
sp = False

#with torch.no_grad():
for data in np.concatenate((SP_test, WP_test)):
    wp = False
    sp = False

    id = data['id']
    if '_' in id: # exclude reconstructed questions during eval
        continue
    if "WP" in id:
        wp = True
    if "SP" in id:
        sp = True

    question_count += 1
    if wp:
      WP_count += 1
    if sp:
      SP_count += 1

    question = data['question']
    choice0 = data['choice_list'][0]
    choice1 = data['choice_list'][1]
    choice2 = data['choice_list'][2]
    choice3 = data['choice_list'][3]
    print("processing data id: " + id + " length: " + str(len(question)))
    labels = torch.tensor(0).unsqueeze(0).to(device)  # choice i is correct, batch size 1

    encoding = tokenizer([question, question, question, question], [choice0, choice1, choice2, choice3], return_tensors="pt", padding=True)
    outputs = model(**{k: v.to(device).unsqueeze(0) for k, v in encoding.items()}, labels=labels)  # batch size is 1

    #print(f"outputs: {outputs}" )
    # sanitize dataset
    if data['answer'] not in data['choice_list']:
        question_count -= 1
        if wp:
            WP_count -= 1
        if sp:
            SP_count -= 1
        continue

    correct_index = data['choice_list'].index(data['answer'])

    # the linear classifier still needs to be trained
    loss = outputs.loss
    logits = outputs.logits

    # Find the index of the correct answer in the choice order
    predicted_answer = np.argmax(F.softmax(logits).cpu().detach().numpy())
    if predicted_answer == correct_index:
        correct_answers += 1
        if wp:
            WP_correct += 1
        if sp:
            SP_correct += 1
    print(f'''\nevaluated data id: {id}, question: {question},\n predicted: {predicted_answer} {data['choice_list'][predicted_answer]}
correct: {correct_index} {data['choice_list'][correct_index]}\ntotal acc: {correct_answers / question_count}''')
    if SP_count != 0:
        print(f"SP acc: {SP_correct / SP_count} correct: {SP_correct}  count: {SP_count}")
    if WP_count != 0:
        print(f"WP acc: {WP_correct / WP_count} correct: {WP_correct}  count: {WP_count}")

# Calculate accuracy
accuracy = correct_answers / question_count
print(f"Total Accuracy: {accuracy:.2f}")

processing data id: SP-203 length: 49

evaluated data id: SP-203, question: How many birth days does the average person have?,
 predicted: 0 People may celebrate their birthdays annually, so it depends on their life span.
correct: 1 They technically only have one birth day in their lifetime.
total acc: 0.0
SP acc: 0.0 correct: 0  count: 1
processing data id: SP-29 length: 141

evaluated data id: SP-29, question: Four men were in a boat on the lake. The boat turns over, and all four men sink to the bottom of the lake, yet not a single man got wet! Why?,
 predicted: 0 The lake was frozen, none of the men got wet despite sinking to the bottom.
correct: 1 Because they were all married and not single.
total acc: 0.0
SP acc: 0.0 correct: 0  count: 2
processing data id: SP-185 length: 107

evaluated data id: SP-185, question: How could a man go outside in the pouring rain without protection, and not have a hair on his head get wet?,
 predicted: 0 
The man was lucky enough to avoid all the rai

  predicted_answer = np.argmax(F.softmax(logits).cpu().detach().numpy())



evaluated data id: SP-85, question: You are driving a bus. When you begin your route, there is an old woman named Mrs. Smith and a young boy named Raymond are on the bus. At the first stop, the old woman leaves, and a salesman, named Ed, enters. At the next stop, Jack and his sister Jill get on, as well as three women with shopping bags. The bus travels fifteen minutes, then stops and Raymond gets off and a man and his wife get on. Next, a woman with a bird in a cage gets on the bus. Who is bus driver?
,
 predicted: 0 Raymond.
correct: 3 None of above.
total acc: 0.3
SP acc: 0.3 correct: 3  count: 10
processing data id: SP-206 length: 153

evaluated data id: SP-206, question: There is a pink single-story house and everything in it is pink. The doors are pink, the windows are pink and the top is pink. What color are the stairs?,
 predicted: 0 There are no stairs in a single story house.
correct: 0 There are no stairs in a single story house.
total acc: 0.36363636363636365
SP acc: 0.363

##REsults
####

SP acc: 0.26666666666666666 correct: 44  count: 165
WP acc: 0.25 correct: 32  count: 128
Total Accuracy: 0.26

10 epoch, 5e5
SP acc: 0.25 correct: 5  count: 20
WP acc: 0.3333333333333333 correct: 4  count: 12
Total Accuracy: 0.28

5 epoch, 8e6, scheduler
SP acc: 0.3076923076923077 correct: 12  count: 39
WP acc: 0.26666666666666666 correct: 8  count: 30
Total Accuracy: 0.29

rms 0.01
SP acc: 0.28205128205128205 correct: 11  count: 39
WP acc: 0.36666666666666664 correct: 11  count: 30
Total Accuracy: 0.32



 #### TODO:

few-shot!