In [1]:
seed=1

In [2]:
import os

os.listdir('..')

['.git',
 '.gitignore',
 '.idea',
 'data',
 'info_nas',
 'info_nas.egg-info',
 'LICENSE',
 'notebooks',
 'README.md',
 'setup.py']

In [3]:
from nasbench import api

nasbench_path = '../data/nasbench_only108.tfrecord'
nb = api.NASBench(nasbench_path)

Loading dataset from file... This may take a few minutes...
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
Loaded dataset in 33 seconds


In [4]:
import torch
from info_nas.datasets.arch2vec_dataset import get_labeled_unlabeled_datasets

#torch.backends.cudnn.benchmark = True
#device = torch.device('cuda')
data_device = torch.device('cpu')

# device = None otherwise the dataset is save to the cuda as a whole
labeled, unlabeled = get_labeled_unlabeled_datasets(nb, percent_labeled=0.00001, device=data_device,
                                                    seed=seed)

Loading nasbench dataset (arch2vec) from ../data/nb_dataset.json
Split the dataset (percent labeled = 1e-05) - 4/381261 labeled networks chosen from the train set, 1/42363 labeled networks chosen from the validation set.
Processing labeled nets for the training set...
Loading labeled dataset from ../data/train_labeled.pt.
Processing labeled nets for the validation set...
Loading labeled dataset from ../data/valid_labeled.pt.


## Test dataset shapes

In [5]:
labeled.keys()

dict_keys(['train_io', 'train_net', 'valid_io', 'valid_net'])

In [6]:
[l.shape for l in labeled['train_io']]

[(4000,), torch.Size([4000, 128, 32, 32]), torch.Size([4000, 512, 8, 8])]

In [7]:
[l.shape for l in labeled['train_net']]

[torch.Size([4000, 7, 7]), torch.Size([4000, 7, 5])]

In [8]:
unlabeled.keys()

dict_keys(['train', 'n_train', 'val', 'n_val'])

In [9]:
unlabeled['train'][1][0].shape, unlabeled['train'][2][0].shape

(torch.Size([7, 7]), torch.Size([7, 5]))

In [10]:
len(unlabeled['train'][1])

381261

## Test model shapes

In [4]:
from arch2vec.extensions.get_nasbench101_model import get_arch2vec_model

model, opt = get_arch2vec_model()

In [11]:
from arch2vec.extensions.get_nasbench101_model import get_nasbench_datasets

nb_dataset = get_nasbench_datasets('../data/nb_dataset.json', batch_size=None, seed=1)

In [23]:
#model.train()
model.eval()

batch_adj, batch_ops = nb_dataset['train'][1][:32], nb_dataset['train'][2][:32]

mu, logvar = model._encoder(batch_ops, batch_adj)
z = model.reparameterize(mu, logvar)

In [25]:
print(f"mu shape: {mu.shape}, logvar shape: {logvar.shape}, z shape: {z.shape}")

mu shape: torch.Size([32, 7, 16]), logvar shape: torch.Size([32, 7, 16]), z shape: torch.Size([32, 7, 16])


In [34]:
import torch.nn as nn

m = nn.Sequential(
    nn.Flatten(),
    nn.Linear(z.shape[1] * z.shape[2], 5),
    nn.ReLU()
)
m(z).shape

torch.Size([32, 5])

In [43]:
# test unsqueeze and channels

conv = nn.Conv2d(32, 8, 1, padding=0)
conv(mu.unsqueeze(0)).shape

torch.Size([1, 8, 7, 16])

In [74]:
# repeat and concat
repeated = torch.Tensor([3]).repeat(mu.shape[0], mu.shape[1], 1)
print(repeated.shape)

torch.cat([mu, repeated], axis=-1).shape

torch.Size([32, 7, 1])


torch.Size([32, 7, 17])

## Extended models

In [5]:
from arch2vec.extensions.get_nasbench101_model import get_arch2vec_model
from arch2vec.extensions.get_nasbench101_model import get_nasbench_datasets

model, opt = get_arch2vec_model(device=device)

