# Testing and comparing the probes for BERT and VisualBERT

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from models import Probe
from dataset import create_df, get_gold_data, get_bert_embedding_dict, get_visual_bert_embedding_dict, get_lists_and_dicts

In [None]:
device = 'cpu'

In [None]:
df = create_df('../data/affordance_annotations.txt')
shuffled_df = df.sample(frac=1, random_state=42).reset_index(drop=True)
unique_objects, unique_affordances, word_to_index, index_to_word = get_lists_and_dicts(df)
train_pairs = get_gold_data(shuffled_df[:42])
val_pairs = get_gold_data(shuffled_df[42:52])
test_pairs = get_gold_data(shuffled_df[52:])
bert_word_to_embedding = get_bert_embedding_dict([train_pairs + val_pairs + test_pairs])
visual_bert_word_to_embedding = get_visual_bert_embedding_dict([train_pairs + val_pairs + test_pairs])

In [None]:
baseline_dict_objects = dict.fromkeys(unique_objects, 0)
for index, row in df.iterrows():
        for i, value in enumerate(row):
            if type(value) == str:
                pass
            else:
                baseline_dict_objects[row[0]] += value
                
baseline_total_objects = 0
for k,v in baseline_dict_objects.items():
    baseline_dict_objects[k] = np.round((v * 100)/15, 2)
    baseline_total_objects += v

baseline_total_objects = np.round((baseline_total_objects/(15*62))*100,2)
print(f'{100-baseline_total_objects} %')

In [None]:
baseline_dict_affordances = dict.fromkeys(unique_affordances, 0)

for index, row in df.iterrows():
    for k in baseline_dict_affordances.keys():
        baseline_dict_affordances[k] += row[k]
        
baseline_total_affordances = 0
for k,v in baseline_dict_affordances.items():
    baseline_dict_affordances[k] = np.round((v * 100)/62, 2)
    baseline_total_affordances += v

baseline_total_affordances = np.round((baseline_total_affordances/(15*62))*100,2)
print(f'{100-baseline_total_objects} %')

## Testing the BERT Probe on test data

In [None]:
torch.manual_seed(0)
bert_probe = Probe().to(device)
bert_probe.load_state_dict(torch.load("../model_bert_probe|epochs_2000|batch_size_64|learning_rate_0.005"))

In [None]:
criterion = nn.NLLLoss()
test_data = [(bert_word_to_embedding[x], bert_word_to_embedding[y], z, word_to_index[x], word_to_index[y]) for x,y,z in test_pairs]
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [None]:
test_loss = 0
bert_probe.eval()

total = 0
correct = 0

per_word_total = dict.fromkeys(bert_word_to_embedding, 0)
per_word_correct = dict.fromkeys(bert_word_to_embedding, 0)

tp_bert = 0
fp_bert = 0
tn_bert = 0
fn_bert = 0

for i, batch in enumerate(test_dataloader):
    
    obj = batch[0]
    affordance = batch[1]
    target = batch[2]

    with torch.no_grad(): 
        
        output = bert_probe(obj, affordance)
        
        bert_loss = criterion(output, target)
        test_loss += bert_loss.item()

        # Calculate total accuracy
        total += len(batch[0])
        
        prediction = torch.argmax(output, dim=1)
        correct_predictions = torch.eq(prediction,target).long()
        correct += float(sum(correct_predictions))

        # Calculate per-object and per-affordance accuracy
        object_indices = batch[3].tolist()
        objects = [index_to_word[i] for i in object_indices]
        affordance_indices = batch[4].tolist()
        affordances = [index_to_word[i] for i in affordance_indices]
        
        for n,word in enumerate(objects):
            if prediction[n] == target[n]:
                per_word_correct[word] += 1
            per_word_total[word] += 1
            
        for n,word in enumerate(affordances):
            if prediction[n] == target[n]:
                per_word_correct[word] += 1
            per_word_total[word] += 1
            
        # Calculate tp,fp,tn,fn
        for i, value in enumerate(prediction.tolist()):
            if target.tolist()[i] == 1 and prediction.tolist()[i] == 1:
                tp_bert += 1
            elif target.tolist()[i] == 0 and prediction.tolist()[i] == 1:
                fp_bert += 1
            elif target.tolist()[i] == 1 and prediction.tolist()[i] == 0:
                fn_bert += 1
            elif target.tolist()[i] == 0 and prediction.tolist()[i] == 0:
                tn_bert += 1
        

        print('>', np.round(test_loss/(i+1), 4))

