In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sys
sys.path.append('/Users/guydavidson/projects/simple-relational-reasoning/')

import torch
from torch import nn
import torch.nn.functional as F

from pytorch_lightning import Trainer

In [3]:
from simple_relational_reasoning.datagen import object_fields
from simple_relational_reasoning.datagen import object_gen
from simple_relational_reasoning.models import MLPModel

# Create field configurations and an object generator

In [4]:
cfgs = ( 
    object_gen.FieldConfig('x', 'int_position', dict(max_coord=20)), 
    object_gen.FieldConfig('y', 'int_position', dict(max_coord=20)), 
    object_gen.FieldConfig('color', 'one_hot', dict(n_types=4))
)

gen = object_gen.BalancedBatchObjectGenerator(20, cfgs, object_gen.adjacent_relation_evaluator, 
                                 object_dtype=torch.float, label_dtype=torch.long)
X, y = gen(20)
X.shape, y.shape, y.sum()

(torch.Size([20, 20, 6]), torch.Size([20]), tensor(10))

# Create a model

In [5]:
model = MLPModel(gen, 10, torch.sigmoid, batch_size=2**10, lr=1e-3,
                 train_epoch_size=2**14, validation_epoch_size=2**11)

In [6]:
trainer = Trainer()

In [7]:
trainer.fit(model)

INFO:lightning:
  | Name                  | Type     | Params
-----------------------------------------------
0 | embedding_layer       | Linear   | 70    
1 | prediction_layer      | Linear   | 402   
2 | prediction_activation | Identity | 0     


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…

{'val_loss': tensor(0.6964), 'val_acc': tensor(0.5010)}


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …



HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6942), 'val_acc': tensor(0.5088)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6942), 'val_acc': tensor(0.4956)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6952), 'val_acc': tensor(0.4941)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6949), 'val_acc': tensor(0.4790)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6966), 'val_acc': tensor(0.4897)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6957), 'val_acc': tensor(0.4912)}


HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=2.0, style=Prog…

{'val_loss': tensor(0.6956), 'val_acc': tensor(0.4941)}


INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...





1

In [8]:
# from pytorch_lightning.loggers import LightningLoggerBase, rank_zero_only

# class PrintLogger(LightningLoggerBase):
    
#     def __init__(self):
#         super(PrintLogger, self).__init__()
    
#     @property
#     def name(self):
#         return 'Test'
    
#     @property
#     def experiment(self):
#         return self.name()
    
#     @property
#     def version(self):
#         return '0.0.1'
    
#     @rank_zero_only
#     def log_hyperparams(self, params):
#         # params is an argparse.Namespace
#         # your code to record hyperparameters goes here
#         pass

#     @rank_zero_only
#     def log_metrics(self, metrics, step):
#         print(f'{step}: {metrics}')

#     def save(self):
#         # Optional. Any code necessary to save logger data goes here
#         pass

#     @rank_zero_only
#     def finalize(self, status):
#         # Optional. Any code that needs to be run after training
#         # finishes goes here
#         pass


# trainer = Trainer(logger=[PrintLogger()])

# Scratch

In [9]:
F.cross_entropy?

In [10]:
nn.Identity()

Identity()

In [11]:
torch.rand(10, 2).argmax(1).shape

torch.Size([10])

In [12]:
torch.eq?

In [13]:
torch.rand(10, 2).view(5, -1).shape

torch.Size([5, 4])

In [14]:
r = torch.rand(10)
b = r > 0.5
r[b].shape

torch.Size([4])

In [15]:
r[b.long().bool()].shape

torch.Size([4])

In [16]:
~b, b

(tensor([ True, False, False,  True,  True, False, False,  True,  True,  True]),
 tensor([False,  True,  True, False, False,  True,  True, False, False, False]))

In [17]:
l = [torch.rand(7), torch.rand(7), torch.rand(7)]

In [18]:
torch.cat(l).shape

torch.Size([21])

In [19]:
r

tensor([0.1440, 0.7089, 0.7682, 0.4631, 0.4353, 0.5172, 0.6408, 0.4616, 0.2146,
        0.4052])

In [20]:
def foo(x=5):
    r = torch.rand(x)
    print(r)
    return r[torch.randperm(5)]


foo()

tensor([0.2999, 0.4186, 0.5046, 0.0083, 0.6585])


tensor([0.2999, 0.5046, 0.6585, 0.4186, 0.0083])