In [None]:
import os

import numpy as np
import torch
import torch.nn as nn
from IPython.display import Javascript
from munch import Munch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models import resnet18
from tqdm.auto import trange

from qsr_learning.main import (
    draw,
    emoji_names,
    generate_negative_examples,
    generate_positive_examples,
    sample_objects,
)

%reload_ext autoreload
%autoreload 2

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class QSRData(Dataset):
    def __init__(
        self,
        num_images=4,
        num_objects=2,
        num_pos_questions_per_image=2,
        num_neg_questions_per_image=2,
    ):
        self.images = []
        self.questions = []
        self.answers = []
        for i in trange(num_images):
            objects = sample_objects(num_objects=num_objects)
            positive_examples = generate_positive_examples(
                objects, num_pos_questions_per_image
            )
            negative_examples = generate_negative_examples(
                objects, positive_examples, num_neg_questions_per_image
            )
            self.images.append(np.array(draw(objects))[:, :, :3])
            for question in positive_examples:
                self.questions.append({"id": i, "question": question})
                self.answers.append({"id": i, "answer": True})
            for question in negative_examples:
                self.questions.append({"id": i, "question": question})
                self.answers.append({"id": i, "answer": False})
        super().__init__()
        relation_names = {"left_of", "right_of", "above", "below"}
        self.idx2word = dict(enumerate(sorted(set(emoji_names) | relation_names)))
        self.word2idx = {self.idx2word[idx]: idx for idx in self.idx2word}
        rgb_mean = 0.5
        rgb_std = 0.5
        self._transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize(rgb_mean, rgb_std)]
        )

    def __getitem__(self, idx):
        image_id = self.questions[idx]["id"]
        return (
            self._transform(self.images[image_id]),
            torch.tensor(
                [self.word2idx[word] for word in self.questions[idx]["question"]]
            ),
            self.answers[idx]["answer"],
        )

    def __len__(self):
        # num_images * num_pos_questions_per_image * num_neg_questions_per_image
        return len(self.questions)

In [None]:
def format_result(phases, result, loader):
    result_dict = {
        "phase": phase[:5],
        "loss": result.total_loss[phase] / len(loader[phase].dataset),
        "accuracy": result.num_correct[phase] / len(loader[phase].dataset),
    }
    return "[{phase}] loss: {loss:4.3f}, acc: {accuracy:4.3f}".format(**result_dict)

