In [1]:
import torch
import json
import numpy as np

from PIL import Image
from utility import ViltImageSetProcessor
from transformers import ViltImageProcessor, BertTokenizer
from models import MultiviewViltForQuestionAnswering, MultiviewViltModel
from torch import nn
from copy import deepcopy
from isvqa_data_setup import ISVQA
from torch.utils.data import DataLoader
from engine import max_to_one_hot
from collections import Counter

In [2]:
isvqa = ISVQA("/home/nikostheodoridis/isvqa/train_set.json",
              "/home/nikostheodoridis/nuscenes/samples",
              "/home/nikostheodoridis/isvqa/answers.json")



In [3]:
model = MultiviewViltModel(6, 210, 768, True, True, True).to("cuda")

In [4]:
loader = DataLoader(isvqa, 1, shuffle=False)

In [5]:
batch = next(iter(loader))

In [7]:
model(**batch[0]).last_hidden_state.shape

torch.Size([1, 1300, 768])

In [3]:
isvqa[0]

(OrderedDict([('input_ids',
               tensor([ 101, 2024, 2045, 5581, 5563, 2006, 1996, 2217, 1997, 1996, 2395,  102,
                          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
                          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
                          0,    0,    0,    0], device='cuda:0')),
              ('token_type_ids',
               tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')),
              ('attention_mask',
               tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')),
              ('pixel_values',
               tensor([[[[-0.2549, -0.2471, -0.2392,  ...,  0.5451,  0.5059,  0.4824],
                         [-0.2627, -0.2549, -0.2392,  ...,  0.55

In [16]:
def val_step(model, loader, acc_fn, answ_len):
    """
    A function that validates the model by going through all the mini-batches in the validation dataloader once.
    """
    print("\tValidating...")
    model.eval()
    losses = []  # to save the loss of each mini-batch in order to take their average at the end
    accuracies = []  # to save the accuracy of each mini-batch in order to take their average at the end

    predictions = []
    

    with torch.inference_mode():
        for i, (X, y) in enumerate(loader):
            outputs = model(**X, labels=y)
            loss = outputs.loss
            pred = max_to_one_hot(outputs.logits)
            acc = acc_fn(pred, y, answ_len)

            losses.append(loss.item())
            accuracies.append(acc)

    avg_loss = sum(losses) / len(loader)
    avg_acc = sum(accuracies) / len(loader)
    
    return avg_loss, avg_acc

# Copy

In [2]:
model = MultiviewViltForQuestionAnswering(6, 210, 768, True, False, False).to("cuda")



In [3]:
model.model.classifier = nn.Sequential(
        nn.Linear(768, 1536),
        nn.LayerNorm(1536),
        nn.GELU(),
        nn.Linear(1536, 429)
    ).to("cuda")

In [4]:
trained_model = deepcopy(model)

In [5]:
trained_model.load_state_dict(torch.load("/home/nikostheodoridis/Trained Models/2024-07-08 00:07:49/model.pth"))

<All keys matched successfully>

In [6]:
# for p1, p2 in zip(model.parameters(), trained_model.parameters()):
#     assert torch.equal(p1, p2)

In [7]:
val_set = ISVQA(qa_path="/home/nikostheodoridis/isvqa/val_set.json",
                nuscenes_path="/home/nikostheodoridis/nuscenes/samples",
                answers_path="/home/nikostheodoridis/isvqa/answers.json")

In [8]:
val_loader = DataLoader(val_set, batch_size=6, shuffle=False)

In [9]:
# targets = []
# untrained_predictions = []
# trained_predictions = []

# model.eval()
# trained_model.eval()
# for i in range(2576):
#     inputs, target = val_set[i]

#     targets.append(target)

#     with torch.inference_mode():
        


In [10]:
def accuracy(predictions: torch.Tensor, targets: torch.Tensor, answers_len: int) -> float:
    cnt = torch.eq(torch.eq(predictions, targets).sum(dim=1), answers_len).sum()
    return cnt.item() / len(predictions)

In [11]:
untrained_loss, untrained_acc = val_step(model, val_loader, accuracy, 429)

	Validating...


In [14]:
trained_loss, trained_acc = val_step(trained_model, val_loader, accuracy, 429)

	Validating...


In [13]:
print(untrained_loss)
print(untrained_acc)

303.8099614342978
0.0011627906976744186


In [15]:
print(trained_loss)
print(trained_acc)

2.5099486532945967
0.6003875968992246


In [20]:
with open("/home/nikostheodoridis/isvqa/answers_counter.json") as f:
    answers_cnt = json.load(f)

In [35]:
Counter(answers_cnt)

Counter({'yes': 13564,
         'no': 3734,
         'one': 3663,
         'white': 2893,
         'two': 2544,
         'red': 1205,
         'black': 1046,
         'blue': 1025,
         'three': 986,
         'green': 968,
         'yellow': 930,
         'orange': 782,
         'four': 529,
         'night': 464,
         'rainy': 434,
         'gray': 365,
         'black and white': 345,
         'silver': 287,
         'zero': 254,
         'five': 218,
         'orange and white': 178,
         'six': 156,
         'left': 151,
         'none': 147,
         'ahead': 147,
         'right': 141,
         'fedex': 136,
         'brown': 134,
         'cloudy': 129,
         'slow': 119,
         'ups': 114,
         'bus': 112,
         'raining': 111,
         'wet': 111,
         'sunny': 107,
         'ryder': 95,
         'urban': 93,
         'twenty-three': 92,
         'stop': 89,
         'day': 77,
         'hump': 69,
         'brick': 69,
         'red and white': 57,

In [4]:
# Start
import json
import random
train_path = "/home/nikostheodoridis/nuscenes-qa/train_set.json"

val_path = "/home/nikostheodoridis/nuscenes-qa/val_set.json"

test_path = "/home/nikostheodoridis/nuscenes-qa/test_set.json"
with open(train_path) as f:
    train_data = json.load(f)

with open(val_path) as f:
    val_data = json.load(f)

with open(test_path) as f:
    test_data = json.load(f)

for data in train_data:
    if data in test_data:
        print("False")
        break
else:
    print("True")

True