In [6]:
print(labeled['train_io']['inputs'].shape)
print(labeled['train_io']['outputs'].shape)

torch.Size([4000, 128, 32, 32])
torch.Size([4000, 512, 8, 8])


In [7]:
from info_nas.models.conv_embeddings import SimpleConvModel

extended_model = SimpleConvModel(model, 128, 512).to(device)

In [8]:
in_batch = labeled['train_io']['inputs'][:32]
out_batch = labeled['train_io']['outputs'][:32]

batch_adj, batch_ops = labeled['train_net'][0][:32], labeled['train_net'][1][:32]

In [9]:
ops_recon, adj_recon, mu, logvar, outputs = extended_model(batch_ops.to(device), batch_adj.to(device), in_batch.to(device))

In [10]:
outputs.shape

torch.Size([32, 512, 8, 8])

### Just some tests

In [None]:
_,_,_,_,z = model(batch_ops.to(device), batch_adj.to(device))
z = extended_model.process_z(z)
z = z.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, in_batch.shape[2], in_batch.shape[3])
z.shape

In [None]:
torch.cat([z, in_batch], dim=1).shape

## Training

In [6]:
import torch
from info_nas.trainer import train

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    device = torch.device('cuda')
else:
    device = None
    
model = train(labeled, unlabeled, nb, device=device, n_workers=4, n_val_workers=4, k=100)

epoch 0: batch 0 / 12040: loss: 4.45273
epoch 0: labeled batches: 1, unlabeled batches: 0
epoch 0: batch 1000 / 12040: loss: 0.52838
epoch 0: labeled batches: 100, unlabeled batches: 901
epoch 0: batch 2000 / 12040: loss: 0.43532
epoch 0: labeled batches: 100, unlabeled batches: 1901
epoch 0: batch 3000 / 12040: loss: 0.37327
epoch 0: labeled batches: 100, unlabeled batches: 2901
epoch 0: batch 4000 / 12040: loss: 0.34162
epoch 0: labeled batches: 100, unlabeled batches: 3901
epoch 0: batch 5000 / 12040: loss: 0.34020
epoch 0: labeled batches: 100, unlabeled batches: 4901
epoch 0: batch 6000 / 12040: loss: 0.31405
epoch 0: labeled batches: 100, unlabeled batches: 5901
epoch 0: batch 7000 / 12040: loss: 0.30694
epoch 0: labeled batches: 100, unlabeled batches: 6901
epoch 0: batch 8000 / 12040: loss: 0.25705
epoch 0: labeled batches: 100, unlabeled batches: 7901
epoch 0: batch 9000 / 12040: loss: 0.25581
epoch 0: labeled batches: 100, unlabeled batches: 8901
epoch 0: batch 10000 / 12040:

epoch 5: batch 8000 / 12040: loss: 0.21803
epoch 5: labeled batches: 100, unlabeled batches: 7901
epoch 5: batch 9000 / 12040: loss: 0.23295
epoch 5: labeled batches: 100, unlabeled batches: 8901
epoch 5: batch 10000 / 12040: loss: 0.22239
epoch 5: labeled batches: 100, unlabeled batches: 9901
epoch 5: batch 11000 / 12040: loss: 0.25254
epoch 5: labeled batches: 100, unlabeled batches: 10901
epoch 5: batch 12000 / 12040: loss: 0.22176
epoch 5: labeled batches: 100, unlabeled batches: 11901
Ratio of valid decodings from the prior: 0.4758
Ratio of unique decodings from the prior: 0.9891
validation set: acc_ops:96.4623, mean_corr_adj:93.0655, mean_fal_pos_adj:8.8744, acc_adj:95.5599
epoch 5: average loss 0.23573
epoch 6: batch 0 / 12040: loss: 0.57729
epoch 6: labeled batches: 1, unlabeled batches: 0
epoch 6: batch 1000 / 12040: loss: 0.24431
epoch 6: labeled batches: 100, unlabeled batches: 901
epoch 6: batch 2000 / 12040: loss: 0.22061
epoch 6: labeled batches: 100, unlabeled batches: 1