In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sys
sys.path.append('/Users/guydavidson/projects/simple-relational-reasoning/')

import torch
from torch import nn
import torch.nn.functional as F

from pytorch_lightning import Trainer

In [3]:
from simple_relational_reasoning.datagen import object_fields
from simple_relational_reasoning.datagen import object_gen
from simple_relational_reasoning.models import MLPModel

# Create field configurations and an object generator

In [4]:
cfgs = ( 
    object_gen.FieldConfig('x', 'int_position', dict(max_coord=40)), 
    object_gen.FieldConfig('y', 'int_position', dict(max_coord=40)), 
    object_gen.FieldConfig('color', 'one_hot', dict(n_types=4))
)

gen = object_gen.SmartBalancedBatchObjectGenerator(20, cfgs, object_gen.adjacent_relation_evaluator, 
                                                         object_gen.adjacent_relation_balancer,
                                                         object_dtype=torch.float, label_dtype=torch.long)

X, y = gen(20)
X.shape, y.shape, y.sum()

(torch.Size([20, 20, 6]), torch.Size([20]), tensor(10))

# Create a model

In [5]:
model = MLPModel(gen, 10, torch.sigmoid, batch_size=2**10, lr=1e-3,
                 train_epoch_size=2**14, validation_epoch_size=2**11)

In [6]:
trainer = Trainer()

In [None]:
trainer.fit(model)

INFO:lightning:
  | Name                  | Type     | Params
-----------------------------------------------
0 | embedding_layer       | Linear   | 70    
1 | prediction_layer      | Linear   | 402   
2 | prediction_activation | Identity | 0     


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…

{'val_loss': tensor(0.8530), 'val_acc': tensor(0.5000)}


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …



HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.7123), 'val_acc': tensor(0.5010)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6997), 'val_acc': tensor(0.5063)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6938), 'val_acc': tensor(0.5122)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6938), 'val_acc': tensor(0.5054)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6948), 'val_acc': tensor(0.4990)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6968), 'val_acc': tensor(0.4834)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6937), 'val_acc': tensor(0.5005)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6929), 'val_acc': tensor(0.5220)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6924), 'val_acc': tensor(0.5181)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6937), 'val_acc': tensor(0.5049)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6939), 'val_acc': tensor(0.5137)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6906), 'val_acc': tensor(0.5405)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6922), 'val_acc': tensor(0.5073)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6938), 'val_acc': tensor(0.5054)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6927), 'val_acc': tensor(0.5020)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6931), 'val_acc': tensor(0.5088)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6917), 'val_acc': tensor(0.5181)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6925), 'val_acc': tensor(0.5127)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6926), 'val_acc': tensor(0.5225)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6890), 'val_acc': tensor(0.5396)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6910), 'val_acc': tensor(0.5234)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6906), 'val_acc': tensor(0.5288)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6934), 'val_acc': tensor(0.5054)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6954), 'val_acc': tensor(0.5156)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6932), 'val_acc': tensor(0.5117)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6922), 'val_acc': tensor(0.5117)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6913), 'val_acc': tensor(0.5273)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6903), 'val_acc': tensor(0.5312)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6919), 'val_acc': tensor(0.5293)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6881), 'val_acc': tensor(0.5396)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6935), 'val_acc': tensor(0.5259)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6925), 'val_acc': tensor(0.5273)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6954), 'val_acc': tensor(0.5117)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6929), 'val_acc': tensor(0.5239)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6950), 'val_acc': tensor(0.4956)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6923), 'val_acc': tensor(0.5234)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6917), 'val_acc': tensor(0.5229)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6915), 'val_acc': tensor(0.5229)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6921), 'val_acc': tensor(0.5234)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6881), 'val_acc': tensor(0.5405)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6952), 'val_acc': tensor(0.5142)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6900), 'val_acc': tensor(0.5405)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6975), 'val_acc': tensor(0.5103)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6962), 'val_acc': tensor(0.5132)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6931), 'val_acc': tensor(0.5190)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6899), 'val_acc': tensor(0.5312)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6930), 'val_acc': tensor(0.5239)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6910), 'val_acc': tensor(0.5381)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6895), 'val_acc': tensor(0.5215)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6989), 'val_acc': tensor(0.5024)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6909), 'val_acc': tensor(0.5444)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6942), 'val_acc': tensor(0.5210)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6916), 'val_acc': tensor(0.5366)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6934), 'val_acc': tensor(0.5220)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6910), 'val_acc': tensor(0.5288)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6930), 'val_acc': tensor(0.5068)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6954), 'val_acc': tensor(0.5132)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6924), 'val_acc': tensor(0.5176)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6907), 'val_acc': tensor(0.5332)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6970), 'val_acc': tensor(0.5171)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6940), 'val_acc': tensor(0.5249)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6907), 'val_acc': tensor(0.5259)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6964), 'val_acc': tensor(0.5034)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6932), 'val_acc': tensor(0.5161)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6889), 'val_acc': tensor(0.5303)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6915), 'val_acc': tensor(0.5264)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6899), 'val_acc': tensor(0.5273)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6945), 'val_acc': tensor(0.5156)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6909), 'val_acc': tensor(0.5269)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6924), 'val_acc': tensor(0.5332)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6975), 'val_acc': tensor(0.5142)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6915), 'val_acc': tensor(0.5293)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6911), 'val_acc': tensor(0.5283)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6916), 'val_acc': tensor(0.5356)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6915), 'val_acc': tensor(0.5176)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6950), 'val_acc': tensor(0.5049)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6916), 'val_acc': tensor(0.5259)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6900), 'val_acc': tensor(0.5483)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6901), 'val_acc': tensor(0.5327)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6920), 'val_acc': tensor(0.5303)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6968), 'val_acc': tensor(0.5132)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6931), 'val_acc': tensor(0.5249)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6927), 'val_acc': tensor(0.5264)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6952), 'val_acc': tensor(0.5200)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6949), 'val_acc': tensor(0.5283)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6904), 'val_acc': tensor(0.5254)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6888), 'val_acc': tensor(0.5303)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6936), 'val_acc': tensor(0.5264)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6949), 'val_acc': tensor(0.5391)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6956), 'val_acc': tensor(0.5127)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6927), 'val_acc': tensor(0.5327)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6968), 'val_acc': tensor(0.5215)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6940), 'val_acc': tensor(0.5283)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6935), 'val_acc': tensor(0.5386)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6917), 'val_acc': tensor(0.5278)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6919), 'val_acc': tensor(0.5337)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6942), 'val_acc': tensor(0.5210)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6924), 'val_acc': tensor(0.5371)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6930), 'val_acc': tensor(0.5293)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6924), 'val_acc': tensor(0.5249)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6947), 'val_acc': tensor(0.5107)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6932), 'val_acc': tensor(0.5249)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6908), 'val_acc': tensor(0.5425)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6967), 'val_acc': tensor(0.5098)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6902), 'val_acc': tensor(0.5332)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6963), 'val_acc': tensor(0.5161)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6932), 'val_acc': tensor(0.5220)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6907), 'val_acc': tensor(0.5288)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6970), 'val_acc': tensor(0.5122)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.7017), 'val_acc': tensor(0.4941)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6932), 'val_acc': tensor(0.5298)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6952), 'val_acc': tensor(0.5205)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6912), 'val_acc': tensor(0.5181)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6939), 'val_acc': tensor(0.5327)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6890), 'val_acc': tensor(0.5361)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6923), 'val_acc': tensor(0.5322)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6921), 'val_acc': tensor(0.5356)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6953), 'val_acc': tensor(0.5215)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6954), 'val_acc': tensor(0.5127)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6945), 'val_acc': tensor(0.5244)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6978), 'val_acc': tensor(0.5269)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6934), 'val_acc': tensor(0.5215)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6946), 'val_acc': tensor(0.5171)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6925), 'val_acc': tensor(0.5312)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6952), 'val_acc': tensor(0.5234)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.7000), 'val_acc': tensor(0.4995)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6972), 'val_acc': tensor(0.5142)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6960), 'val_acc': tensor(0.5225)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6913), 'val_acc': tensor(0.5361)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6923), 'val_acc': tensor(0.5215)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6941), 'val_acc': tensor(0.5151)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6948), 'val_acc': tensor(0.5293)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6953), 'val_acc': tensor(0.5210)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6940), 'val_acc': tensor(0.5298)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6968), 'val_acc': tensor(0.5151)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6929), 'val_acc': tensor(0.5371)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6895), 'val_acc': tensor(0.5371)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6929), 'val_acc': tensor(0.5210)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6964), 'val_acc': tensor(0.5107)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6991), 'val_acc': tensor(0.5083)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6879), 'val_acc': tensor(0.5371)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6968), 'val_acc': tensor(0.5249)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6915), 'val_acc': tensor(0.5386)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6965), 'val_acc': tensor(0.5107)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6938), 'val_acc': tensor(0.5371)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6940), 'val_acc': tensor(0.5176)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6961), 'val_acc': tensor(0.5117)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6915), 'val_acc': tensor(0.5254)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6955), 'val_acc': tensor(0.5234)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6966), 'val_acc': tensor(0.5137)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6955), 'val_acc': tensor(0.5288)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6999), 'val_acc': tensor(0.5059)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6958), 'val_acc': tensor(0.5200)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6945), 'val_acc': tensor(0.5098)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6969), 'val_acc': tensor(0.5151)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6958), 'val_acc': tensor(0.5205)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6902), 'val_acc': tensor(0.5347)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6926), 'val_acc': tensor(0.5303)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6920), 'val_acc': tensor(0.5259)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6948), 'val_acc': tensor(0.5264)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6896), 'val_acc': tensor(0.5288)}


