In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import sys
sys.path.append(os.path.abspath('..'))

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
from torchsummary import summary

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

In [3]:
torch.manual_seed(33)

<torch._C.Generator at 0x7f4d27f84390>

In [4]:
from simple_relational_reasoning.datagen import *
from simple_relational_reasoning.models import *

In [5]:
from pytorch_lightning.loggers import LightningLoggerBase, rank_zero_only
from collections import defaultdict


class PrintLogger(LightningLoggerBase):
    
    def __init__(self):
        super(PrintLogger, self).__init__()
    
    @property
    def name(self):
        return 'PrintLogger'
    
    @property
    def experiment(self):
        return self.name()
    
    @property
    def version(self):
        return '0.0.1'
    
    @rank_zero_only
    def log_hyperparams(self, params):
        print(f'Hyperparameters:\n{params}')

    @rank_zero_only
    def log_metrics(self, metrics, step):
        if metrics is not None and len(metrics.keys()) > 0:
            out = ', '.join([f'{key}: {metrics[key]:.4f}' for key in sorted(metrics.keys())])
            print(f'{step}: {out}')

    def save(self):
        # Optional. Any code necessary to save logger data goes here
        pass

    @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here
        pass
    
    
class PlotLogger(LightningLoggerBase):
    
    def __init__(self, metric_groups, plot_grid=None, cmap='Dark2',
                 ax_width=6, ax_height=4, print_final=True):
        super(PlotLogger, self).__init__()
        self.metric_groups = metric_groups
        if plot_grid is None:
            plot_grid = (1, len(self.metric_groups))

        self.plot_grid = plot_grid
        self.metrics = defaultdict(list)
        self.cmap = plt.get_cmap(cmap)
        self.ax_width = ax_width
        self.ax_height = ax_height
        self.print_final = print_final
    
    @property
    def name(self):
        return 'PlotLogger'
    
    @property
    def experiment(self):
        return self.name()
    
    @property
    def version(self):
        return '0.0.1'
    
    @rank_zero_only
    def log_hyperparams(self, params):
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        if metrics is not None and len(metrics.keys()) > 0:
            for key in metrics.keys():
                self.metrics[key].append(metrics[key])

    def save(self):
        # Optional. Any code necessary to save logger data goes here
        pass

    @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here

        if self.print_final:
            out = ', '.join([f'{key}: {self.metrics[key][-1]:.4f}' for key in sorted(self.metrics.keys())])
            print(f'Final values: {out}')
        
        # plot_grid is (n_rows, n_cols)
        plt.figure(figsize=(self.ax_width * self.plot_grid[1], self.ax_height * self.plot_grid[0]))
        
        for i, group_or_name in enumerate(self.metric_groups):
            if isinstance(group_or_name, str):
                group = self.metric_groups[group_or_name]
                name = group_or_name
            else:
                group = group_or_name
                name = None
                
            ax = plt.subplot(*self.plot_grid, i + 1)
            for j, key in enumerate(sorted(group)):
                ax.plot(self.metrics[key], color=self.cmap(j), label=key)
                
            ax.legend(loc='best')
            ax.set_xlabel('Epochs')
            if name is not None:
                ax.set_title(name.title())
            
        plt.show()

In [6]:
cfgs = ( 
    object_gen.FieldConfig('x', 'int_position', dict(max_coord=20)), 
    object_gen.FieldConfig('y', 'int_position', dict(max_coord=20)), 
    object_gen.FieldConfig('color', 'one_hot', dict(n_types=2)),
    object_gen.FieldConfig('shape', 'one_hot', dict(n_types=2))
)

# gen = object_gen.ObjectGenerator(8, cfgs, 
gen = object_gen.SmartBalancedBatchObjectGenerator(8, cfgs, 
                                                   MultipleDAdjacentRelation, 
#                                                    IdenticalObjectsRelation,
#                                                    ObjectCountRelation,
                                                   object_dtype=torch.float, label_dtype=torch.long,
#                                                    max_recursion_depth=100
                                                  )

X, y = gen(2 ** 14)
X.shape, y.shape, y.sum(), y.sum().float() / y.shape[0]

(torch.Size([16384, 8, 6]), torch.Size([16384]), tensor(8192), tensor(0.5000))

In [None]:
N_OBJECTS = 8

