In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import os

import torch
from ipywidgets import interact
# from qsr_learning.data.data import draw, generate_objects

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Test the `Entity` Class

In [None]:
import math

from PIL import Image
from qsr_learning.entity import Entity

In [None]:
canvas = Image.new("RGBA", (224, 224), (127, 127, 127, 127))
entity1 = Entity(
    name="octopus",
    frame_of_reference="absolute",
    p=(30, 30),
    theta=0 / 360 * 2 * math.pi,
    size=(32, 32),
)
entity1.draw(canvas, show_bbox=True, orientation_marker=True)
entity2 = Entity(
    name="trophy",
    frame_of_reference="absolute",
    p=(60, 60),
    theta=90 / 360 * 2 * math.pi,
    size=(32, 32),
)
entity2.draw(canvas, show_bbox=True, orientation_marker=True)

# Test the Relations

In [None]:
from ipywidgets import interact
from qsr_learning.relation import above, below, left_of, right_of


@interact(
    frame_of_reference=(0, 1),
    x1=(0, 150),
    y1=(0, 150),
    theta1=(0, 360),
    x2=(0, 150),
    y2=(0, 150),
    theta2=(0, 360),
)
def test_spatial_relations(
    frame_of_reference=0, x1=64, y1=64, theta1=0, x2=128, y2=128, theta2=150
):
    canvas = Image.new("RGBA", (224, 224), (127, 127, 127, 127))
    entity1 = Entity(
        name="octopus",
        frame_of_reference={0: "absolute", 1: "intrinsic"}[frame_of_reference],
        p=(x1, y1),
        theta=theta1 / 360 * 2 * math.pi,
        size=(32, 32),
    )
    entity1.draw(canvas, show_bbox=True, orientation_marker=True)
    entity2 = Entity(
        name="trophy",
        frame_of_reference={0: "absolute", 1: "intrinsic"}[frame_of_reference],
        p=(x2, y2),
        theta=theta2 / 360 * 2 * math.pi,
        size=(32, 32),
    )
    entity2.draw(canvas, show_bbox=True, orientation_marker=True)
    display(canvas)

    for relation in [left_of, right_of, above, below]:
        if relation(entity1, entity2):
            print(entity1.name, relation.__name__, entity2.name)
    for relation in [left_of, right_of, above, below]:
        if relation(entity2, entity1):
            print(entity2.name, relation.__name__, entity1.name)

# Test the Samples

In [None]:
from qsr_learning.data import draw_entities, generate_entities, generate_questions
from qsr_learning.entity import emoji_names
from qsr_learning.relation import above, below, left_of, right_of

entities = generate_entities(
    entity_names=emoji_names,  # ["octopus", "trophy"],
    num_entities=2,
    frame_of_reference="absolute",
    w_range=(16, 64),
    h_range=(16, 64),
)
relations = [left_of, right_of, above, below]
image = draw_entities(entities, show_bbox=True, orientation_marker=False)
display(image)
positive_questions, negative_questions = generate_questions(entities, relations)
display(positive_questions)
display(negative_questions)

# Test the Dataset

In [None]:
import torch
import torchvision
from qsr_learning.data import DRLDataset, get_mean_and_std
from torch.utils.data import DataLoader

entity_names = ["octopus", "trophy"]
relation_names = ["left_of", "right_of", "above", "below"]
num_entities = 2
frame_of_reference = "absolute"


drl_dataset = DRLDataset(
    entity_names=["octopus", "trophy"],
    relation_names=["left_of", "right_of"],
    num_entities=2,
    frame_of_reference="absolute",
    num_samples=8,
    w_range=(32, 32),
    h_range=(32, 32),
    theta_range=(0.0, 2 * math.pi),
    random_seed=42
)

loader = DataLoader(drl_dataset, batch_size=8)

image, questions, answers = drl_dataset.generate_scene()
display(image)


def tensor2question(question, drl_dataset):
    e1, r, e2 = question
    return drl_dataset.idx2ent[e1], drl_dataset.idx2rel[r], drl_dataset.idx2ent[e2]


def tensor2answer(answer):
    return bool(answer)


list(
    zip(
        [tensor2question(question, drl_dataset) for question in questions],
        [tensor2answer(answer) for answer in answers],
    )
)

# Test a Trained Model

In [None]:
import math

import torch
import torch.nn as nn
import torchvision
from munch import Munch
from torch.utils.data import DataLoader

