In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import sys
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../run'))
import random
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection

from IPython.display import display, Markdown, Latex

import torch
from torch import nn
import torch.nn.functional as F

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

In [None]:
torch.manual_seed(33)

In [None]:
from simple_relational_reasoning.datagen import *
from simple_relational_reasoning.models import *

import run
from defaults import FIELD_CONFIGURATIONS

In [None]:
from pytorch_lightning.loggers import LightningLoggerBase # , rank_zero_only
from collections import defaultdict


class PrintLogger(LightningLoggerBase):
    
    def __init__(self):
        super(PrintLogger, self).__init__()
    
    @property
    def name(self):
        return 'PrintLogger'
    
    @property
    def experiment(self):
        return self.name()
    
    @property
    def version(self):
        return '0.0.1'
    
#     @rank_zero_only
    def log_hyperparams(self, params):
        print(f'Hyperparameters:\n{params}')

#     @rank_zero_only
    def log_metrics(self, metrics, step):
        if metrics is not None and len(metrics.keys()) > 0:
            out = ', '.join([f'{key}: {metrics[key]:.4f}' for key in sorted(metrics.keys())])
            print(f'{step}: {out}')

    def save(self):
        # Optional. Any code necessary to save logger data goes here
        pass

#     @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here
        pass
    
    
class PlotLogger(LightningLoggerBase):
    
    def __init__(self, metric_groups, plot_grid=None, cmap='Dark2',
                 ax_width=6, ax_height=4, print_final=True):
        super(PlotLogger, self).__init__()
        self.metric_groups = metric_groups
        if plot_grid is None:
            plot_grid = (1, len(self.metric_groups))

        self.plot_grid = plot_grid
        self.metrics = defaultdict(list)
        self.cmap = plt.get_cmap(cmap)
        self.ax_width = ax_width
        self.ax_height = ax_height
        self.print_final = print_final
    
    @property
    def name(self):
        return 'PlotLogger'
    
    @property
    def experiment(self):
        return self.name()
    
    @property
    def version(self):
        return '0.0.1'
    
#     @rank_zero_only
    def log_hyperparams(self, params):
        pass

#     @rank_zero_only
    def log_metrics(self, metrics, step):
        if metrics is not None and len(metrics.keys()) > 0:
            for key in metrics.keys():
                self.metrics[key].append(metrics[key])

    def save(self):
        # Optional. Any code necessary to save logger data goes here
        pass

#     @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here

        if self.print_final:
            out = ', '.join([f'{key}: {self.metrics[key][-1]:.4f}' for key in sorted(self.metrics.keys())])
            print(f'Final values: {out}')
        
        # plot_grid is (n_rows, n_cols)
        plt.figure(figsize=(self.ax_width * self.plot_grid[1], self.ax_height * self.plot_grid[0]))
        
        for i, group_or_name in enumerate(self.metric_groups):
            if isinstance(group_or_name, str):
                group = self.metric_groups[group_or_name]
                name = group_or_name
            else:
                group = group_or_name
                name = None
                
            ax = plt.subplot(*self.plot_grid, i + 1)
            for j, key in enumerate(sorted(group)):
                ax.plot(self.metrics[key], color=self.cmap(j), label=key)
                
            ax.legend(loc='best')
            ax.set_xlabel('Epochs')
            if name is not None:
                ax.set_title(name.title())
            
        plt.show()

In [None]:
X = {}
y = {}
object_generators = {}
positive_indices = {}
negative_indices = {}

for num_objects in (5, 10):
    
    for i, relation in enumerate(run.RELATION_NAMES_TO_CLASSES):
        key = (num_objects, relation)
        
        relation_class = run.RELATION_NAMES_TO_CLASSES[relation]
        object_generator = object_gen.SmartBalancedBatchObjectGenerator(
            num_objects, run.FIELD_CONFIGURATIONS['default'], relation_class)
        object_generators[key] = object_generator
            
        X[key], y[key] = object_generator(100)
        y[key] = y[key].bool()
        
        positive_indices[key] = torch.nonzero(y[key]).squeeze()
        negative_indices[key] = torch.nonzero(~y[key]).squeeze()
                   


In [None]:
for key, gen in object_generators.items():
    print(key)
    """
    model = RelationNetModel(gen, 
    #                          embedding_size=8,
                             object_pair_layer_sizes=[32], # 32],
    #                          object_pair_layer_sizes=[256, 256, 256],
    #                          combined_object_layer_sizes=[256, 256],
                             combined_object_layer_sizes=[32],
                             # prediction_sizes=hidden_sizes, prediction_activation_class=nn.ReLU,
                             batch_size=2**10, lr=1e-3, 
                             train_epoch_size=2**14, validation_epoch_size=2**14,
    #                          train_epoch_size=2**10, validation_epoch_size=2**10,
                             regenerate_every_epoch=False)
    """
    
    """
    model = TransformerModel(gen, 
                             embedding_size=8,
                             mlp_sizes=[32],
                             batch_size=2**10, lr=1e-3, 
                             train_epoch_size=2**14, validation_epoch_size=2**14,
                             regenerate_every_epoch=False
                            )
    """
    train_dataset = ObjectGeneratorDataset(gen, 2 ** 14)
    validation_dataset = ObjectGeneratorDataset(gen, 2 ** 14)
    
    model = CNNModel(gen,
                     conv_sizes=[16, 16], conv_output_size=256,
                     batch_size=2**10, lr=1e-3, 
                     train_epoch_size=2**14, validation_epoch_size=2**14,
                     regenerate_every_epoch=False,
                     train_dataset=train_dataset, validation_dataset=validation_dataset
                    )
    
    use_gpu = int(torch.cuda.is_available())
    loggers = []
    loggers.append(PrintLogger())
    loggers.append(PlotLogger(dict(loss=('train_loss', 'val_loss'), accuracy=('train_acc', 'val_acc'))))
    trainer = Trainer(gpus=use_gpu, max_epochs=1000, logger=loggers,
                      early_stop_callback=EarlyStopping('val_loss', patience=5))
    trainer.fit(model)


