In [None]:
seed = 1

from nasbench import api

nasbench_path = '../data/nasbench_only108.tfrecord'
nb = api.NASBench(nasbench_path)

import torch
from info_nas.datasets.arch2vec_dataset import get_labeled_unlabeled_datasets

#torch.backends.cudnn.benchmark = True
device = torch.device('cuda')

# device = None otherwise the dataset is save to the cuda as a whole
dataset, _ = get_labeled_unlabeled_datasets(nb, device=device, seed=seed,
                                            train_pretrained=None,
                                            valid_pretrained=None,
                                            train_labeled_path='../data/train_long.pt',
                                            valid_labeled_path='../data/valid_long.pt')

In [None]:
from info_nas.datasets.io.semi_dataset import labeled_network_dataset
from scripts.train_vae import get_transforms

transforms = get_transforms('../data/scales/scale-train-include_bias.pickle',
                            True, None, True)

labeled = labeled_network_dataset(dataset['train'], transforms=transforms)

In [None]:
for b in labeled:
    print(b[3].shape)
    break

In [None]:
import numpy as np

np_outs = [b[3].detach().numpy() for b in labeled]

In [None]:
labeled_mean = np.mean(np_outs, axis=0)
labeled_mean = torch.Tensor(labeled_mean)
labeled_mean.shape

In [None]:
np.std(np_outs, axis=0)

In [None]:
import torch
import torch.nn as nn

gen = torch.utils.data.DataLoader(labeled, batch_size=32, shuffle=True, num_workers=0)
ref = labeled_mean.repeat(32, 1)
#ref = torch.full((32, 513), np.mean(np_outs))
loss = nn.MSELoss()
#loss = nn.L1Loss()

losses = []
for b in gen:
    l = loss(ref, b[3])
    losses.append(l.item())

In [None]:
np.mean(losses)

In [None]:
transforms_val = get_transforms('../data/scales/scale-valid-include_bias.pickle',
                                True, None, True)
labeled_val = labeled_network_dataset(dataset['valid'], transforms=transforms_val)

val_outs = [b[3].detach().numpy() for b in labeled_val]
labeled_mean_val = np.mean(val_outs, axis=0)
labeled_mean_val = torch.Tensor(labeled_mean_val)

In [None]:
gen2 = torch.utils.data.DataLoader(labeled_val, batch_size=32, shuffle=False, num_workers=0)

ref = labeled_mean_val.repeat(32, 1)
#ref = labeled_mean.repeat(32, 1)
#ref = torch.full((32, 513), np.mean(val_outs))
#loss = nn.MSELoss()
loss = nn.L1Loss()

losses = []
for b in gen2:
    if b[3].shape != ref.shape:
        print('Prr')  # TODO !!!!!!
        continue
    
    l = loss(ref, b[3])
    losses.append(l.item())

np.mean(losses)

In [None]:
check_path = '../data/vae_checkpoints/2021-11-07_17-05-53/model_dense_epoch-29.pt'
trained_checkpoint = torch.load(check_path, map_location=device)

In [None]:
from arch2vec.extensions.get_nasbench101_model import get_arch2vec_model
from info_nas.models.utils import load_extended_vae


model, optimizer = get_arch2vec_model(device=device)
model, _ = load_extended_vae(check_path, [model, 3, 513], device=device)

In [None]:
how_many = 20

orig = []
pred = []

for i, b in enumerate(gen):
    if i >= how_many:
        break
    
    res = model(b[1].to(device), b[0].to(device), b[2].to(device))
    pred.append(res[-1].detach().cpu().numpy())
    orig.append(b[3].numpy())
        
orig = np.vstack(orig)
pred = np.vstack(pred)

## Prediction vs original comparison

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib notebook

plt.close()
plt.figure(figsize=(5,5))
plt.title("Original outputs")
sns.heatmap(orig, vmax=10)
plt.tight_layout()
plt.show()

plt.figure(figsize=(5,5))
plt.title("Predicted outputs")
sns.heatmap(pred, vmax=10)
plt.tight_layout()
plt.show()

In [None]:
how_many = 20

orig = []
pred = []

for i, b in enumerate(gen2):
    if i >= how_many:
        break
    
    res = model(b[1].to(device), b[0].to(device), b[2].to(device))
    pred.append(res[-1].detach().cpu().numpy())
    orig.append(b[3].numpy())
        
orig = np.vstack(orig)
pred = np.vstack(pred)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib notebook

plt.close()
plt.figure(figsize=(5,5))
plt.title("Original outputs - val")
sns.heatmap(orig, vmax=10)
plt.tight_layout()
plt.show()


plt.figure(figsize=(5,5))
plt.title("Predicted outputs - val")
sns.heatmap(pred, vmax=10)
plt.tight_layout()
plt.show()

In [None]:
orig = []
pred = []

first = None