from qsr_learning.data import DRLDataset, get_mean_and_std
from qsr_learning.entity import emoji_names
from qsr_learning.models import HadarmardFusionNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def step(model, criterion, optimizer, phase, batch, result, freeze, device):
    if phase == "train":
        model.train()
        if freeze == "all":
            model.image_encoder.eval()
        if freeze == "bn":  # Common in training object detection models
            for layer in model.modules():
                if isinstance(layer, nn.BatchNorm2d):
                    layer.eval()
    else:
        model.eval()
    torch.autograd.set_grad_enabled(phase == "train")
    images, questions, answers = (item.to(device) for item in batch)
    batch_size = images.shape[0]
    model.zero_grad()
    out = model(images, questions)
    loss = criterion(out, answers.float()) / batch_size
    result[phase].total_loss += loss.item()
    result[phase].num_correct += ((out > 0.5) == answers).sum().item()
    if phase == "train":
        loss.backward()
        optimizer.step()


def report_result(epoch, phases, result, data_loader):
    log = dict(epoch=epoch)
    for phase in phases:
        log[phase + "_loss"] = result[phase].total_loss / len(
            data_loader[phase].dataset
        )
        log[phase + "_accuracy"] = result[phase].num_correct / len(
            data_loader[phase].dataset
        )
    # tune.report(**log)
    print(log)


config = Munch(
    data=Munch(
        train=Munch(
            #             entity_names=emoji_names,
            entity_names=["octopus", "trophy"],
            relation_names=["left_of", "right_of"],
            num_entities=2,
            frame_of_reference="absolute",
            w_range=(32, 32),
            h_range=(32, 32),
            theta_range=(0, 0),
            num_samples=2 ** 10,
            shuffle=True,
        ),
        validation=Munch(
            entity_names=["octopus", "trophy"],
            relation_names=["left_of", "right_of"],
            num_entities=2,
            frame_of_reference="absolute",
            w_range=(32, 32),
            h_range=(32, 32),
            theta_range=(0, 0),  # theta_range=(0, 2 * math.pi),
            num_samples=2 ** 10,
            shuffle=True,
        ),
    ),
    model=Munch(
        ent_dim=10,
        rel_dim=10,
        cnn_model="resnet18",
        pretrained=False,
    ),
    train=Munch(batch_size=128, num_epochs=100, freeze="all", lr=0.001),
)
phases = ["train", "validation"]
datasets = Munch({phase: DRLDataset(**config.data[phase]) for phase in phases})
data_loader = Munch(
    {
        phase: DataLoader(
            datasets[phase],
            batch_size=config.train.batch_size,
            num_workers=4,
        )
        for phase in phases
    }
)

for batch in data_loader["train"]:
    break


model_state_dict = torch.load("model.pt")
model = HadarmardFusionNet(datasets.train, **config.model)
model.to(device)
model.load_state_dict(model_state_dict)

In [None]:
# Test Performance

model_state_dict = torch.load('model.pt')
model = HadarmardFusionNet(datasets.train, **config.model)
model.to(device)
model.load_state_dict(model_state_dict)
criterion = nn.BCELoss(reduction="sum")

result = Munch()
phases = ["validation"]
for phase in phases:
    result[phase] = Munch()
    result[phase].total_loss = 0
    result[phase].num_correct = 0
    for batch in data_loader[phase]:
        step(
            model,
            criterion,
            None,
            phase,
            batch,
            result,
            config.train.freeze,
            device,
        )
report_result("0", phases, result, data_loader)

In [None]:
from ipywidgets import interact
from PIL import Image


@interact(row=(0, 127))
def batch2sample(row):
    dataset = datasets["validation"]
    image_tensor, question_tensor, answer_tensor = (
        batch[0][row],
        batch[1][row],
        batch[2][row],
    )
    image = Image.fromarray(
        (
            255
            * (dataset.std.view(-1, 1, 1) * image_tensor + dataset.mean.view(-1, 1, 1))
        )
        .permute(1, 2, 0)
        .numpy()
        .astype("uint8")
    )
    head, relation, tail = question_tensor.tolist()
    question = (dataset.idx2ent[head], dataset.idx2rel[relation], dataset.idx2ent[tail])
    answer = bool(answer_tensor)
    display(image)
    print(question)
    print("Ground truth: ", answer)
    with torch.no_grad():
        prediction = model(image_tensor.to(device).unsqueeze(0), question_tensor.to(device).unsqueeze(0))
    print("Prediction: {} (p={:3.2f})".format(prediction.item() > 0.5, prediction.item()))