In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sys
sys.path.append('/Users/guydavidson/projects/simple-relational-reasoning/')

import torch
from torch import nn
import torch.nn.functional as F

from pytorch_lightning import Trainer

In [3]:
from simple_relational_reasoning.datagen import object_fields
from simple_relational_reasoning.datagen import object_gen
from simple_relational_reasoning.models import MLPAverageModel

# Create field configurations and an object generator

In [4]:
cfgs = ( 
    object_gen.FieldConfig('x', 'int_position'), 
    object_gen.FieldConfig('y', 'int_position'), 
    object_gen.FieldConfig('color', 'one_hot', dict(n_types=4))
)

gen = object_gen.ObjectGenerator(10, cfgs, object_gen.adjacent_relation_evaluator, 
                                 object_dtype=torch.float)
X, y = gen(20)
X.shape, y.shape

(torch.Size([20, 10, 6]), torch.Size([20]))

# Create a model

In [5]:
model = MLPAverageModel(gen, 10, torch.sigmoid, batch_size=2**10,
                        train_epoch_size=2**14, validation_epoch_size=2**11)

In [6]:
trainer = Trainer()

In [7]:
trainer.fit(model)

INFO:lightning:
  | Name                 | Type   | Params
--------------------------------------------
0 | representation_layer | Linear | 70    
1 | prediction_layer     | Linear | 11    


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…

{'val_loss': tensor(0.2314), 'val_acc': tensor(0.7856)}


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …



HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2287), 'val_acc': tensor(0.8091)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2270), 'val_acc': tensor(0.8110)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2261), 'val_acc': tensor(0.8062)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2248), 'val_acc': tensor(0.8018)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2250), 'val_acc': tensor(0.7876)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2212), 'val_acc': tensor(0.8130)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2218), 'val_acc': tensor(0.7969)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2200), 'val_acc': tensor(0.7988)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2192), 'val_acc': tensor(0.7959)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2187), 'val_acc': tensor(0.7905)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2177), 'val_acc': tensor(0.7891)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2139), 'val_acc': tensor(0.8091)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2139), 'val_acc': tensor(0.8013)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2122), 'val_acc': tensor(0.8032)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2104), 'val_acc': tensor(0.8081)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2111), 'val_acc': tensor(0.7959)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2114), 'val_acc': tensor(0.7871)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2103), 'val_acc': tensor(0.7856)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2093), 'val_acc': tensor(0.7871)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2045), 'val_acc': tensor(0.8081)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2035), 'val_acc': tensor(0.8071)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2028), 'val_acc': tensor(0.8052)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2027), 'val_acc': tensor(0.8003)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2002), 'val_acc': tensor(0.8086)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1989), 'val_acc': tensor(0.8096)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.2004), 'val_acc': tensor(0.7964)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1980), 'val_acc': tensor(0.8047)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1963), 'val_acc': tensor(0.8066)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1947), 'val_acc': tensor(0.8101)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1945), 'val_acc': tensor(0.8066)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1929), 'val_acc': tensor(0.8081)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1919), 'val_acc': tensor(0.8096)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1890), 'val_acc': tensor(0.8169)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1950), 'val_acc': tensor(0.7876)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1902), 'val_acc': tensor(0.8047)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1930), 'val_acc': tensor(0.7896)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1905), 'val_acc': tensor(0.7964)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1939), 'val_acc': tensor(0.7803)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1895), 'val_acc': tensor(0.7939)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1833), 'val_acc': tensor(0.8135)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1876), 'val_acc': tensor(0.7944)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1879), 'val_acc': tensor(0.7915)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1862), 'val_acc': tensor(0.7939)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1852), 'val_acc': tensor(0.7959)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1819), 'val_acc': tensor(0.8047)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1794), 'val_acc': tensor(0.8096)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1791), 'val_acc': tensor(0.8081)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1831), 'val_acc': tensor(0.7930)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1778), 'val_acc': tensor(0.8081)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.1792), 'val_acc': tensor(0.8013)}


INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...





1

In [None]:
# from pytorch_lightning.loggers import LightningLoggerBase, rank_zero_only

# class PrintLogger(LightningLoggerBase):
    
#     def __init__(self):
#         super(PrintLogger, self).__init__()
    
#     @property
#     def name(self):
#         return 'Test'
    
#     @property
#     def experiment(self):
#         return self.name()
    
#     @property
#     def version(self):
#         return '0.0.1'
    
#     @rank_zero_only
#     def log_hyperparams(self, params):
#         # params is an argparse.Namespace
#         # your code to record hyperparameters goes here
#         pass

#     @rank_zero_only
#     def log_metrics(self, metrics, step):
#         print(f'{step}: {metrics}')

#     def save(self):
#         # Optional. Any code necessary to save logger data goes here
#         pass

#     @rank_zero_only
#     def finalize(self, status):
#         # Optional. Any code that needs to be run after training
#         # finishes goes here
#         pass


# trainer = Trainer(logger=[PrintLogger()])