In [1]:
from info_nas.datasets.networks.pretrained import pretrain_network_cifar
from info_nas.datasets.networks.utils import load_nasbench
from nasbench_pytorch.datasets.cifar10 import prepare_dataset

In [3]:
seed = 42
batch_size = 32
num_workers = 2

dataset = prepare_dataset(batch_size, root='../data/', validation_size=1000, random_state=seed,
                          num_workers=num_workers)
train, n_train, val, n_val, test, n_test = dataset

--- Preparing CIFAR10 Data ---
Files already downloaded and verified
Files already downloaded and verified
--- CIFAR10 Data Prepared ---


In [11]:
import os

checkpoint_dir = '../data/checkpoints/nasbench_10-epochs_cifar-10/'

os.listdir(f'{checkpoint_dir}')

['0000cb5372fad9d62c470df699ac6d52.tar',
 '0000fa05697179112aaf69c2f0a51a0f.tar',
 '0001a2f6c8977346ccd12fa0c435bf42.tar',
 '0001b322e18b86665c067ffa09e46897.tar']

In [2]:
from nasbench import api

nasbench_path = '../data/nasbench_only108.tfrecord'
nb = api.NASBench(nasbench_path)

Loading dataset from file... This may take a few minutes...
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
Loaded dataset in 148 seconds


In [13]:
import torch
from info_nas.datasets.networks.utils import load_trained_net

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

net_path = f'{checkpoint_dir}0000cb5372fad9d62c470df699ac6d52.tar'
net, info = load_trained_net(net_path, nb)
net = net.to(device)
net

In [14]:
for batch_idx, (inputs, targets) in enumerate(val):
    inputs, targets = inputs.to(device), targets.to(device)
    break
    
print(inputs.shape)
print(targets.shape)

torch.Size([32, 3, 32, 32])
torch.Size([32])


In [17]:
nasbench = load_nasbench(nasbench_path)

Loading dataset from file... This may take a few minutes...
Loaded dataset in 141 seconds


In [25]:
from nasbench_pytorch.model import Network as NBNetwork

net = nasbench[0]
net = NBNetwork((net[2], net[1]), 10)

In [26]:
with torch.no_grad():
    outputs = net(inputs)
    
outputs.shape

torch.Size([32, 10])

In [27]:
with torch.no_grad():
    out_list = net.get_cell_outputs(inputs, return_inputs=False)
    
[print(o.shape) for o in out_list]
print()

torch.Size([32, 128, 32, 32])
torch.Size([32, 128, 32, 32])
torch.Size([32, 128, 32, 32])
torch.Size([32, 256, 16, 16])
torch.Size([32, 256, 16, 16])
torch.Size([32, 256, 16, 16])
torch.Size([32, 512, 8, 8])
torch.Size([32, 512, 8, 8])
torch.Size([32, 512, 8, 8])
torch.Size([32, 10])



In [28]:
with torch.no_grad():
    in_list, out_list = net.get_cell_outputs(inputs, return_inputs=True)
    
[print(i.shape, ' -> ', o.shape) for i, o in zip(in_list, out_list)]
print()

torch.Size([32, 128, 32, 32])  ->  torch.Size([32, 128, 32, 32])
torch.Size([32, 128, 32, 32])  ->  torch.Size([32, 128, 32, 32])
torch.Size([32, 128, 32, 32])  ->  torch.Size([32, 128, 32, 32])
torch.Size([32, 128, 16, 16])  ->  torch.Size([32, 256, 16, 16])
torch.Size([32, 256, 16, 16])  ->  torch.Size([32, 256, 16, 16])
torch.Size([32, 256, 16, 16])  ->  torch.Size([32, 256, 16, 16])
torch.Size([32, 256, 8, 8])  ->  torch.Size([32, 512, 8, 8])
torch.Size([32, 512, 8, 8])  ->  torch.Size([32, 512, 8, 8])
torch.Size([32, 512, 8, 8])  ->  torch.Size([32, 512, 8, 8])
torch.Size([32, 512, 8, 8])  ->  torch.Size([32, 10])



In [44]:
for i, next_net in enumerate(nasbench[:150]):
    if (i % 5) == 0:
        print(i)

    next_net = NBNetwork((next_net[2], next_net[1]), 10)
    
    with torch.no_grad():        
        n_in, n_out = next_net.get_cell_outputs(inputs, return_inputs=True)
        
        for i, ni in zip(in_list, n_in):
            assert i.shape == ni.shape
            
        for o, no in zip(out_list, n_out):
            assert o.shape == no.shape

0
5
10
15
20
25
30
35
40
45
50
55
60
65
70
75
80
85
90
95
100
105
110
115
120
125
130
135
140
145


In [43]:
len(nasbench)

423624