In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import sys
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../run'))
import random
from collections import defaultdict
import itertools

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection

from IPython.display import display, Markdown, Latex

import torch
from torch import nn
import torch.nn.functional as F

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

In [3]:
torch.manual_seed(33)

<torch._C.Generator at 0x7fc5b19a0390>

In [4]:
from simple_relational_reasoning.datagen import *
from simple_relational_reasoning import models

import run
from quinn_defaults import prettify_class_name

In [17]:
DEFAULT_MODELS_CONFIG_KEY = 'default'
LARGER_MODELS_CONFIG_KEY = 'larger'
MODEL_CONFIGURATIONS = {
    DEFAULT_MODELS_CONFIG_KEY: {
        models.CombinedObjectMLPModel: dict(embedding_size=8, prediction_sizes=[32, 32, 16]),
        models.RelationNetModel: dict(embedding_size=8, object_pair_layer_sizes=[32], combined_object_layer_sizes=[32]),
        models.TransformerModel: dict(embedding_size=8, transformer_mlp_sizes=[8, 8], mlp_sizes=[32, 32]),
        models.CNNModel: dict(conv_sizes=[8, 16], mlp_sizes=[16, 8],)
    },
    LARGER_MODELS_CONFIG_KEY: {
        models.CombinedObjectMLPModel: dict(embedding_size=16, prediction_sizes=[64, 64, 32, 16]),
        models.RelationNetModel: dict(embedding_size=16, object_pair_layer_sizes=[64, 32],
                                      combined_object_layer_sizes=[64, 32]),
        models.TransformerModel: dict(embedding_size=16, num_transformer_layers=3, num_heads=2,
                                      transformer_mlp_sizes=[16, 16], mlp_sizes=[64, 32]),
        models.CNNModel: dict(conv_sizes=[8, 16, 32], mlp_sizes=[32, 32],),
    }
}

input_sizes = {
    models.CombinedObjectMLPModel: (32, 2, 5),
    models.RelationNetModel: (32, 2, 5),
    models.TransformerModel: (32, 2, 5),
    models.CNNModel: (32, 5, 18, 18)
}

In [18]:
REFERENCE_OBJECT_SIZE = 9
TARGET_OBJECT_SIZE = 1
ADD_NEITHER = True
X_MAX = 25
Y_MAX = 25
SEED = 33
PROP_TRAIN_REF_LOCATIONS = 0.90
N_TRAIN_TARGET_OBJECT_LOCATIONS = 7

object_generator = ObjectGeneratorWithSize(SEED, REFERENCE_OBJECT_SIZE, TARGET_OBJECT_SIZE)
dataset = AboveBelowReferenceInductiveBias(object_generator, X_MAX, Y_MAX, SEED, 
                                           prop_train_reference_object_locations=PROP_TRAIN_REF_LOCATIONS,
                                           n_train_target_object_locations=N_TRAIN_TARGET_OBJECT_LOCATIONS)

In [19]:
for config_name in MODEL_CONFIGURATIONS:
    print(f'CONFIGURATION NAME: {config_name}')
    for model_class, model_kwargs in MODEL_CONFIGURATIONS[config_name].items():
        model_name = prettify_class_name(model_class)
        model = model_class(dataset, **model_kwargs)
        print(f'\t{model_name}: {sum([p.numel() for p in model.parameters()])}')
        sample_input = torch.rand(input_sizes[model_class])
        y = model(sample_input)

CONFIGURATION NAME: default
	combined-object-mlp: 1954
	relation-net: 1714
	transformer: 1890
	cnn: 1962
CONFIGURATION NAME: larger
	combined-object-mlp: 7986
	relation-net: 8546
	transformer: 8226
	cnn: 8354