In [None]:
# from pytorch_lightning.loggers import LightningLoggerBase, rank_zero_only

# class PrintLogger(LightningLoggerBase):
    
#     def __init__(self):
#         super(PrintLogger, self).__init__()
    
#     @property
#     def name(self):
#         return 'Test'
    
#     @property
#     def experiment(self):
#         return self.name()
    
#     @property
#     def version(self):
#         return '0.0.1'
    
#     @rank_zero_only
#     def log_hyperparams(self, params):
#         # params is an argparse.Namespace
#         # your code to record hyperparameters goes here
#         pass

#     @rank_zero_only
#     def log_metrics(self, metrics, step):
#         print(f'{step}: {metrics}')

#     def save(self):
#         # Optional. Any code necessary to save logger data goes here
#         pass

#     @rank_zero_only
#     def finalize(self, status):
#         # Optional. Any code that needs to be run after training
#         # finishes goes here
#         pass


# trainer = Trainer(logger=[PrintLogger()])

# Scratch

In [None]:
cfgs = ( 
    object_gen.FieldConfig('x', 'int_position', dict(max_coord=40)), 
    object_gen.FieldConfig('y', 'int_position', dict(max_coord=40)), 
    object_gen.FieldConfig('color', 'one_hot', dict(n_types=4))
)

gen = object_gen.ObjectGenerator(20, cfgs, object_gen.adjacent_relation_evaluator, 
                                 object_dtype=torch.float, label_dtype=torch.long)

balanced_gen = object_gen.BalancedBatchObjectGenerator(20, cfgs, object_gen.adjacent_relation_evaluator, 
                                                         object_dtype=torch.float, label_dtype=torch.long)

smart_gen = object_gen.SmartBalancedBatchObjectGenerator(20, cfgs, object_gen.adjacent_relation_evaluator, 
                                                         object_gen.adjacent_relation_balancer,
                                                         object_dtype=torch.float, label_dtype=torch.long)

smart_gen_test = object_gen.SmartBalancedBatchObjectGeneratorTest(20, cfgs, object_gen.adjacent_relation_evaluator, 
                                                         object_gen.adjacent_relation_balancer_test,
                                                         object_dtype=torch.float, label_dtype=torch.long)

In [None]:
import timeit
batch_size = int(2 ** 14)
number = int(2 ** 6)

for g in (gen, balanced_gen, smart_gen, smart_gen_test):
    print(g.__class__.__name__, timeit.timeit(lambda: g(batch_size), number=number))

In [None]:
torch.rand(10, 2).view(5, -1).shape

In [None]:
r = torch.rand(10)
b = r > 0.5
r[b].shape

In [None]:
r[b.long().bool()].shape

In [None]:
~b, b

In [None]:
l = [torch.rand(7), torch.rand(7), torch.rand(7)]

In [None]:
torch.cat(l).shape

In [None]:
r

In [None]:
def foo(x=5):
    r = torch.rand(x)
    print(r)
    return r[torch.randperm(5)]


foo()

In [None]:
b

In [None]:
indices_to_modify = (b == False).nonzero().squeeze()
num_samples_to_modify = 3
indices_to_modify[torch.randperm(indices_to_modify.shape[0])[:num_samples_to_modify]]

In [None]:
a = torch.randint(2, (20,))

In [None]:
a[0] = False

In [None]:
x, y = torch.randperm(10)[:2]
x, y

In [None]:
torch.randint?

In [None]:
torch.rand(tuple())

In [None]:
torch.sign(torch.rand(-0.5, 0.5))