In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import math
from copy import deepcopy

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision
from munch import Munch
from qsr_learning.data import DRLDataset
from qsr_learning.entity import emoji_names
from sklearn.metrics import accuracy_score

In [None]:
dataset = DRLDataset(
    entity_names=["octopus", "trophy"],
    relation_names=["left_of", "right_of"],
    num_entities=2,
    fixed_entities=None,
    frame_of_reference="absolute",
    num_samples=128,
    w_range=(32, 32),
    h_range=(32, 32),
    theta_range=(0.0, 2 * math.pi),
    num_questions_per_image=1,
    random_seed=0,
)

In [None]:
class DRLNet(pl.LightningModule):
    def __init__(self, num_embeddings: int, embedding_dim: int, vision_model: str):
        super().__init__()

        self.embedding_dim = embedding_dim

        # Image encoder
        resnet = getattr(torchvision.models, vision_model)(pretrained=True)
        self.image_encoder = nn.Sequential(*deepcopy(list(resnet.children())[:-3]))
        del resnet
        # Freeze the image encoder weights
        for param in self.image_encoder.parameters():
            param.requires_grad = False

        # Question encoder
        self.question_encoder = nn.Identity()
        encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)

        # Fusion
        self.fusion = nn.Identity()
        self.criterion = nn.BCELoss()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.fc = nn.Linear(
            self.image_feature_size.numel(), self.question_feature_size.numel()
        )

    @property
    def image_feature_size(self):
        image = torch.rand((1, 3, 224, 224), device=self.device)
        return self.image_encoder(image).shape

    @property
    def question_feature_size(self):
        question = torch.ones((1, 3, self.embedding_dim), device=self.device)
        return self.question_encoder(question).shape

    def forward(self, images, questions):
        pass

    def training_step(self, batch, batch_idx):
        self.image_encoder.eval()

        # Make prediction
        images, questions, answers = batch
        preds = self(images, questions)
        loss = self.criterion(preds, answers)

        # Logging
        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy_score(answers, preds))
        return loss

    def configure_optimizers(self):
        # Make sure to filter the parameters based on `requires_grad`
        return torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters))


model = DRLNet(num_embeddings=10, embedding_dim=10, vision_model="resnet18")

In [None]:
model.image_feature_size

In [None]:
model.question_feature_size

In [None]:
trainer = pl.Trainer()
trainer.fit(model)