In [None]:
%reload_ext autoreload
%autoreload 2
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import math

import pytorch_lightning as pl
import torch
from munch import Munch
from torch.utils.data import DataLoader, random_split

from qsr_learning.data import DRLDataset
from qsr_learning.models import DRLNet

In [None]:
config = Munch()

# Dataset

In [None]:
import random

from qsr_learning.entity import emoji_names

entity_names = random.sample(emoji_names, k=20)
excluded_pair = random.sample(entity_names, k=2)

In [None]:
config.dataset = Munch(
#     entity_names=["octopus", "trophy", "frog", "ghost"],
#     excluded_combinations=[],
    relation_names=["left_of", "right_of", "above", "below"],
    num_entities=2,
    frame_of_reference="intrinsic",  # absolute
    w_range=(32, 32),
    h_range=(32, 32),
    theta_range=(0, 2 * math.pi),
    add_bbox=False,
    add_front=False,
    transform=None,
    canvas_size=(224, 224),
    num_samples=10 ** 5,
    root_seed=0,
)

In [None]:
train_dataset = DRLDataset(
    **{
        **config.dataset,
        **dict(
            entity_names=entity_names,
            excluded_combinations=excluded_pair,
            num_samples=10 ** 5,
            root_seed=0,
        ),
    }
)

In [None]:
validation_dataset = DRLDataset(
    **{
        **config.dataset,
        **dict(
            entity_names=excluded_pair,
            excluded_combinations=[],
            num_samples=10 ** 4,
            root_seed=train_dataset.num_samples,
        ),
    }
)

# Data Loader

In [None]:
config.data_loader = Munch(
    batch_size=256,
    shuffle=True,
    num_workers=16,
    pin_memory=True,
)

In [None]:
train_loader = DataLoader(train_dataset, **config.data_loader)
validation_loader = DataLoader(
    validation_dataset, **{**config.data_loader, "shuffle": False}
)

# Model

In [None]:
config.model = Munch(
    vision_model="resnet18",
    image_size=(3, *config.dataset.canvas_size),
    num_embeddings=len(train_dataset.word2idx),
    embedding_dim=10,
    question_len=train_dataset[0][1].shape.numel(),
)

In [None]:
model = DRLNet(**config.model)

# Trainer 

In [None]:
config.trainer = Munch(
    gpus=1,
    max_epochs=100,
    precision=32,
    limit_train_batches=1.0,
    limit_val_batches=1.0,
)

In [None]:
from pathlib import Path

import git
from git.exc import RepositoryDirtyError
from pytorch_lightning import loggers

repo = git.Repo(Path(".").absolute(), search_parent_directories=True)
if repo.is_dirty():
    raise RepositoryDirtyError(repo, "Have you forgotten to commit the changes?")

sha = repo.head.object.hexsha
tb_logger = loggers.TensorBoardLogger(save_dir="lightning_logs", name="", version=sha)
trainer = pl.Trainer(**{**config.trainer, **dict(logger=tb_logger)})
trainer.fit(model, train_loader, validation_loader)