In [None]:
%reload_ext autoreload
%autoreload 2

# Prepare a simple, but non-trivial dataset

## Fix the location of the first entity and use only one relation.

In [None]:
import math
from copy import deepcopy
from pprint import pprint

from munch import Munch, unmunchify
from qsr_learning.entity import Entity

fixed_entities = [
    Entity(
        name="octopus",
        frame_of_reference="absolute",
        p=(96, 96),
        theta=0,
        size=(32, 32),
    )
]

## Generate question involving only the first entity

In [None]:
def filter_fn(question): 
    return question.head == "trophy"

In [None]:
config = Munch()
config.data = Munch(
    train=Munch(
        entity_names=["octopus", "trophy"],
        relation_names=["left_of", "right_of"],
        num_entities=2,
        fixed_entities=fixed_entities,  # fix certain entities
        frame_of_reference="absolute",
        w_range=(32, 32),
        h_range=(32, 32),
        theta_range=(0, math.pi),
        num_samples=2 ** 10,
        filter_fn=filter_fn,  # Filter questions
        shuffle=True,
        random_seed=0,
    )
)
config.data.validation = deepcopy(config.data.train)
config.data.validation.update(Munch(num_samples=2 ** 7, random_seed=1))
config.model = Munch(ent_dim=10, rel_dim=10, cnn_model="resnet18", pretrained=True)
config.train = Munch(batch_size=128, num_epochs=10, freeze="all", lr=0.001)

In [None]:
def pretty(d, indent=0):
    for key, value in d.items():
        if isinstance(value, dict):
            print("\t" * indent + str(key))
            pretty(value, indent + 1)
        else:
            print("\t" * indent + f"{str(key):20}: " + str(value))

pretty(unmunchify(config))

## Prepare a simple model

In [None]:
from copy import deepcopy

import torch
import torch.nn as nn
import torchvision


class Net1(nn.Module):
    def __init__(
        self,
        dataset,
        ent_dim,
        rel_dim,
        cnn_model: str,
        pretrained,
    ):
        super().__init__()
        resnet = getattr(torchvision.models, cnn_model)(pretrained=pretrained)
        self.image_encoder = nn.Sequential(*deepcopy(list(resnet.children())[:-3]))
        del resnet
        # Get the size of the image featutres
        with torch.no_grad():
            device = list(self.parameters())[0].device
            output_size = (
                self.image_encoder(torch.rand((1, 3, 224, 224), device=device))
                .view(-1)
                .shape[0]
            )
        self.ent_embedding = nn.Embedding(len(dataset.idx2ent), ent_dim)
        self.rel_embedding = nn.Embedding(len(dataset.idx2rel), rel_dim)
        self.fc1 = nn.Linear(output_size, 2 * ent_dim + rel_dim)
        self.fc2 = nn.Linear(2 * ent_dim + rel_dim, 2 * ent_dim + rel_dim)
        self.fc3 = nn.Linear(2 * (2 * ent_dim + rel_dim), 2 * (2 * ent_dim + rel_dim))
        self.fc4 = nn.Linear(2 * (2 * ent_dim + rel_dim), 1)

        nn.init.xavier_uniform_(self.ent_embedding.weight.data)
        nn.init.xavier_uniform_(self.rel_embedding.weight.data)

    def forward(self, images, questions):
        image_features = self.image_encoder(images).view(images.shape[0], -1)
        head_features = self.ent_embedding(questions[:, 0])
        relation_features = self.rel_embedding(questions[:, 1])
        tail_features = self.ent_embedding(questions[:, 2])
        question_features = torch.cat(
            (head_features, relation_features, tail_features), dim=-1
        )
        out = torch.cat((self.fc2(self.fc1(image_features).relu()), question_features), dim=-1)
        out = self.fc3(out).sigmoid()
        out = self.fc4(out).view(-1)
        out = out.sigmoid()
        return out

In [None]:
from copy import deepcopy

import torch
import torch.nn as nn
import torchvision


