In [None]:
%reload_ext autoreload
%autoreload 2

import torch
from munch import Munch
from qsr_learning.train import train

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from qsr_learning.entity import Entity

trophy = Entity(
    name="trophy",
    frame_of_reference="absolute",
    p=(112, 112),
    theta=0,
    size=(32, 32),
)

In [None]:
config = Munch(
    data=Munch(
        train=Munch(
            # entity_names=emoji_names,
            entity_names=["octopus", "trophy"],
            relation_names=["left_of"],
            num_entities=2,
            fixed_entities=[trophy],
            frame_of_reference="absolute",
            w_range=(32, 32),
            h_range=(32, 32),
            theta_range=(0, 0),
            num_samples=2 ** 10,
            shuffle=True,
            random_seed=0,
        ),
        validation=Munch(
            entity_names=["octopus", "trophy"],
            relation_names=["left_of"],
            num_entities=2,
            fixed_entities=[trophy],
            frame_of_reference="absolute",
            w_range=(32, 32),
            h_range=(32, 32),
            theta_range=(0, 0),  # theta_range=(0, 2 * math.pi),
            num_samples=2 ** 4,
            shuffle=True,
            random_seed=1,
        ),
    ),
    model=Munch(ent_dim=10, rel_dim=10, cnn_model="resnet18", pretrained=True),
    train=Munch(batch_size=128, num_epochs=100, freeze="all", lr=0.001),
)

In [None]:
train(config, device)

In [None]:
from ipywidgets import interact
from munch import Munch
from PIL import Image
from qsr_learning.data import DRLDataset
from torch.utils.data import DataLoader

phases = ["train", "validation"]
datasets = Munch({phase: DRLDataset(**config.data[phase]) for phase in phases})
data_loader = Munch(
    {
        phase: DataLoader(
            datasets[phase],
            batch_size=config.train.batch_size,
            num_workers=4,
        )
        for phase in phases
    }
)

for batch in data_loader["train"]:
    break


@interact(row=(0, 7))
def batch2sample(row):
    dataset = datasets["train"]
    image_tensor, question_tensor, answer_tensor = (
        batch[0][row],
        batch[1][row],
        batch[2][row],
    )
    image = Image.fromarray(
        (
            255
            * (dataset.std.view(-1, 1, 1) * image_tensor + dataset.mean.view(-1, 1, 1))
        )
        .permute(1, 2, 0)
        .numpy()
        .astype("uint8")
    )
    head, relation, tail = question_tensor.tolist()
    question = (dataset.idx2ent[head], dataset.idx2rel[relation], dataset.idx2ent[tail])
    answer = bool(answer_tensor)
    display(image)
    print(question)
    print("Ground truth: ", answer)