In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import math
from copy import deepcopy

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision
from munch import Munch
from qsr_learning.data import DRLDataset
from qsr_learning.entity import emoji_names
from sklearn.metrics import accuracy_score

In [None]:
from typing import Callable, Dict, List, Tuple

from qsr_learning.data import draw_entities
from qsr_learning.entity import Entity
from qsr_learning.relation import above, below, left_of, right_of


def inside_canvas(entity: Entity, canvas_size: Tuple[int, int]) -> bool:
    """Check whether entity is in the canvas."""
    xs_inside_canvas = all(
        (0 < entity.bbox[:, 0]) & (entity.bbox[:, 0] < canvas_size[0])
    )
    ys_inside_canvas = all(
        (0 < entity.bbox[:, 1]) & (entity.bbox[:, 1] < canvas_size[1])
    )
    return xs_inside_canvas and ys_inside_canvas


def generate_entities(
    entity_names,
    num_entities: int = 5,
    frame_of_reference: str = "absolute",
    w_range: Tuple[int, int] = (32, 32),
    h_range: Tuple[int, int] = (32, 32),
    theta_range: Tuple[float, float] = (0.0, 2 * math.pi),
    canvas_size: Tuple[int, int] = (224, 224),
    relations: List[Callable[[Entity, Entity], bool]] = [
        left_of,
        right_of,
        above,
        below,
    ],
) -> List[Entity]:
    """
    :param canvas_size: (width, height)
    """
    entity_names_copy = deepcopy(entity_names)
    random.shuffle(entity_names_copy)

    entities_in_canvas = False
    while not entities_in_canvas:
        entities = []
        for name in entity_names_copy[:num_entities]:
            # Rotate and translate the entities.            
            theta = random.uniform(*theta_range)
            p = (random.uniform(0, canvas_size[0]), random.uniform(0, canvas_size[1]))
            entity = Entity(
                name=name,
                frame_of_reference=frame_of_reference,
                p=p,
                theta=theta,
                size=(random.randint(*w_range), (random.randint(*h_range))),
            )
            entities.append(entity)
        # Ensure that all entities are inside the canvas
        entities_in_canvas = all(
            inside_canvas(entity, canvas_size) for entity in entities
        )
    return entities

In [None]:
import math
import random
from collections import namedtuple
from copy import deepcopy

from qsr_learning.relation import above, below, left_of, right_of

Question = namedtuple("Question", ["head", "relation", "tail"])
self = Munch()
self.entity_names = ["octopus", "trophy"]
self.num_entities = 2
self.frame_of_reference = "absolute"
self.w_range = (32, 32)
self.h_range = (32, 32)
self.add_bbox = False
self.theta_range = (0, 2 * math.pi)
self.canvas_size = (224, 224)
self.relations = [above, below, left_of, right_of]


def gen_sample():
    satisfied = []
    while not satisfied:
        # TODO: continue here
        entities = generate_entities(
            self.entity_names,
            self.num_entities,
            self.frame_of_reference,
            w_range=self.w_range,
            h_range=self.h_range,
            theta_range=self.theta_range,
            canvas_size=self.canvas_size,
            relations=self.relations,
        )
        head, tail = random.sample(entities, 2)
        answer = random.randint(0, 1)
        indices = list(range(len(self.relations)))
        random.shuffle(indices)
        satisfied = [
            self.relations[i] for i in indices if self.relations[i](head, tail)
        ]
    dissatisified = list(set(self.relations) - set(satisfied))
    relation = random.choice(satisfied) if answer else random.choice(dissatisified)
    image = draw_entities(
        entities, canvas_size=self.canvas_size, show_bbox=self.add_bbox
    )
    question = Question(head.name, relation.__name__, tail.name)
    return image, question, answer


image, question, answer = gen_sample()
display(image, question, bool(answer))

In [None]:
from collections import namedtuple

from torch.utils.data import Dataset

Question = namedtuple("Question", ["head", "relation", "tail"])