class Net2(nn.Module):
    def __init__(
        self,
        dataset,
        ent_dim,
        rel_dim,
        cnn_model: str,
        pretrained,
    ):
        super().__init__()
        resnet = getattr(torchvision.models, cnn_model)(pretrained=pretrained)
        self.image_encoder = nn.Sequential(*deepcopy(list(resnet.children())[:-3]))
        del resnet
        # Get the size of the image featutres
        with torch.no_grad():
            device = list(self.parameters())[0].device
            output_size = (
                self.image_encoder(torch.rand((1, 3, 224, 224), device=device))
                .view(-1)
                .shape[0]
            )
        self.ent_embedding = nn.Embedding(len(dataset.idx2ent), ent_dim)
        self.rel_embedding = nn.Embedding(len(dataset.idx2rel), rel_dim)
        self.fc1 = nn.Linear(output_size, 2 * ent_dim + rel_dim)
        self.fc2 = nn.Linear(2 * ent_dim + rel_dim, 2 * ent_dim + rel_dim)

        nn.init.xavier_uniform_(self.ent_embedding.weight.data)
        nn.init.xavier_uniform_(self.rel_embedding.weight.data)

    def forward(self, images, questions):
        image_features = self.image_encoder(images).view(images.shape[0], -1)
        head_features = self.ent_embedding(questions[:, 0])
        relation_features = self.rel_embedding(questions[:, 1])
        tail_features = self.ent_embedding(questions[:, 2])
        question_features = torch.cat(
            (head_features, relation_features, tail_features), dim=-1
        )
        out = (self.fc2(self.fc1(image_features).relu()) * question_features).sum(-1)
        out = out.sigmoid()
        return out

## Train the simple model

In [None]:
import torch
import torch.nn as nn
from munch import Munch
from qsr_learning.data import DRLDataset
from qsr_learning.train import report_result, step, train
from torch.utils.data import DataLoader
from tqdm.auto import trange

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# train(config, device)


def train(Module, config, device):
    phases = ["train", "validation"]
    datasets = Munch({phase: DRLDataset(**config.data[phase]) for phase in phases})
    data_loader = Munch(
        {
            phase: DataLoader(
                datasets[phase],
                batch_size=config.train.batch_size,
                num_workers=4,
            )
            for phase in phases
        }
    )
    model = Module(datasets.train, **config.model)
    model.to(device)
    criterion = nn.BCELoss(reduction="sum")
    optimizer = torch.optim.Adam(model.parameters(), lr=config.train.lr)
    result = Munch()
    for epoch in trange(config.train.num_epochs):
        for phase in phases:
            result[phase] = Munch()
            result[phase].total_loss = 0
            result[phase].num_correct = 0
            for batch in data_loader[phase]:
                step(
                    model,
                    criterion,
                    optimizer,
                    phase,
                    batch,
                    result,
                    config.train.freeze,
                    device,
                )
        report_result(epoch, phases, result, data_loader)
        torch.save(model.state_dict(), "model.pt")

In [None]:
train(Net1, config, device)

In [None]:
train(Net2, config, device)

## Test the dataset

In [None]:
from ipywidgets import interact
from munch import Munch
from PIL import Image
from qsr_learning.data import DRLDataset
from torch.utils.data import DataLoader

phases = ["train", "validation"]
datasets = Munch({phase: DRLDataset(**config.data[phase]) for phase in phases})
data_loader = Munch(
    {
        phase: DataLoader(
            datasets[phase],
            batch_size=config.train.batch_size,
            num_workers=4,
        )
        for phase in phases
    }
)

for batch in data_loader["train"]:
    break


@interact(row=(0, 7))
def batch2sample(row):
    dataset = datasets["train"]
    image_tensor, question_tensor, answer_tensor = (
        batch[0][row],
        batch[1][row],
        batch[2][row],
    )
    image = Image.fromarray(
        (
            255
            * (dataset.std.view(-1, 1, 1) * image_tensor + dataset.mean.view(-1, 1, 1))
        )
        .permute(1, 2, 0)
        .numpy()
        .astype("uint8")
    )
    head, relation, tail = question_tensor.tolist()
    question = (dataset.idx2ent[head], dataset.idx2rel[relation], dataset.idx2ent[tail])
    answer = bool(answer_tensor)
    display(image)
    print(question)
    print("Ground truth: ", answer)

In [None]:
for batch in data_loader["validation"]:
    break


@interact(row=(0, 7))
def batch2sample(row):
    dataset = datasets["validation"]
    image_tensor, question_tensor, answer_tensor = (
        batch[0][row],
        batch[1][row],
        batch[2][row],
    )
    image = Image.fromarray(
        (
            255
            * (dataset.std.view(-1, 1, 1) * image_tensor + dataset.mean.view(-1, 1, 1))
        )
        .permute(1, 2, 0)
        .numpy()
        .astype("uint8")
    )
    head, relation, tail = question_tensor.tolist()
    question = (dataset.idx2ent[head], dataset.idx2rel[relation], dataset.idx2ent[tail])
    answer = bool(answer_tensor)
    display(image)
    print(question)
    print("Ground truth: ", answer)