### N_OBJECTS = 4

for relation_class in (MultipleDAdjacentRelation, ColorAboveColorRelation, ObjectCountRelation):
    print(relation_class.__name__)
    gen = object_gen.SmartBalancedBatchObjectGenerator(N_OBJECTS, cfgs, relation_class,
                                                   object_dtype=torch.float, label_dtype=torch.long,
                                                   max_recursion_depth=100
                                                  )
    
    """
    model = RelationNetModel(gen, 
    #                          embedding_size=8,
                             object_pair_layer_sizes=[32], # 32],
    #                          object_pair_layer_sizes=[256, 256, 256],
    #                          combined_object_layer_sizes=[256, 256],
                             combined_object_layer_sizes=[32],
                             # prediction_sizes=hidden_sizes, prediction_activation_class=nn.ReLU,
                             batch_size=2**10, lr=1e-3, 
                             train_epoch_size=2**14, validation_epoch_size=2**14,
    #                          train_epoch_size=2**10, validation_epoch_size=2**10,
                             regenerate_every_epoch=False)
    """
    
    """
    model = TransformerModel(gen, 
                             embedding_size=8,
                             mlp_sizes=[32],
                             batch_size=2**10, lr=1e-3, 
                             train_epoch_size=2**14, validation_epoch_size=2**14,
                             regenerate_every_epoch=False
                            )
    """
    model = CNNModel(gen,
                     conv_output_size=16,
                     batch_size=2**10, lr=1e-3, 
                     train_epoch_size=2**14, validation_epoch_size=2**14,
                     regenerate_every_epoch=False
                    )
    
    use_gpu = int(torch.cuda.is_available())
    loggers = []
#     loggers.append(PrintLogger())
    loggers.append(PlotLogger(dict(loss=('train_loss', 'val_loss'), accuracy=('train_acc', 'val_acc'))))
    trainer = Trainer(gpus=use_gpu, max_epochs=1000, logger=loggers,
                      early_stop_callback=EarlyStopping('val_loss', patience=10))
    trainer.fit(model)


In [None]:
assert(False)

In [None]:
for train_size_power in (10, 12, 14, 16):
    model = RelationNetModel(gen, 
    #                          embedding_size=8,
                             object_pair_layer_sizes=[32], # 32],
    #                          object_pair_layer_sizes=[256, 256, 256],
    #                          combined_object_layer_sizes=[256, 256],
                             combined_object_layer_sizes=[32],
                             # prediction_sizes=hidden_sizes, prediction_activation_class=nn.ReLU,
                             batch_size=2**10, lr=1e-3, 
                             train_epoch_size=2**train_size_power, validation_epoch_size=2**14,
    #                          train_epoch_size=2**10, validation_epoch_size=2**10,
                             regenerate_every_epoch=False)
    
    use_gpu = int(torch.cuda.is_available())
    loggers = []
#     loggers.append(PrintLogger())
    loggers.append(PlotLogger(dict(loss=('train_loss', 'val_loss'), accuracy=('train_acc', 'val_acc'))))
    trainer = Trainer(gpus=use_gpu, max_epochs=1000, logger=loggers,
                      early_stop_callback=EarlyStopping('val_loss', patience=10))
    trainer.fit(model)

In [None]:
use_gpu = int(torch.cuda.is_available())
loggers = []
loggers.append(PrintLogger())
loggers.append(PlotLogger(dict(loss=('train_loss', 'val_loss'), accuracy=('train_acc', 'val_acc'))))
trainer = Trainer(gpus=use_gpu, max_epochs=00, logger=loggers,
                  early_stop_callback=EarlyStopping('val_loss', patience=5))
trainer.fit(model)

In [None]:
!git pull

In [None]:
model

In [None]:
N_OBJECTS = 8

gen = object_gen.SmartBalancedBatchObjectGenerator(N_OBJECTS, cfgs, MultipleDAdjacentRelation,
                                                   object_dtype=torch.float, label_dtype=torch.long,
                                                   max_recursion_depth=100
                                                  )
"""
model = RelationNetModel(gen, 
#                          embedding_size=8,
                         object_pair_layer_sizes=[32], # 32],
#                          object_pair_layer_sizes=[256, 256, 256],
#                          combined_object_layer_sizes=[256, 256],
                         combined_object_layer_sizes=[32],
                         # prediction_sizes=hidden_sizes, prediction_activation_class=nn.ReLU,
                         batch_size=2**10, lr=1e-3, 
                         train_epoch_size=2**14, validation_epoch_size=2**14,
#                          train_epoch_size=2**10, validation_epoch_size=2**10,
                         regenerate_every_epoch=False)
"""

"""
model = TransformerModel(gen, 
                         embedding_size=8,
                         mlp_sizes=[32],
                         batch_size=2**10, lr=1e-3, 
                         train_epoch_size=2**14, validation_epoch_size=2**14,
                         regenerate_every_epoch=False
                        )
"""
model = CNNModel(gen,
                 conv_sizes=[16, 32, 48], 
                 mlp_sizes=[64, 32, 16],
                 conv_output_size=192,
                 conv_stride=1,
                 batch_size=2**10, lr=1e-3, 
                 train_epoch_size=2**14, validation_epoch_size=2**14,
                 regenerate_every_epoch=False
                )

In [None]:
model.cuda()
summary(model, (6, 16, 16))

In [None]:
2 ** 16

In [None]:
model.cuda()
summary(model, (6, 16, 16))