for relation_class in (MultipleDAdjacentRelation, ColorAboveColorRelation, ObjectCountRelation):
    print(relation_class.__name__)
    gen = object_gen.SmartBalancedBatchObjectGenerator(N_OBJECTS, cfgs, relation_class,
                                                   object_dtype=torch.float, label_dtype=torch.long,
                                                   max_recursion_depth=100
                                                  )
    """
    model = RelationNetModel(gen, 
    #                          embedding_size=8,
                             object_pair_layer_sizes=[32], # 32],
    #                          object_pair_layer_sizes=[256, 256, 256],
    #                          combined_object_layer_sizes=[256, 256],
                             combined_object_layer_sizes=[32],
                             # prediction_sizes=hidden_sizes, prediction_activation_class=nn.ReLU,
                             batch_size=2**10, lr=1e-3, 
                             train_epoch_size=2**14, validation_epoch_size=2**14,
    #                          train_epoch_size=2**10, validation_epoch_size=2**10,
                             regenerate_every_epoch=False)
    """
    
    """
    model = TransformerModel(gen, 
                             embedding_size=8,
                             mlp_sizes=[32],
                             batch_size=2**10, lr=1e-3, 
                             train_epoch_size=2**14, validation_epoch_size=2**14,
                             regenerate_every_epoch=False
                            )
    """
    model = CNNModel(gen,
                     conv_output_size=16,
                     batch_size=2**10, lr=1e-3, 
                     train_epoch_size=2**14, validation_epoch_size=2**14,
                     regenerate_every_epoch=False
                    )
    
    use_gpu = int(torch.cuda.is_available())
    loggers = []
#     loggers.append(PrintLogger())
    loggers.append(PlotLogger(dict(loss=('train_loss', 'val_loss'), accuracy=('train_acc', 'val_acc'))))
    trainer = Trainer(gpus=use_gpu, max_epochs=1000, logger=loggers,
                      early_stop_callback=EarlyStopping('val_loss', patience=20))
    trainer.fit(model)


In [None]:
N_OBJECTS = 4

for relation_class in (MultipleDAdjacentRelation, ColorAboveColorRelation, ObjectCountRelation):
    print(relation_class.__name__)
    gen = object_gen.SmartBalancedBatchObjectGenerator(N_OBJECTS, cfgs, relation_class,
                                                   object_dtype=torch.float, label_dtype=torch.long,
                                                   max_recursion_depth=100
                                                  )
    
    """
    model = RelationNetModel(gen, 
    #                          embedding_size=8,
                             object_pair_layer_sizes=[32], # 32],
    #                          object_pair_layer_sizes=[256, 256, 256],
    #                          combined_object_layer_sizes=[256, 256],
                             combined_object_layer_sizes=[32],
                             # prediction_sizes=hidden_sizes, prediction_activation_class=nn.ReLU,
                             batch_size=2**10, lr=1e-3, 
                             train_epoch_size=2**14, validation_epoch_size=2**14,
    #                          train_epoch_size=2**10, validation_epoch_size=2**10,
                             regenerate_every_epoch=False)
    """
    
    """
    model = TransformerModel(gen, 
                             embedding_size=8,
                             mlp_sizes=[32],
                             batch_size=2**10, lr=1e-3, 
                             train_epoch_size=2**14, validation_epoch_size=2**14,
                             regenerate_every_epoch=False
                            )
    """
    model = CNNModel(gen,
                     conv_output_size=16,
                     batch_size=2**10, lr=1e-3, 
                     train_epoch_size=2**14, validation_epoch_size=2**14,
                     regenerate_every_epoch=False
                    )
    
    use_gpu = int(torch.cuda.is_available())
    loggers = []
#     loggers.append(PrintLogger())
    loggers.append(PlotLogger(dict(loss=('train_loss', 'val_loss'), accuracy=('train_acc', 'val_acc'))))
    trainer = Trainer(gpus=use_gpu, max_epochs=1000, logger=loggers,
                      early_stop_callback=EarlyStopping('val_loss', patience=10))
    trainer.fit(model)


In [None]:
assert(False)

In [None]:
for train_size_power in (10, 12, 14, 16):
    model = RelationNetModel(gen, 
    #                          embedding_size=8,
                             object_pair_layer_sizes=[32], # 32],
    #                          object_pair_layer_sizes=[256, 256, 256],
    #                          combined_object_layer_sizes=[256, 256],
                             combined_object_layer_sizes=[32],
                             # prediction_sizes=hidden_sizes, prediction_activation_class=nn.ReLU,
                             batch_size=2**10, lr=1e-3, 
                             train_epoch_size=2**train_size_power, validation_epoch_size=2**14,
    #                          train_epoch_size=2**10, validation_epoch_size=2**10,
                             regenerate_every_epoch=False)
    
    use_gpu = int(torch.cuda.is_available())
    loggers = []
