In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
from munch import Munch
from ray import tune


def report_result(epoch, phases, result, data_loader):
    log = dict(epoch=epoch)
    for phase in phases:
        log[phase + "_loss"] = result[phase].total_loss / len(
            data_loader[phase].dataset
        )
        log[phase + "_accuracy"] = result[phase].num_correct / len(
            data_loader[phase].dataset
        )
    # tune.report(**log)
    print(log)


def step(model, criterion, optimizer, phase, batch, result, freeze, device):
    if phase == "train":
        model.train()
        if freeze == "all":
            model.image_encoder.eval()
        if freeze == "bn":  # Common in training object detection models
            for layer in model.modules():
                if isinstance(layer, nn.BatchNorm2d):
                    layer.eval()
    else:
        model.eval()
    torch.autograd.set_grad_enabled(phase == "train")
    images, questions, answers = (item.to(device) for item in batch)
    batch_size = images.shape[0]
    model.zero_grad()
    out = model(images, questions)
    loss = criterion(out, answers.float()) / batch_size
    result[phase].total_loss += loss.item()
    result[phase].num_correct += ((out > 0.5) == answers).sum().item()
    if phase == "train":
        loss.backward()
        optimizer.step()

# Training

In [None]:
import math

import torch
import torchvision
from munch import Munch
from torch.utils.data import DataLoader

from qsr_learning.data import DRLDataset, get_mean_and_std
from qsr_learning.entity import emoji_names
from qsr_learning.models import HadarmardFusionNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
config = Munch(
    data=Munch(
        train=Munch(
            #             entity_names=emoji_names,
            entity_names=["octopus", "trophy"],
            relation_names=["left_of", "right_of"],
            num_entities=2,
            frame_of_reference="absolute",
            w_range=(32, 32),
            h_range=(32, 32),
            theta_range=(0, 2 * math.pi),
            num_samples=2**14,
            shuffle=True,
        ),
        validation=Munch(
            entity_names=["octopus", "trophy"],
            relation_names=["left_of", "right_of"],
            num_entities=2,
            frame_of_reference="absolute",
            w_range=(32, 32),
            h_range=(32, 32),
            theta_range=(0, 2 * math.pi),  # theta_range=(0, 2 * math.pi),
            num_samples=64,
            shuffle=False,
        ),
    ),
    model=Munch(
        ent_dim=10,
        rel_dim=10,
        cnn_model="resnet18",
        pretrained=False,
    ),
    train=Munch(batch_size=128, num_epochs=100, freeze="all", lr=0.001),
)

In [None]:
phases = ["train", "validation"]
datasets = Munch({phase: DRLDataset(**config.data[phase]) for phase in phases})

data_loader = Munch(
    {
        phase: DataLoader(
            datasets[phase],
            batch_size=config.train.batch_size,
            num_workers=4,
        )
        for phase in phases
    }
)

In [None]:
from tqdm.auto import trange

In [None]:
model = HadarmardFusionNet(datasets.train, **config.model)
model.to(device)

criterion = nn.BCELoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=config.train.lr)
result = Munch()
for epoch in trange(config.train.num_epochs):
    for phase in phases:
        result[phase] = Munch()
        result[phase].total_loss = 0
        result[phase].num_correct = 0
        for batch in data_loader[phase]:
            step(
                model,
                criterion,
                optimizer,
                phase,
                batch,
                result,
                config.train.freeze,
                device,
            )
    report_result(epoch, phases, result, data_loader)

phases = ["train", "validation"]
datasets = Munch({phase: DRLDataset(**config.data[phase]) for phase in phases})
data_loader = Munch(
    {
        phase: DataLoader(
            datasets[phase],
            batch_size=config.train.batch_size,
            num_workers=4,
        )
        for phase in phases
    }
)

# Test

In [None]:
for batch in data_loader["train"]:
    break

from PIL import Image


def batch2sample(batch, row, dataset):
    image = Image.fromarray(
        (
            255
            * (dataset.std.view(-1, 1, 1) * batch[0][row] + dataset.mean.view(-1, 1, 1))
        )
        .permute(1, 2, 0)
        .numpy()
        .astype("uint8")
    )
    head, relation, tail = batch[1][row].tolist()
    question = (dataset.idx2ent[head], dataset.idx2rel[relation], dataset.idx2ent[tail])
    answer = bool(batch[2][row])
    display(image, question, answer)


batch2sample(batch, 6, datasets["train"])