accuracy_bert_probe = correct / total
per_object_accuracy_bert_probe = {word : (per_word_correct[word] / per_word_total[word]) for word in unique_objects if per_word_total[word] > 0}
per_affordance_accuracy_bert_probe = {word : (per_word_correct[word] / per_word_total[word]) for word in unique_affordances if per_word_total[word] > 0}

print()
print(f'Total accuracy BERT probe: {np.round(accuracy_bert_probe * 100, 2)} %')
print()

print('Per-object accuracy BERT probe:')
for k,v in per_object_accuracy_bert_probe.items():
    print(f'{k} : {np.round(v * 100, 2)} %')
print()
    
print('Per-affordance accuracy BERT probe:')
for k,v in per_affordance_accuracy_bert_probe.items():
    print(f'{k} : {np.round(v * 100, 2)} %')

In [None]:
accuracy_bert = (tp_bert + tn_bert) / (tp_bert + fp_bert + tn_bert + fn_bert)
print(f'{np.round(accuracy_bert * 100, 2)}%')

In [None]:
precision_bert = tp_bert / (tp_bert + fp_bert)
print(f'{np.round(precision_bert * 100, 2)}%')

In [None]:
recall_bert = tp_bert / (tp_bert + fn_bert)
print(f'{np.round(recall_bert * 100, 2)}%')

In [None]:
f1_bert = (2 * recall_bert * precision_bert) / (recall_bert + precision_bert)
print(f'{np.round(f1_bert * 100, 2)}%')

## Testing the BERT Probe on seen objects

In [None]:
with torch.no_grad():
    for affordance in unique_affordances:
        output = bert_probe(bert_word_to_embedding['sickle'].unsqueeze(0), bert_word_to_embedding[affordance].unsqueeze(0))
        print('sickle, carving knife')
        print(f'{affordance}: {torch.argmax(output)}')
        output = bert_probe(bert_word_to_embedding['carving knife'].unsqueeze(0), bert_word_to_embedding[affordance].unsqueeze(0))
        print(f'{affordance}: {torch.argmax(output)}')
        print()
    
    for affordance in unique_affordances:
        output = bert_probe(bert_word_to_embedding['banjo'].unsqueeze(0), bert_word_to_embedding[affordance].unsqueeze(0))
        print('banjo, guitar')
        print(f'{affordance}: {torch.argmax(output)}')
        output = bert_probe(bert_word_to_embedding['guitar'].unsqueeze(0), bert_word_to_embedding[affordance].unsqueeze(0))
        print(f'{affordance}: {torch.argmax(output)}')
        print()
        
    for affordance in unique_affordances:
        output = bert_probe(bert_word_to_embedding['small boat'].unsqueeze(0), bert_word_to_embedding[affordance].unsqueeze(0))
        print('small boat, kayak')
        print(f'{affordance}: {torch.argmax(output)}')
        output = bert_probe(bert_word_to_embedding['kayak'].unsqueeze(0), bert_word_to_embedding[affordance].unsqueeze(0))
        print(f'{affordance}: {torch.argmax(output)}')
        print()

## Testing the VisualBERT Probe on test data

In [None]:
visual_bert_probe = Probe()
visual_bert_probe.load_state_dict(torch.load("model_visual_bert_probe|epochs_2000|batch_size_64|learning_rate_0.005"))

In [None]:
visual_bert_probe = visual_bert_probe.to(device)
test_loss = 0
criterion = nn.NLLLoss()
visual_bert_probe.eval()

total = 0
correct = 0

per_word_total = dict.fromkeys(visual_bert_word_to_embedding, 0)
per_word_correct = dict.fromkeys(visual_bert_word_to_embedding, 0)

tp_visual_bert = 0
fp_visual_bert = 0
tn_visual_bert = 0
fn_visual_bert = 0

