In [None]:
%reload_ext autoreload
%autoreload 2
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import math

import pytorch_lightning as pl
import torch
from munch import Munch
from torch.utils.data import DataLoader, random_split

from qsr_learning.data import DRLDataset
from qsr_learning.models import DRLNet

In [None]:
config = Munch()

# Dataset

In [None]:
from qsr_learning.entity import emoji_names

In [None]:
config.dataset = Munch(
    entity_names=["octopus", "trophy"],
    relation_names=["left_of", "right_of"],
    num_entities=2,
    frame_of_reference="intrinsic",
    w_range=(32, 32),
    h_range=(32, 32),
    theta_range=(0, 2 * math.pi),
    add_bbox=False,
    add_front=False,
    transform=None,
    canvas_size=(224, 224),
    num_samples=10 ** 5 + 10 ** 4 + 10 ** 4,
)

In [None]:
dataset = DRLDataset(**config.dataset)
train_dataset, validation_dataset, test_dataset = random_split(
    dataset,
    [10 ** 5, 10 ** 4, 10 ** 4],
    generator=torch.Generator().manual_seed(0),
)

# Data Loader

In [None]:
config.data_loader = Munch(
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

In [None]:
train_loader = DataLoader(train_dataset, **config.data_loader)
validation_loader = DataLoader(
    validation_dataset, **{**config.data_loader, "shuffle": False}
)

In [None]:
config.model = Munch(
    vision_model="resnet18",
    image_size=(3, *config.dataset.canvas_size),
    num_embeddings=len(dataset.word2idx),
    embedding_dim=10,
    question_len=dataset[0][1].shape.numel(),
)

model = DRLNet(**config.model)
lightning_checkpoint_path = (
    "lightning_logs/version_29/checkpoints/epoch=99-step=19599.ckpt"
)
model.load_state_dict(torch.load(lightning_checkpoint_path)["state_dict"])
model.eval();

## Evaluate on manual datasets

In [None]:
from pprint import pprint

from ipywidgets import interact
from PIL import Image

from qsr_learning.data import Question, draw_entities
from qsr_learning.entity import Entity
from qsr_learning.relation import above, below, left_of, right_of


@interact(
    frame_of_reference=(0, 1),
    x1=(0, 190),
    y1=(0, 190),
    theta1=(0, 360),
    x2=(0, 190),
    y2=(0, 190),
    theta2=(0, 360),
)
def test_spatial_relations(
    frame_of_reference=1, x1=64, y1=64, theta1=0, x2=128, y2=128, theta2=150
):
    canvas = Image.new("RGBA", (224, 224), (127, 127, 127, 127))
    entity1 = Entity(
        name="octopus",
        frame_of_reference={0: "absolute", 1: "intrinsic"}[frame_of_reference],
        p=(x1, y1),
        theta=theta1 / 360 * 2 * math.pi,
        size=(32, 32),
    )
    entity2 = Entity(
        name="trophy",
        frame_of_reference={0: "absolute", 1: "intrinsic"}[frame_of_reference],
        p=(x2, y2),
        theta=theta2 / 360 * 2 * math.pi,
        size=(32, 32),
    )
    image = draw_entities([entity1, entity2], add_bbox=True)
    background = Image.new("RGBA", image.size, (0, 0, 0))
    image = Image.alpha_composite(background, image).convert("RGB")
    display(image)
    image_t = dataset.transform(image)
    questions = []
    answers = []
    for relation in [right_of]:
        questions.append(Question(entity1.name, relation.__name__, entity2.name))
        answers.append(relation(entity1, entity2))
    #     for relation in dataset.relations:
    #         questions.append(Question(entity2.name, relation.__name__, entity1.name))
    #         answers.append(relation(entity2, entity1))
    for question, answer in zip(questions, answers):
        question_t = torch.tensor([dataset.word2idx[word] for word in question])
        answer_t = torch.tensor(answer)
        with torch.no_grad():
            pred_t = model(image_t.unsqueeze(0), question_t.unsqueeze(0))
        score = pred_t.sigmoid().item()
        pred = bool(pred_t.sigmoid().round())
        print(
            f"\n{question.head:7} {question.relation:8} {question.tail:7}\n\nGround Truth: {answer}\nPrediction  : {pred}\nScore       : {score:3.2f}\nCorrect     : {answer==pred:1}"
        )

## Display incorrect predictions

TODO: Getteng samples predicted incorrectly does not work yet.

In [None]:
from tqdm.auto import tqdm

device = torch.device("cuda")
idx_incorrect = []
model.to(device)
with tqdm(total=(len(validation_dataset) // config.data_loader.batch_size + 1)) as pbar:
    for (i, batch) in enumerate(validation_loader):
        batch_size = batch[0].shape[0]
        image = batch[0].to(device)
        question = batch[1].to(device)
        answer = batch[2].to(device)
        idx = torch.arange(i * batch_size, (i + 1) * batch_size)
        idx_incorrect.extend(
            idx[answer != model(image, question).sigmoid().round()].tolist()
        )
        pbar.update(1)
model.to(torch.device("cpu"))

from ipywidgets import interact
from PIL import Image

subset = validation_dataset


@interact(idx=(0, len(idx_incorrect) - 1))
def display_sample(idx=0):
    idx = idx_incorrect[idx]
    image_t, question_t, answer_t = subset[idx]
    with torch.no_grad():
        pred_t = model(image_t.unsqueeze(0), question_t.unsqueeze(0))
    image = Image.fromarray(
        (255 * (dataset.std.view(-1, 1, 1) * image_t + dataset.mean.view(-1, 1, 1)))
        .permute(1, 2, 0)
        .numpy()
        .astype("uint8")
    )
    head, relation, tail = question_t.tolist()
    question = (
        dataset.idx2word[head],
        dataset.idx2word[relation],
        dataset.idx2word[tail],
    )
    answer = bool(answer_t)
    pred = bool(pred_t.round())
    display(image)
    print(question)
    print("Ground truth: ", answer)
    print("Prediction: ", pred)