print(len(gen))
for i, b in enumerate(gen):
    if i % 100 == 0:
        print(i)
    
    if first is None:
        first = b[2][0]
    
    res = model(b[1].to(device), b[0].to(device), b[2].to(device))
    
    pbatch = res[-1].detach().cpu().numpy()
    obatch = b[3].numpy()
    
    for ins, o, p in zip(b[2], obatch, pbatch):
        if (ins == first).all():
            pred.append(p)
            orig.append(o)

In [None]:
orig_im = np.array(orig)
pred_im = np.array(pred)


plt.figure(figsize=(5,5))
plt.title("Original outputs - same image")
sns.heatmap(orig_im, vmax=10)
plt.tight_layout()
plt.show()

plt.figure(figsize=(5,5))
plt.title("Predicted outputs - same image")
sns.heatmap(pred_im, vmax=10)
plt.tight_layout()
plt.show()

In [None]:
orig = []
pred = []

first = None
model.eval()

print(len(gen))
for i, b in enumerate(gen):
    if i % 100 == 0:
        print(i)
    
    if first is None:
        first = b[0][0], b[1][0]
    
    res = model(b[1].to(device), b[0].to(device), b[2].to(device))
    
    pbatch = res[-1].detach().cpu().numpy()
    obatch = b[3].numpy()
    
    for adj, ops, o, p in zip(b[0], b[1], obatch, pbatch):
        if (adj == first[0]).all() and (ops == first[1]).all():
            pred.append(p)
            orig.append(o)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib notebook

orig_im = np.array(orig)
pred_im = np.array(pred)


plt.figure(figsize=(5,5))
plt.title("Original outputs - same net")
sns.heatmap(orig_im, vmax=6)
plt.tight_layout()
plt.show()

plt.figure(figsize=(5,5))
plt.title("Predicted outputs - same net")
sns.heatmap(pred_im, vmax=6)
plt.tight_layout()
plt.show()

In [None]:
how_many = 20

import torch
import torch.nn as nn

gen = torch.utils.data.DataLoader(labeled, batch_size=32, shuffle=True, num_workers=0)

orig = []
pred = []

model.eval()

for i, b in enumerate(gen):
    #f i >= how_many:
    #   break
    if i % 100 == 0:
        print(i)
    
    res = model(b[1].to(device), b[0].to(device), b[2].to(device))
    pred.append(res[-1].detach().cpu().numpy())
    orig.append(b[3].numpy())
        
orig = np.vstack(orig)
pred = np.vstack(pred)

In [None]:
np.max(np.abs(orig - pred), axis=0)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib notebook

diff = np.abs(orig - pred)

plt.figure(figsize=(5,5))
plt.title("diff")
sns.heatmap(diff, vmax=5)
plt.tight_layout()
plt.show()

In [None]:
pred.min()

Dál bordel

In [None]:
from info_nas.datasets.io.semi_dataset import labeled_network_dataset
from scripts.train_vae import get_transforms

transforms = get_transforms('../data/scales/scale-train-include_bias-axis_0.pickle',
                            True, 0, True)

labeled = labeled_network_dataset(dataset['train'], transforms=transforms)
gen3 = torch.utils.data.DataLoader(labeled, batch_size=32, shuffle=True, num_workers=4)


In [None]:
how_many = 20

orig = []

for i, b in enumerate(gen3):
    if i >= how_many:
        break
    
    orig.append(b[3].numpy())
        
orig = np.vstack(orig)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib notebook

plt.figure(figsize=(5,5))
plt.title("Original outputs - same image")
sns.heatmap(orig)
plt.tight_layout()
plt.show()

In [None]:
from info_nas.datasets.io.semi_dataset import labeled_network_dataset
from scripts.train_vae import get_transforms

transforms = get_transforms('../data/scales/scale-train-include_bias.pickle',
                            True, None, True)

labeled = labeled_network_dataset(dataset['train'], transforms=transforms)
gen = torch.utils.data.DataLoader(labeled, batch_size=32, shuffle=True, num_workers=1)

In [None]:
how_many = 20

orig = []

for i, b in enumerate(gen):
    if i >= how_many:
        break
    
    orig.append(b[3].numpy())
        
orig = np.vstack(orig)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib notebook

plt.figure(figsize=(5,5))
plt.title("Original outputs - same image")
sns.heatmap(orig)
plt.tight_layout()
plt.show()

In [None]:
orig = []

first = None

print(len(gen3))
for i, b in enumerate(gen3):
    if i % 100 == 0:
        print(i)
    
    if first is None:
        first = b[0][0], b[1][0]
    
    obatch = b[3].numpy()
    
    for adj, ops, o in zip(b[0], b[1], obatch):
        if (adj == first[0]).all() and (ops == first[1]).all():
            orig.append(o)
            
orig = np.array(orig)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib notebook

plt.figure(figsize=(5,5))
plt.title("Original outputs")
sns.heatmap(orig)
plt.tight_layout()
plt.show()