## TEST MODELS INITIALISATION

In [1]:
import sys
sys.path.append('../')

import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils

import datasets
import models
import importlib
importlib.reload(datasets)
importlib.reload(models)

<module 'models' from '/Users/francesco/Projects/rhm-master/random-hierarchy-model/models/../models/__init__.py'>

## IMPORT A HIERARCHICAL DATASET TO TEST THE MODELS

In [9]:
v = 16
m = 4
L = 4
n = 16
s = 2

input_size = s**L # number of pixels, actual input size is (input_size x num_features) because of one-hot encoding
num_data = n * (m**((s**L-1)/(s-1))) # total number of data

seed_rules = 12345678 # seed of the random hierarchy model
train_size = 1024 # size of the training set
test_size = 0 # size of the test set

# to generate the full dataset: set trainset=num_data, test_size=0

dataset = datasets.RandomHierarchyModel(
    num_features=v, # vocabulary size
    num_synonyms=m, # features multiplicity
    num_layers=L, # number of layers
    num_classes=n, # number of classes
    tuple_size=s, # number of branches of the tree
    seed_rules=seed_rules,
    train_size=train_size,
    test_size=test_size,
    input_format='onehot',
    whitening=0 # 1 to whiten the input
)

x = dataset.features
y = dataset.labels
print('input: tensor of size', x.size())
print('outputs: tensor of size', y.size())
print('total dataset size:', num_data)

input: tensor of size torch.Size([1024, 16, 16])
outputs: tensor of size torch.Size([1024])
total dataset size: 17179869184.0


In [10]:
depth = 4
width = 512

model_fcn = models.MLP( input_size*v, width, n, depth)
print(model_fcn)

for i in range(depth):
    print(f'{i+1}-th layer weights, size:', model_fcn.hidden[i][0].weight.size())
print(f'readout weights, size:', model_fcn.readout.size())

model_y = model_fcn(x.flatten(start_dim=1))
print(model_y.size())

param_count = sum([p.numel() for p in model_fcn.parameters()])
print(param_count)

MLP(
  (hidden): Sequential(
    (0): Sequential(
      (0): MyLinear()
      (1): ReLU()
    )
    (1): Sequential(
      (0): MyLinear()
      (1): ReLU()
    )
    (2): Sequential(
      (0): MyLinear()
      (1): ReLU()
    )
    (3): Sequential(
      (0): MyLinear()
      (1): ReLU()
    )
  )
)
1-th layer weights, size: torch.Size([512, 256])
2-th layer weights, size: torch.Size([512, 512])
3-th layer weights, size: torch.Size([512, 512])
4-th layer weights, size: torch.Size([512, 512])
readout weights, size: torch.Size([512, 16])
torch.Size([1024, 16])
925696


In [16]:
depth = 4
width = 724

model_cnn = models.hCNN( input_size, s, v, width, n, depth)

print(model_cnn)

for i in range(depth):
    print(f'{i+1}-th layer weights, size:', model_cnn.hidden[i][0].filter.size())
print(f'readout weights, size:', model_cnn.readout.size())

model_y = x
for i in range(depth):
    model_y = model_cnn.hidden[i][0](model_y).relu()
    print(f'{i+1}-th hidden rep. size:', model_y.size())
model_y = model_cnn(x)
print(model_y.size())

param_count = sum([p.numel() for p in model_cnn.parameters()])
print(param_count)

hCNN(
  (hidden): Sequential(
    (0): Sequential(
      (0): MyConv1d()
      (1): ReLU()
    )
    (1): Sequential(
      (0): MyConv1d()
      (1): ReLU()
    )
    (2): Sequential(
      (0): MyConv1d()
      (1): ReLU()
    )
    (3): Sequential(
      (0): MyConv1d()
      (1): ReLU()
    )
  )
)
1-th layer weights, size: torch.Size([724, 16, 2])
2-th layer weights, size: torch.Size([724, 724, 2])
3-th layer weights, size: torch.Size([724, 724, 2])
4-th layer weights, size: torch.Size([724, 724, 2])
readout weights, size: torch.Size([724, 16])
1-th hidden rep. size: torch.Size([1024, 724, 8])
2-th hidden rep. size: torch.Size([1024, 724, 4])
3-th hidden rep. size: torch.Size([1024, 724, 2])
4-th hidden rep. size: torch.Size([1024, 724, 1])
torch.Size([1024, 16])
3179808


In [17]:
depth = 4
width = 512

model_cnn = models.hLCN( input_size, s, v, width, n, depth)

print(model_cnn)

for i in range(depth):
    print(f'{i+1}-th layer weights, size:', model_cnn.hidden[i][0].filter.size())
print(f'readout weights, size:', model_cnn.readout.size())

model_y = x
for i in range(depth):
    model_y = model_cnn.hidden[i][0](model_y).relu()
    print(f'{i+1}-th hidden rep. size:', model_y.size())
model_y = model_cnn(x)
print(model_y.size())

param_count = sum([p.numel() for p in model_cnn.parameters()])
print(param_count)

hLCN(
  (hidden): Sequential(
    (0): Sequential(
      (0): MyLoc1d()
      (1): ReLU()
    )
    (1): Sequential(
      (0): MyLoc1d()
      (1): ReLU()
    )
    (2): Sequential(
      (0): MyLoc1d()
      (1): ReLU()
    )
    (3): Sequential(
      (0): MyLoc1d()
      (1): ReLU()
    )
  )
)
1-th layer weights, size: torch.Size([512, 16, 8, 2])
2-th layer weights, size: torch.Size([512, 512, 4, 2])
3-th layer weights, size: torch.Size([512, 512, 2, 2])
4-th layer weights, size: torch.Size([512, 512, 1, 2])
readout weights, size: torch.Size([512, 16])
1-th hidden rep. size: torch.Size([1024, 512, 8])
2-th hidden rep. size: torch.Size([1024, 512, 4])
3-th hidden rep. size: torch.Size([1024, 512, 2])
4-th hidden rep. size: torch.Size([1024, 512, 1])
torch.Size([1024, 16])
3809280


In [None]:
embedding_dim = 256
num_heads = 8
depth = 3

model_mla = models.MLA( v, input_size, embedding_dim, num_heads, depth)
print(model_mla)
print('embedding:', model_mla.token_embedding.size())
print('readout size:', model_mla.readout.size())


model_y = x.transpose(1,2)

# for i in range(depth):
#     model_y = model_mla.blocks[i].sa(model_y)
#     print(f'{i+1}-th hidden rep. size:', model_y.size())
model_y = model_mla(model_y)
print(model_y.size())

MLA(
  (position_embedding): Embedding(8, 256)
  (blocks): Sequential(
    (0): AttentionBlock(
      (sa): MultiHeadAttention()
    )
    (1): AttentionBlock(
      (sa): MultiHeadAttention()
    )
    (2): AttentionBlock(
      (sa): MultiHeadAttention()
    )
  )
)
embedding: torch.Size([256, 8])
readout size: torch.Size([8, 256])
torch.Size([1024, 8])