#     loggers.append(PrintLogger())
    loggers.append(PlotLogger(dict(loss=('train_loss', 'val_loss'), accuracy=('train_acc', 'val_acc'))))
    trainer = Trainer(gpus=use_gpu, max_epochs=1000, logger=loggers,
                      early_stop_callback=EarlyStopping('val_loss', patience=10))
    trainer.fit(model)

In [None]:
use_gpu = int(torch.cuda.is_available())
loggers = []
loggers.append(PrintLogger())
loggers.append(PlotLogger(dict(loss=('train_loss', 'val_loss'), accuracy=('train_acc', 'val_acc'))))
trainer = Trainer(gpus=use_gpu, max_epochs=00, logger=loggers,
                  early_stop_callback=EarlyStopping('val_loss', patience=5))
trainer.fit(model)

In [None]:
!git pull

In [None]:
model

In [13]:
N_OBJECTS = 8

gen = object_gen.SmartBalancedBatchObjectGenerator(N_OBJECTS, cfgs, MultipleDAdjacentRelation,
                                                   object_dtype=torch.float, label_dtype=torch.long,
                                                   max_recursion_depth=100
                                                  )
"""
model = RelationNetModel(gen, 
#                          embedding_size=8,
                         object_pair_layer_sizes=[32], # 32],
#                          object_pair_layer_sizes=[256, 256, 256],
#                          combined_object_layer_sizes=[256, 256],
                         combined_object_layer_sizes=[32],
                         # prediction_sizes=hidden_sizes, prediction_activation_class=nn.ReLU,
                         batch_size=2**10, lr=1e-3, 
                         train_epoch_size=2**14, validation_epoch_size=2**14,
#                          train_epoch_size=2**10, validation_epoch_size=2**10,
                         regenerate_every_epoch=False)
"""

"""
model = TransformerModel(gen, 
                         embedding_size=8,
                         mlp_sizes=[32],
                         batch_size=2**10, lr=1e-3, 
                         train_epoch_size=2**14, validation_epoch_size=2**14,
                         regenerate_every_epoch=False
                        )
"""
model = CNNModel(gen,
                 conv_sizes=[16, 32, 48], 
                 mlp_sizes=[64, 32, 16],
                 conv_output_size=192,
                 conv_stride=1,
                 batch_size=2**10, lr=1e-3, 
                 train_epoch_size=2**14, validation_epoch_size=2**14,
                 regenerate_every_epoch=False
                )

In [14]:
model.cuda()
summary(model, (6, 16, 16))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 16, 16]             880
              ReLU-2           [-1, 16, 16, 16]               0
         MaxPool2d-3             [-1, 16, 8, 8]               0
            Conv2d-4             [-1, 32, 8, 8]           4,640
              ReLU-5             [-1, 32, 8, 8]               0
         MaxPool2d-6             [-1, 32, 4, 4]               0
            Conv2d-7             [-1, 48, 4, 4]          13,872
              ReLU-8             [-1, 48, 4, 4]               0
         MaxPool2d-9             [-1, 48, 2, 2]               0
           Linear-10                   [-1, 64]          12,352
             ReLU-11                   [-1, 64]               0
           Linear-12                   [-1, 32]           2,080
             ReLU-13                   [-1, 32]               0
           Linear-14                   

In [15]:
2 ** 16

65536

In [12]:
model.cuda()
summary(model, (6, 16, 16))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 16, 16]             440
              ReLU-2            [-1, 8, 16, 16]               0
         MaxPool2d-3              [-1, 8, 8, 8]               0
            Conv2d-4             [-1, 16, 8, 8]           1,168
              ReLU-5             [-1, 16, 8, 8]               0
         MaxPool2d-6             [-1, 16, 4, 4]               0
            Linear-7                   [-1, 32]           8,224
              ReLU-8                   [-1, 32]               0
            Linear-9                   [-1, 32]           1,056
             ReLU-10                   [-1, 32]               0
           Linear-11                    [-1, 2]              66
         Identity-12                    [-1, 2]               0
Total params: 10,954
Trainable params: 10,954
Non-trainable params: 0
---------------------------------