for i, batch in enumerate(test_dataloader):
    
    obj = batch[2]
    affordance = batch[3]
    target = batch[4]

    with torch.no_grad(): 
        
        output = visual_bert_probe(obj, affordance)
        
        visual_bert_loss = criterion(output, target)
        test_loss += visual_bert_loss.item()

        # Calculate total accuracy
        total += len(batch[0])
        
        prediction = torch.argmax(output, dim=1)
        correct_predictions = torch.eq(prediction,target).long()
        correct += float(sum(correct_predictions))

        # Calculate per word accuracy
        object_indices = batch[5].tolist()
        objects = [index_to_word[i] for i in object_indices]
        affordance_indices = batch[6].tolist()
        affordances = [index_to_word[i] for i in affordance_indices]
        
        for n,word in enumerate(objects):
            if prediction[n] == target[n]:
                per_word_correct[word] += 1
            per_word_total[word] += 1
            
        for n,word in enumerate(affordances):
            if prediction[n] == target[n]:
                per_word_correct[word] += 1
            per_word_total[word] += 1
            
        # Calculate tp,fp,tn,fn
        for i, value in enumerate(prediction.tolist()):
            if target.tolist()[i] == 1 and prediction.tolist()[i] == 1:
                tp_visual_bert += 1
            elif target.tolist()[i] == 0 and prediction.tolist()[i] == 1:
                fp_visual_bert += 1
            elif target.tolist()[i] == 1 and prediction.tolist()[i] == 0:
                fn_visual_bert += 1
            elif target.tolist()[i] == 0 and prediction.tolist()[i] == 0:
                tn_visual_bert += 1

        print('>', np.round(test_loss/(i+1), 4))

accuracy_visual_bert_probe = correct / total
per_object_accuracy_visual_bert_probe = {word : (per_word_correct[word] / per_word_total[word]) for word in unique_objects if per_word_total[word] > 0}
per_affordance_accuracy_visual_bert_probe = {word : (per_word_correct[word] / per_word_total[word]) for word in unique_affordances if per_word_total[word] > 0}

print(f'Total accuracy VisualBERT probe: {np.round(accuracy_visual_bert_probe * 100, 2)} %')
print()

print('Per-object accuracy VisualBERT probe:')
for k,v in per_object_accuracy_visual_bert_probe.items():
    print(f'{k} : {np.round(v * 100, 2)} %')
print()

print('Per-affordance accuracy VisualBERT probe:')
for k,v in per_affordance_accuracy_visual_bert_probe.items():
    print(f'{k} : {np.round(v * 100, 2)} %')

In [None]:
accuracy_visual_bert = (tp_visual_bert + tn_visual_bert) / (tp_visual_bert + fp_visual_bert + tn_visual_bert + fn_visual_bert)
print(f'{np.round(accuracy_visual_bert * 100, 2)}%')

In [None]:
precision_visual_bert = tp_visual_bert / (tp_visual_bert + fp_visual_bert)
print(f'{np.round(precision_visual_bert * 100, 2)}%')

In [None]:
recall_visual_bert = tp_visual_bert / (tp_visual_bert + fn_visual_bert)
print(f'{np.round(recall_visual_bert * 100, 2)}%')

In [None]:
f1_visual_bert = (2 * recall_visual_bert * precision_visual_bert) / (recall_visual_bert + precision_visual_bert)
print(f'{np.round(f1_visual_bert * 100, 2)}%')

## Testing the VisualBERT Probe on seen data

In [None]:
with torch.no_grad():
    for affordance in unique_affordances:
        output = visual_bert_probe(visual_bert_word_to_embedding['sickle'].unsqueeze(0), visual_bert_word_to_embedding[affordance].unsqueeze(0))
        print('sickle, carving knife')
        print(f'{affordance}: {torch.argmax(output)}')
        output = visual_bert_probe(visual_bert_word_to_embedding['carving knife'].unsqueeze(0), visual_bert_word_to_embedding[affordance].unsqueeze(0))
        print(f'{affordance}: {torch.argmax(output)}')
        print()
    
    for affordance in unique_affordances:
        output = visual_bert_probe(visual_bert_word_to_embedding['banjo'].unsqueeze(0), visual_bert_word_to_embedding[affordance].unsqueeze(0))
        print('banjo, guitar')
        print(f'{affordance}: {torch.argmax(output)}')
        output = visual_bert_probe(visual_bert_word_to_embedding['guitar'].unsqueeze(0), visual_bert_word_to_embedding[affordance].unsqueeze(0))
        print(f'{affordance}: {torch.argmax(output)}')
        print()
        
    for affordance in unique_affordances:
        output = visual_bert_probe(visual_bert_word_to_embedding['small boat'].unsqueeze(0), visual_bert_word_to_embedding[affordance].unsqueeze(0))
        print('small boat, kayak')
        print(f'{affordance}: {torch.argmax(output)}')
        output = visual_bert_probe(visual_bert_word_to_embedding['kayak'].unsqueeze(0), visual_bert_word_to_embedding[affordance].unsqueeze(0))
        print(f'{affordance}: {torch.argmax(output)}')
        print()

# Comparison of the results