In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import sys
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../run'))
import random
from collections import defaultdict
import itertools

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection

from IPython.display import display, Markdown, Latex

import torch
from torch import nn
import torch.nn.functional as F

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

In [3]:
torch.manual_seed(33)

<torch._C.Generator at 0x7f9be76892f0>

In [4]:
from simple_relational_reasoning.datagen import *
from simple_relational_reasoning import models

import run
from quinn_defaults import prettify_class_name

In [45]:
DEFAULT_MODELS_CONFIG_KEY = 'default'
LARGER_MODELS_CONFIG_KEY = 'larger'
MODEL_CONFIGURATIONS = {
    DEFAULT_MODELS_CONFIG_KEY: {
        models.CombinedObjectMLPModel: dict(embedding_size=8, prediction_sizes=[32, 32, 16]),
        models.RelationNetModel: dict(embedding_size=8, object_pair_layer_sizes=[32], combined_object_layer_sizes=[32]),
        models.TransformerModel: dict(embedding_size=8, transformer_mlp_sizes=[8, 8], mlp_sizes=[32, 32]),
        # models.CNNModel: dict(conv_sizes=[8, 16], mlp_sizes=[16, 8],),
        models.SimplifiedCNNModel: dict(conv_sizes=[8, 16], mlp_sizes=[16, 8],),
        models.PrediNetModel: dict(key_size=4, num_heads=4, num_relations=4, output_hidden_size=16),
        models.PrediNetWithEmbeddingModel: dict(embedding_size=8, key_size=4, num_heads=2, num_relations=4, output_hidden_size=16),
    },
    LARGER_MODELS_CONFIG_KEY: {
        models.CombinedObjectMLPModel: dict(embedding_size=16, prediction_sizes=[64, 64, 32, 16]),
        models.RelationNetModel: dict(embedding_size=16, object_pair_layer_sizes=[64, 32],
                                      combined_object_layer_sizes=[64, 32]),
        models.TransformerModel: dict(embedding_size=16, num_transformer_layers=3, num_heads=2,
                                      transformer_mlp_sizes=[16, 16], mlp_sizes=[64, 32]),
        # models.CNNModel: dict(conv_sizes=[8, 16, 32], mlp_sizes=[32, 32],),
        models.SimplifiedCNNModel: dict(conv_sizes=[8, 16, 32], mlp_sizes=[32, 32],),
        models.PrediNetModel: dict(key_size=8, num_heads=8, num_relations=8, output_hidden_size=32),
        models.PrediNetWithEmbeddingModel: dict(embedding_size=16, key_size=6, num_heads=4, num_relations=6, output_hidden_size=16),
    }
}


In [46]:
REFERENCE_OBJECT_SIZE = 9
TARGET_OBJECT_SIZE = 1
ADD_NEITHER = True
X_MAX = 25
Y_MAX = 25
SEED = 33
PROP_TRAIN_REF_LOCATIONS = 0.90
N_TRAIN_TARGET_OBJECT_LOCATIONS = 7
BATCH_SIZE = 32

object_generator = ObjectGeneratorWithoutSize(SEED, REFERENCE_OBJECT_SIZE, TARGET_OBJECT_SIZE)

spatial_dataset = False
for config_name in MODEL_CONFIGURATIONS:
    print(f'CONFIGURATION NAME: {config_name}')
    for model_class, model_kwargs in MODEL_CONFIGURATIONS[config_name].items():
        model_name = prettify_class_name(model_class)
        
        if 'simplified' in model_name.lower():
            spatial_dataset = 'simplified'
        else:
            spatial_dataset = 'cnn' in model_name.lower()
            
        dataset = AboveBelowReferenceInductiveBias(object_generator, X_MAX, Y_MAX, SEED, 
                                           prop_train_reference_object_locations=PROP_TRAIN_REF_LOCATIONS,
                                           n_train_target_object_locations=N_TRAIN_TARGET_OBJECT_LOCATIONS,
                                           spatial_dataset=spatial_dataset)
        
        model = model_class(dataset, **model_kwargs)
        
        if spatial_dataset == 'simplified':
            input_size = (BATCH_SIZE, 2, dataset.x_max, dataset.y_max)
        elif spatial_dataset:
            input_size = (BATCH_SIZE, model.object_size, dataset.x_max, dataset.y_max)
        else:
            input_size = (BATCH_SIZE, model.num_objects, model.object_size)
        
        print(f'\t{model_name}: {sum([p.numel() for p in model.parameters()])}')
        sample_input = torch.rand(input_size)
        y = model(sample_input)

