In [1]:
import torch
from FPdataset import fixationPredictionDataset, collate_batch
import random
from IPython.core.display import display, HTML
from torch.utils.data import DataLoader
import torch.nn.functional as F

FPdataset = fixationPredictionDataset('TRT', average=True)

model, mapping = torch.load('./FPmodels/model-%s-%s-uniD.pt'%('simpleLSTM', FPdataset.feature))
model.eval()

BATCH_SIZE=32 
device='cpu'
test_loader = DataLoader(FPdataset.test_set, BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

with torch.no_grad():
    test_loss = []
    test_acc = []
    test_acc1 = []
    for i, (input_ids, labels) in enumerate(test_loader):
        input_ids = input_ids.to(device=device)
        labels = labels.to(device=device)

        pred, _ =  model(input_ids)
        pred = pred[input_ids != FPdataset.mapping.pad_id]
        labels = labels[input_ids != FPdataset.mapping.pad_id]

        loss = F.l1_loss(pred, labels.float())
        test_loss.append(loss.item())

        acc = (torch.round(pred).long() == labels).float().mean()
        test_acc.append(acc.item())

        accWithin1 = ((torch.round(pred) <= labels.float()+1) & (torch.round(pred) >= labels.float()-1) ).float().mean()
        test_acc1.append(accWithin1.item())

    test_loss = sum(test_loss) / len(test_loss)
    test_acc = sum(test_acc) / len(test_acc)
    test_acc1 = sum(test_acc1) / len(test_acc1)

print('test loss: %.4f  test acc: %.4f  test acc (+-1): %.4f '%(test_loss, test_acc, test_acc1))
    

test loss: 1.5112  test acc: 0.3056  test acc (+-1): 0.5936 


In [4]:
def highlighter(word, value):
    return '<span style="background-color:rgba(40,116,166,%.2f)">'%(value/20) +word+ '</span>' 

html = ''
for n in range(5):
    with torch.no_grad():
        i = random.randint(0, len(FPdataset.test_set)-1)
        input_tokens = FPdataset.test_set.input_tokens[i]
        labels = FPdataset.test_set.labels[i]
        input_ids = torch.LongTensor(FPdataset.test_set.input_ids[i]).unsqueeze(0)
        pred = torch.round(model(input_ids)[0].squeeze()).long().tolist()
    print(pred)
    print(labels)

    html += '<h5>Prediction: </h5>' + ' '.join([highlighter(w,v) for w,v in zip(input_tokens, pred)])
    html += '<h5>True: </h5>' + ' '.join([highlighter(w,v) for w,v in zip(input_tokens, labels)])
    html += '<hr>'

display(HTML(html))


[7, 8, 10, 6, 5, 2, 9, 7, 3, 5, 5, 8, 7, 3, 10, 2, 0]
[6, 6, 11, 7, 4, 5, 8, 7, 5, 7, 6, 7, 8, 6, 10, 8, 0]
[2, 8, 10, 1, 4, 4, 9, 0, 9, 3, 4, 10, 1, 4, 10, 0, 10, 2, 7, 8, 0]
[2, 5, 10, 2, 4, 2, 4, 0, 10, 1, 2, 10, 0, 3, 10, 0, 10, 2, 6, 4, 0]
[2, 5, 6, 1, 2, 3, 7, 6, 6, 1, 6, 0, 9, 9, 7, 0, 2, 6, 2, 3, 5, 1, 2, 6, 0]
[2, 0, 2, 2, 0, 1, 8, 4, 8, 2, 8, 0, 11, 11, 9, 0, 4, 3, 0, 3, 1, 0, 1, 5, 0]
[2, 8, 0, 5, 4, 7, 2, 3, 6, 2, 6, 9, 3, 5, 7, 6, 1, 5, 0, 2, 7, 3, 9, 0, 5, 6, 10, 0, 2, 10, 2, 5, 6, 0]
[2, 11, 0, 3, 5, 9, 5, 0, 5, 0, 7, 9, 3, 5, 10, 5, 0, 7, 0, 4, 9, 2, 10, 0, 7, 3, 8, 0, 0, 11, 0, 6, 8, 0]
[0, 8, 4, 3, 3, 6, 2, 3, 5, 4, 3, 6, 3, 9, 9, 2, 5, 9, 0, 11, 1, 3, 0, 0, 3, 10, 9, 4, 6, 9, 4, 2, 9, 9, 0, 0, 1, 0, 7, 10, 1, 6, 5, 6, 2, 3, 7, 0, 0, 0, 2, 10, 3, 7, 0, 6, 6, 9, 3, 4, 3, 10, 1, 4, 8, 0, 0, 0, 10, 6, 0, 0, 0, 8, 8, 6, 4, 6, 3, 3, 0, 0, 0, 2, 8, 8, 2, 3, 2, 9, 0, 6, 9, 10, 9, 1, 0, 0, 0, 6, 0, 7, 10, 10, 9, 0, 3, 0, 6, 7, 0, 8, 0, 5, 10, 1, 3, 8, 7, 10, 6, 10, 0, 2, 9, 0