In [None]:
class Net(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.image_encoder = resnet18(pretrained=True, progress=True)
        self.embedding = nn.Embedding(config.num_embeddings, config.embedding_dim)
        self.fc = nn.Linear(1000, 3 * config.embedding_dim)

        nn.init.xavier_uniform_(self.embedding.weight.data)

    def forward(self, images, questions):
        # image_features.shape = (batch_size, 3, 1000)
        image_features = self.image_encoder(images)
        # out.shape = (batch_size, self.embedding_dim)
        out = self.fc(image_features)
        # question_features.shape = (16, self.embedding_dim)
        head_features = self.embedding(questions[:, 0])
        relation_features = self.embedding(questions[:, 1])
        tail_features = self.embedding(questions[:, 2])
        question_features = torch.cat(
            (head_features, relation_features, tail_features), dim=-1
        )
        out = (out * question_features).sum(-1)
        return out.sigmoid()

    def step(self, phase, batch, result):
        self.train() if phase == "train" else self.eval()
        torch.autograd.set_grad_enabled(phase == "train")
        images, questions, answers = (item.to(device) for item in batch)
        batch_size = images.shape[0]
        self.zero_grad()
        out = self(images, questions)
        loss = criterion(out, answers.float()) / batch_size
        result.total_loss[phase] += loss.item()
        result.num_correct[phase] += ((out > 0.5) == answers).sum().item()
        if phase == "train":
            loss.backward()
            optimizer.step()

# Train a model

In [None]:
config = Munch(
    embedding_dim=20,
    batch_size=16,
    num_epochs=10,
    data=Munch(
        negative_sample_mixture=Munch(
            head=0, tail=0, relation=1, head_relation=0, tail_relation=0, head_tail=0
        ),
        train=Munch(
            num_images=128,
            num_objects=2,
            num_pos_questions_per_image=2,
            num_neg_questions_per_image=2,
        ),
        validation=Munch(
            num_images=16,
            num_objects=2,
            num_pos_questions_per_image=2,
            num_neg_questions_per_image=2,
        ),
        test=Munch(
            num_images=16,
            num_objects=2,
            num_pos_questions_per_image=2,
            num_neg_questions_per_image=2,
        ),
    ),
)
phases = ["train", "validation", "test"]
data = Munch({phase: QSRData(**config.data[phase]) for phase in phases})
config.num_embeddings = len(data.train.word2idx)
loader = Munch(
    {
        phase: DataLoader(
            data[phase], batch_size=config.batch_size, shuffle=True, num_workers=4
        )
        for phase in phases
    }
)

In [None]:
net = Net(config)
net.to(device)
criterion = nn.BCELoss(reduction="sum")
optimizer = torch.optim.Adam(net.parameters())
phases = ["train", "validation"]
with trange(config.num_epochs) as pbar:
    for epoch in pbar:
        result = Munch(
            total_loss=Munch({phase: 0 for phase in phases}),
            num_correct={phase: 0 for phase in phases},
        )
        for phase in phases:
            for batch in loader[phase]:
                net.step(phase, batch, result)
            print(epoch, format_result(phases, result, loader))

# Test the model

In [None]:
phases = ["test"]
result = Munch(
    total_loss=Munch({phase: 0 for phase in phases}),
    num_correct={phase: 0 for phase in phases},
)

for phase in phases:
    for batch in loader[phase]:
        net.step(phase, batch, result)
print(format_result(phases, result, loader))

In [None]:
def tensor2image(x):
    return Image.fromarray(
        (255 * ((0.5 * x) + 0.5)).numpy().astype("uint8").transpose(1, 2, 0)
    )


def tensor2question(x):
    return [loader.train.dataset.idx2word[idx] for idx in x.numpy()]


def tensor2answer(x):
    return x.item()

In [None]:
for batch in loader.test:
    break
images, questions, answers = (item.to(device) for item in batch)
with torch.no_grad():
    net.eval()
    print(net(images[:1], questions[:1]) > 0.5)
    print(answers[:1])

In [None]:
emojis = []
for item in loader["train"].dataset.questions:
    emojis.append(item["question"][0])
    emojis.append(item["question"][2])
emojis = set(emojis)

In [None]:
display(tensor2image(images[0].cpu()))
display(tensor2question(questions[0].cpu()))
display(tensor2answer(answers[0].cpu()))

In [None]:
'kissing face with closed eyes' in emojis

In [None]:
'Japanese discount button' in emojis

Generate negative examples with irrelevant objects satisfying the relation

# Train Location Detection

- Input: A scene containing one object
- Output: The bounding box of the object

# Notification

In [None]:
Javascript(
    'var snd = new Audio("data:audio/wav;base64,//uQRAAAAWMSLwUIYAAsYkXgoQwAEaYLWfkWgAI0wWs/ItAAAGDgYtAgAyN+QWaAAihwMWm4G8QQRDiMcCBcH3Cc+CDv/7xA4Tvh9Rz/y8QADBwMWgQAZG/ILNAARQ4GLTcDeIIIhxGOBAuD7hOfBB3/94gcJ3w+o5/5eIAIAAAVwWgQAVQ2ORaIQwEMAJiDg95G4nQL7mQVWI6GwRcfsZAcsKkJvxgxEjzFUgfHoSQ9Qq7KNwqHwuB13MA4a1q/DmBrHgPcmjiGoh//EwC5nGPEmS4RcfkVKOhJf+WOgoxJclFz3kgn//dBA+ya1GhurNn8zb//9NNutNuhz31f////9vt///z+IdAEAAAK4LQIAKobHItEIYCGAExBwe8jcToF9zIKrEdDYIuP2MgOWFSE34wYiR5iqQPj0JIeoVdlG4VD4XA67mAcNa1fhzA1jwHuTRxDUQ//iYBczjHiTJcIuPyKlHQkv/LHQUYkuSi57yQT//uggfZNajQ3Vmz+Zt//+mm3Wm3Q576v////+32///5/EOgAAADVghQAAAAA//uQZAUAB1WI0PZugAAAAAoQwAAAEk3nRd2qAAAAACiDgAAAAAAABCqEEQRLCgwpBGMlJkIz8jKhGvj4k6jzRnqasNKIeoh5gI7BJaC1A1AoNBjJgbyApVS4IDlZgDU5WUAxEKDNmmALHzZp0Fkz1FMTmGFl1FMEyodIavcCAUHDWrKAIA4aa2oCgILEBupZgHvAhEBcZ6joQBxS76AgccrFlczBvKLC0QI2cBoCFvfTDAo7eoOQInqDPBtvrDEZBNYN5xwNwxQRfw8ZQ5wQVLvO8OYU+mHvFLlDh05Mdg7BT6YrRPpCBznMB2r//xKJjyyOh+cImr2/4doscwD6neZjuZR4AgAABYAAAABy1xcdQtxYBYYZdifkUDgzzXaXn98Z0oi9ILU5mBjFANmRwlVJ3/6jYDAmxaiDG3/6xjQQCCKkRb/6kg/wW+kSJ5//rLobkLSiKmqP/0ikJuDaSaSf/6JiLYLEYnW/+kXg1WRVJL/9EmQ1YZIsv/6Qzwy5qk7/+tEU0nkls3/zIUMPKNX/6yZLf+kFgAfgGyLFAUwY//uQZAUABcd5UiNPVXAAAApAAAAAE0VZQKw9ISAAACgAAAAAVQIygIElVrFkBS+Jhi+EAuu+lKAkYUEIsmEAEoMeDmCETMvfSHTGkF5RWH7kz/ESHWPAq/kcCRhqBtMdokPdM7vil7RG98A2sc7zO6ZvTdM7pmOUAZTnJW+NXxqmd41dqJ6mLTXxrPpnV8avaIf5SvL7pndPvPpndJR9Kuu8fePvuiuhorgWjp7Mf/PRjxcFCPDkW31srioCExivv9lcwKEaHsf/7ow2Fl1T/9RkXgEhYElAoCLFtMArxwivDJJ+bR1HTKJdlEoTELCIqgEwVGSQ+hIm0NbK8WXcTEI0UPoa2NbG4y2K00JEWbZavJXkYaqo9CRHS55FcZTjKEk3NKoCYUnSQ0rWxrZbFKbKIhOKPZe1cJKzZSaQrIyULHDZmV5K4xySsDRKWOruanGtjLJXFEmwaIbDLX0hIPBUQPVFVkQkDoUNfSoDgQGKPekoxeGzA4DUvnn4bxzcZrtJyipKfPNy5w+9lnXwgqsiyHNeSVpemw4bWb9psYeq//uQZBoABQt4yMVxYAIAAAkQoAAAHvYpL5m6AAgAACXDAAAAD59jblTirQe9upFsmZbpMudy7Lz1X1DYsxOOSWpfPqNX2WqktK0DMvuGwlbNj44TleLPQ+Gsfb+GOWOKJoIrWb3cIMeeON6lz2umTqMXV8Mj30yWPpjoSa9ujK8SyeJP5y5mOW1D6hvLepeveEAEDo0mgCRClOEgANv3B9a6fikgUSu/DmAMATrGx7nng5p5iimPNZsfQLYB2sDLIkzRKZOHGAaUyDcpFBSLG9MCQALgAIgQs2YunOszLSAyQYPVC2YdGGeHD2dTdJk1pAHGAWDjnkcLKFymS3RQZTInzySoBwMG0QueC3gMsCEYxUqlrcxK6k1LQQcsmyYeQPdC2YfuGPASCBkcVMQQqpVJshui1tkXQJQV0OXGAZMXSOEEBRirXbVRQW7ugq7IM7rPWSZyDlM3IuNEkxzCOJ0ny2ThNkyRai1b6ev//3dzNGzNb//4uAvHT5sURcZCFcuKLhOFs8mLAAEAt4UWAAIABAAAAAB4qbHo0tIjVkUU//uQZAwABfSFz3ZqQAAAAAngwAAAE1HjMp2qAAAAACZDgAAAD5UkTE1UgZEUExqYynN1qZvqIOREEFmBcJQkwdxiFtw0qEOkGYfRDifBui9MQg4QAHAqWtAWHoCxu1Yf4VfWLPIM2mHDFsbQEVGwyqQoQcwnfHeIkNt9YnkiaS1oizycqJrx4KOQjahZxWbcZgztj2c49nKmkId44S71j0c8eV9yDK6uPRzx5X18eDvjvQ6yKo9ZSS6l//8elePK/Lf//IInrOF/FvDoADYAGBMGb7FtErm5MXMlmPAJQVgWta7Zx2go+8xJ0UiCb8LHHdftWyLJE0QIAIsI+UbXu67dZMjmgDGCGl1H+vpF4NSDckSIkk7Vd+sxEhBQMRU8j/12UIRhzSaUdQ+rQU5kGeFxm+hb1oh6pWWmv3uvmReDl0UnvtapVaIzo1jZbf/pD6ElLqSX+rUmOQNpJFa/r+sa4e/pBlAABoAAAAA3CUgShLdGIxsY7AUABPRrgCABdDuQ5GC7DqPQCgbbJUAoRSUj+NIEig0YfyWUho1VBBBA//uQZB4ABZx5zfMakeAAAAmwAAAAF5F3P0w9GtAAACfAAAAAwLhMDmAYWMgVEG1U0FIGCBgXBXAtfMH10000EEEEEECUBYln03TTTdNBDZopopYvrTTdNa325mImNg3TTPV9q3pmY0xoO6bv3r00y+IDGid/9aaaZTGMuj9mpu9Mpio1dXrr5HERTZSmqU36A3CumzN/9Robv/Xx4v9ijkSRSNLQhAWumap82WRSBUqXStV/YcS+XVLnSS+WLDroqArFkMEsAS+eWmrUzrO0oEmE40RlMZ5+ODIkAyKAGUwZ3mVKmcamcJnMW26MRPgUw6j+LkhyHGVGYjSUUKNpuJUQoOIAyDvEyG8S5yfK6dhZc0Tx1KI/gviKL6qvvFs1+bWtaz58uUNnryq6kt5RzOCkPWlVqVX2a/EEBUdU1KrXLf40GoiiFXK///qpoiDXrOgqDR38JB0bw7SoL+ZB9o1RCkQjQ2CBYZKd/+VJxZRRZlqSkKiws0WFxUyCwsKiMy7hUVFhIaCrNQsKkTIsLivwKKigsj8XYlwt/WKi2N4d//uQRCSAAjURNIHpMZBGYiaQPSYyAAABLAAAAAAAACWAAAAApUF/Mg+0aohSIRobBAsMlO//Kk4soosy1JSFRYWaLC4qZBYWFRGZdwqKiwkNBVmoWFSJkWFxX4FFRQWR+LsS4W/rFRb/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////VEFHAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAU291bmRib3kuZGUAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMjAwNGh0dHA6Ly93d3cuc291bmRib3kuZGUAAAAAAAAAACU="); snd.play(); new Notification("Cell Execution Has Finished")'
)

227