CONFIGURATION NAME: default
	combined-object-mlp: 1963
	relation-net: 1739
	transformer: 1915
	simplified-cnn: 1755
	predi-net: 1891
	predi-netwithembedding: 1707
CONFIGURATION NAME: larger
	combined-object-mlp: 7987
	relation-net: 8563
	transformer: 8243
	simplified-cnn: 8171
	predi-net: 8387
	predi-netwithembedding: 8659


In [25]:
REFERENCE_OBJECT_SIZE = 9
TARGET_OBJECT_SIZE = 1
ADD_NEITHER = True
X_MAX = 25
Y_MAX = 25
SEED = 33
PROP_TRAIN_REF_LOCATIONS = 0.90
N_TRAIN_TARGET_OBJECT_LOCATIONS = 7

object_generator = ObjectGeneratorWithSize(SEED, REFERENCE_OBJECT_SIZE, TARGET_OBJECT_SIZE)

spatial_dataset = False
for config_name in MODEL_CONFIGURATIONS:
    print(f'CONFIGURATION NAME: {config_name}')
    for model_class, model_kwargs in MODEL_CONFIGURATIONS[config_name].items():
        model_name = prettify_class_name(model_class)
        
        if 'simplified' in model_name.lower():
            spatial_dataset = 'simplified'
        else:
            spatial_dataset = 'cnn' in model_name.lower()
            
        dataset = AboveBelowReferenceInductiveBias(object_generator, X_MAX + 2, Y_MAX + 2, SEED, 
                                           prop_train_reference_object_locations=PROP_TRAIN_REF_LOCATIONS,
                                           n_train_target_object_locations=N_TRAIN_TARGET_OBJECT_LOCATIONS,
                                           spatial_dataset=spatial_dataset)
        
        model = model_class(dataset, **model_kwargs)
        
        if spatial_dataset == 'simplified':
            input_size = (BATCH_SIZE, 2, dataset.x_max, dataset.y_max)
        elif spatial_dataset:
            input_size = (BATCH_SIZE, model.object_size, dataset.x_max, dataset.y_max)
        else:
            input_size = (BATCH_SIZE, model.num_objects, model.object_size)
        
        print(f'\t{model_name}: {sum([p.numel() for p in model.parameters()])}')
        
        sample_input = torch.rand(input_size)
        y = model(sample_input)

CONFIGURATION NAME: default
	combined-object-mlp: 1971
	relation-net: 1747
	transformer: 1923
	simplified-cnn: 1755
	predi-net: 939
	predi-netwithembedding: 691
CONFIGURATION NAME: larger
	combined-object-mlp: 8003
	relation-net: 8579
	transformer: 8259
	simplified-cnn: 8171
	predi-net: 4563
	predi-netwithembedding: 2979


In [8]:
from torch.utils.data import DataLoader

In [9]:
dataset = AboveBelowReferenceInductiveBias(object_generator, X_MAX + 2, Y_MAX + 2, SEED, 
                                           prop_train_reference_object_locations=PROP_TRAIN_REF_LOCATIONS,
                                           n_train_target_object_locations=N_TRAIN_TARGET_OBJECT_LOCATIONS,
                                           spatial_dataset=spatial_dataset,
                                           subsample_train_size=32)
train = dataset.get_training_dataset()
train_loader = DataLoader(train, shuffle=True, batch_size=256)

In [10]:
for X, y in train_loader:
    print(X.shape, y.shape)

torch.Size([32, 2, 5]) torch.Size([32])
