In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sys
sys.path.append('/Users/guydavidson/projects/simple-relational-reasoning/')

import torch
from torch import nn
import torch.nn.functional as F

from pytorch_lightning import Trainer

In [3]:
from simple_relational_reasoning.datagen import object_fields
from simple_relational_reasoning.datagen import object_gen
from simple_relational_reasoning.models import MLPModel

# Create field configurations and an object generator

In [4]:
cfgs = ( 
    object_gen.FieldConfig('x', 'int_position'), 
    object_gen.FieldConfig('y', 'int_position'), 
    object_gen.FieldConfig('color', 'one_hot', dict(n_types=4))
)

gen = object_gen.ObjectGenerator(10, cfgs, object_gen.adjacent_relation_evaluator, 
                                 object_dtype=torch.float, label_dtype=torch.long)
X, y = gen(20)
X.shape, y.shape

(torch.Size([20, 10, 6]), torch.Size([20]))

# Create a model

In [5]:
model = MLPModel(gen, 10, torch.sigmoid, batch_size=2**10,
                 train_epoch_size=2**14, validation_epoch_size=2**11)

In [6]:
trainer = Trainer()

In [7]:
trainer.fit(model)

INFO:lightning:
  | Name                  | Type     | Params
-----------------------------------------------
0 | embedding_layer       | Linear   | 70    
1 | prediction_layer      | Linear   | 202   
2 | prediction_activation | Identity | 0     


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…

{'val_loss': tensor(0.7334), 'val_acc': tensor(0.2910)}


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …



HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6848), 'val_acc': tensor(0.5610)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6386), 'val_acc': tensor(0.7749)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6002), 'val_acc': tensor(0.8057)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5719), 'val_acc': tensor(0.8027)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5618), 'val_acc': tensor(0.7861)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5279), 'val_acc': tensor(0.8125)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5159), 'val_acc': tensor(0.8125)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5188), 'val_acc': tensor(0.7969)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5083), 'val_acc': tensor(0.8037)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5085), 'val_acc': tensor(0.7988)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5077), 'val_acc': tensor(0.7979)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5213), 'val_acc': tensor(0.7866)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5198), 'val_acc': tensor(0.7861)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5097), 'val_acc': tensor(0.7944)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4833), 'val_acc': tensor(0.8145)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5111), 'val_acc': tensor(0.7935)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4960), 'val_acc': tensor(0.8042)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5021), 'val_acc': tensor(0.7993)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5142), 'val_acc': tensor(0.7900)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4986), 'val_acc': tensor(0.8022)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5027), 'val_acc': tensor(0.7988)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4830), 'val_acc': tensor(0.8125)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4920), 'val_acc': tensor(0.8076)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5044), 'val_acc': tensor(0.7974)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4898), 'val_acc': tensor(0.8081)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5090), 'val_acc': tensor(0.7959)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4840), 'val_acc': tensor(0.8135)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4860), 'val_acc': tensor(0.8105)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4989), 'val_acc': tensor(0.8018)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5042), 'val_acc': tensor(0.7983)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4850), 'val_acc': tensor(0.8115)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5298), 'val_acc': tensor(0.7808)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5142), 'val_acc': tensor(0.7905)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4996), 'val_acc': tensor(0.8008)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4822), 'val_acc': tensor(0.8130)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5089), 'val_acc': tensor(0.7944)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4870), 'val_acc': tensor(0.8096)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4930), 'val_acc': tensor(0.8047)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4832), 'val_acc': tensor(0.8115)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5033), 'val_acc': tensor(0.7983)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5041), 'val_acc': tensor(0.7983)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4804), 'val_acc': tensor(0.8140)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4971), 'val_acc': tensor(0.8027)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4909), 'val_acc': tensor(0.8081)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5103), 'val_acc': tensor(0.7935)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4970), 'val_acc': tensor(0.8032)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4906), 'val_acc': tensor(0.8081)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4986), 'val_acc': tensor(0.8018)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4949), 'val_acc': tensor(0.8042)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5061), 'val_acc': tensor(0.7964)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5044), 'val_acc': tensor(0.7974)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5015), 'val_acc': tensor(0.8013)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5191), 'val_acc': tensor(0.7871)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4933), 'val_acc': tensor(0.8057)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5039), 'val_acc': tensor(0.7983)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5005), 'val_acc': tensor(0.7998)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5009), 'val_acc': tensor(0.7993)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5095), 'val_acc': tensor(0.7939)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4948), 'val_acc': tensor(0.8037)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5069), 'val_acc': tensor(0.7959)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5015), 'val_acc': tensor(0.8003)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5135), 'val_acc': tensor(0.7920)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5107), 'val_acc': tensor(0.7935)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5010), 'val_acc': tensor(0.8008)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4976), 'val_acc': tensor(0.8022)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4997), 'val_acc': tensor(0.8013)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4989), 'val_acc': tensor(0.8008)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4917), 'val_acc': tensor(0.8062)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5055), 'val_acc': tensor(0.7964)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4865), 'val_acc': tensor(0.8096)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5245), 'val_acc': tensor(0.7837)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5158), 'val_acc': tensor(0.7896)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4997), 'val_acc': tensor(0.8008)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4820), 'val_acc': tensor(0.8135)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5190), 'val_acc': tensor(0.7871)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4997), 'val_acc': tensor(0.8013)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5004), 'val_acc': tensor(0.8008)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4923), 'val_acc': tensor(0.8066)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4945), 'val_acc': tensor(0.8047)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4901), 'val_acc': tensor(0.8076)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5001), 'val_acc': tensor(0.8008)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5050), 'val_acc': tensor(0.7979)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5092), 'val_acc': tensor(0.7939)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5127), 'val_acc': tensor(0.7915)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5152), 'val_acc': tensor(0.7896)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4862), 'val_acc': tensor(0.8101)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5020), 'val_acc': tensor(0.7988)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4775), 'val_acc': tensor(0.8164)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5134), 'val_acc': tensor(0.7910)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5067), 'val_acc': tensor(0.7954)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5000), 'val_acc': tensor(0.8003)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5015), 'val_acc': tensor(0.7993)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5092), 'val_acc': tensor(0.7944)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.4984), 'val_acc': tensor(0.8018)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.5149), 'val_acc': tensor(0.7896)}


INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...





1

Error in callback <bound method AutoreloadMagics.post_execute_hook of <autoreload.AutoreloadMagics object at 0x11075dc18>> (for post_execute):


KeyboardInterrupt: 

In [8]:
# from pytorch_lightning.loggers import LightningLoggerBase, rank_zero_only

# class PrintLogger(LightningLoggerBase):
    
#     def __init__(self):
#         super(PrintLogger, self).__init__()
    
#     @property
#     def name(self):
#         return 'Test'
    
#     @property
#     def experiment(self):
#         return self.name()
    
#     @property
#     def version(self):
#         return '0.0.1'
    
#     @rank_zero_only
#     def log_hyperparams(self, params):
#         # params is an argparse.Namespace
#         # your code to record hyperparameters goes here
#         pass

#     @rank_zero_only
#     def log_metrics(self, metrics, step):
#         print(f'{step}: {metrics}')

#     def save(self):
#         # Optional. Any code necessary to save logger data goes here
#         pass

#     @rank_zero_only
#     def finalize(self, status):
#         # Optional. Any code that needs to be run after training
#         # finishes goes here
#         pass


# trainer = Trainer(logger=[PrintLogger()])

# Scratch

In [9]:
F.cross_entropy?

In [10]:
nn.Identity()

Identity()

In [11]:
torch.rand(10, 2).argmax(1).shape

torch.Size([10])

In [12]:
torch.eq?

In [13]:
torch.rand(10, 2).view(5, -1).shape

torch.Size([5, 4])