class DRLDataset(Dataset):
    def __init__(
        self,
        entity_names,
        relation_names,
        num_entities,
        fixed_entities,
        frame_of_reference,
        w_range,
        h_range,
        theta_range,
        num_samples,
        filter_fn=None,
        show_bbox=False,
        orientation_marker=False,
        transform=None,
        num_questions_per_image=1,
        random_seed=0,
    ):
        """
        :param num_questions_per_image: the (maximal) number of questions generated for each image.
        """
        super().__init__()
        self.entity_names = entity_names
        self.relations = [
            getattr(qsr_learning.relation, relation_name)
            for relation_name in relation_names
        ]
        self.num_entities = num_entities
        self.fixed_entities = fixed_entities  # predefined entites
        self.w_range = (32, 32)
        self.h_range = (32, 32)
        self.frame_of_reference = frame_of_reference
        self.theta_range = (0, 2 * math.pi)
        self.num_samples = num_samples
        self.filter_fn = filter_fn
        self.show_bbox = show_bbox
        self.orientation_marker = orientation_marker

        if not transform:
            self.mean, self.std = get_mean_and_std(
                entity_names,
                relation_names,
                num_entities,
                w_range,
                h_range,
            )
            self.transform = torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(self.mean, self.std),
                ]
            )
        else:
            self.transform = transform

        self.num_questions_per_image = num_questions_per_image
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)

        self.idx2ent, self.ent2idx = {}, {}
        for idx, entity_name in enumerate(sorted(entity_names)):
            self.idx2ent[idx] = entity_name
            self.ent2idx[entity_name] = idx

        self.idx2rel, self.rel2idx = {}, {}
        for idx, relation_name in enumerate(sorted(relation_names)):
            self.idx2rel[idx] = relation_name
            self.rel2idx[relation_name] = idx

        self.image: Dict[int, PIL.Image] = {}
        self.questions: Dict[int, Questions] = {}
        self.answers: Dict[int, Answers] = {}

    def __getitem__(self, idx):
        return self.generate_sample()

    def __len__(self):
        return self.num_samples

    def gen_sample(self):
        while True:
            entities = generate_entities(
                entity_names=self.entity_names,
                num_entities=self.num_entities,
                frame_of_reference=self.frame_of_reference,
                w_range=self.w_range,
                h_range=self.h_range,
                theta_range=self.theta_range,
            )
            head, tail = random.sample(entities, 2)
            relations_shuffled = deepcopy(self.relations)
            random.shuffle(relations_shuffled)
            answer = random.randint(0, 1)
            for relation in relations_shuffled:
                if relation(head, tail) == answer:
                    break

            image = draw_entities(
                entities,
                show_bbox=self.show_bbox,
                orientation_marker=self.orientation_marker,
            )
            background = Image.new("RGBA", image.size, (0, 0, 0))
            image = Image.alpha_composite(background, image).convert("RGB")
            question = Question(head.name, relation.__name__, tail.name)

            return image, question, answer

    def gen_sample(self):
        entities = generate_entities(
            entity_names=self.entity_names,
            num_entities=self.num_entities,
            frame_of_reference=self.frame_of_reference,
            w_range=self.w_range,
            h_range=self.h_range,
            theta_range=self.theta_range,
        )

        image = draw_entities(
            entities,
            show_bbox=self.show_bbox,
            orientation_marker=self.orientation_marker,
        )
        background = Image.new("RGBA", image.size, (0, 0, 0))
        image = Image.alpha_composite(background, image).convert("RGB")

        question, answer = gen_qa(entities, self.relations)

        def q2t(question: Question):
            """Convert question to tensor"""
            return torch.tensor([self.word2idx[w] for w in question], dtype=torch.int64)

        return self.transform(image), q2t(question), torch.tensor(answer)

In [None]:
dataset = DRLDataset(
    entity_names=["octopus", "trophy"],
    relation_names=["left_of", "right_of"],
    num_entities=2,
    fixed_entities=None,
    frame_of_reference="absolute",
    num_samples=128,
    w_range=(32, 32),
    h_range=(32, 32),
    theta_range=(0.0, 2 * math.pi),
    num_questions_per_image=1,
    random_seed=0,
)

In [None]:
class DRLNet(pl.LightningModule):
    def __init__(self, num_embeddings: int, embedding_dim: int, vision_model: str):
        super().__init__()

        self.embedding_dim = embedding_dim

        # Image encoder
        resnet = getattr(torchvision.models, vision_model)(pretrained=True)
        self.image_encoder = nn.Sequential(*deepcopy(list(resnet.children())[:-3]))
        del resnet
        # Freeze the image encoder weights
        for param in self.image_encoder.parameters():
            param.requires_grad = False

        # Question encoder
        self.question_encoder = nn.Identity()
        encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)

        # Fusion
        self.fusion = nn.Identity()
        self.criterion = nn.BCELoss()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.fc = nn.Linear(
            self.image_feature_size.numel(), self.question_feature_size.numel()
        )

    @property
    def image_feature_size(self):
        image = torch.rand((1, 3, 224, 224), device=self.device)
        return self.image_encoder(image).shape

    @property
    def question_feature_size(self):
        question = torch.ones((1, 3, self.embedding_dim), device=self.device)
        return self.question_encoder(question).shape

    def forward(self, images, questions):
        pass

    def training_step(self, batch, batch_idx):
        self.image_encoder.eval()

        # Make prediction
        images, questions, answers = batch
        preds = self(images, questions)
        loss = self.criterion(preds, answers)

        # Logging
        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy_score(answers, preds))
        return loss

    def configure_optimizers(self):
        # Make sure to filter the parameters based on `requires_grad`
        return torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters))


model = DRLNet(num_embeddings=10, embedding_dim=10, vision_model="resnet18")

In [None]:
model.image_feature_size

In [None]:
model.question_feature_size

In [None]:
trainer = pl.Trainer()
trainer.fit(model)