In [None]:
%autosave 0

import torch

seed = 1

In [None]:
import os


data_dir = '../data/'
os.listdir(data_dir)


def from_data(path, data=False):
    return os.path.join(data_dir, 'data' if data else '', path)

In [None]:
device = torch.device('cuda')

In [None]:
from nasbench import api

nasbench_path = from_data('nasbench_only108.tfrecord')
nb = api.NASBench(nasbench_path)

In [None]:
import os

model_dir = os.path.join(data_dir, 'out_0534eeefa12ca7c2b177541ad24929f1')

subdirs = os.listdir(model_dir)
subdirs = [os.path.join(model_dir, d) for d in subdirs]
subdirs = [d for d in subdirs if os.path.isdir(d)]

In [None]:
from info_nas.datasets.io.create_dataset import dataset_from_pretrained
from nasbench_pytorch.datasets.cifar10 import prepare_dataset
import torch
import random

seed_dir = os.path.join(data_dir, 'seed_experiment_models')
if not os.path.exists(seed_dir):
    os.mkdir(seed_dir)

pretrain = False
    
if pretrain:
    for sd in subdirs:
        base_path = os.path.basename(sd)
        print(base_path)
    
        cifar_batch = 128
        
        random.seed(1)
        torch.manual_seed(1)
        cifar = prepare_dataset(cifar_batch, validation_size=1000, num_workers=4, root=from_data('cifar'), random_state=1)

        dataset_from_pretrained(sd, nb, cifar, os.path.join(seed_dir, f'net_seeds_{base_path}.pt'))

In [None]:
from info_nas.datasets.arch2vec_dataset import prepare_labeled_dataset

all_labeled = {}
seed_dir = os.path.join(data_dir, 'seed_experiment_models')

for net_path in os.listdir(seed_dir):
    print(net_path)
    
    net_path_pt = os.path.join(seed_dir, net_path)

    labeled, _ = prepare_labeled_dataset(net_path_pt, nb, nb_dataset=from_data('nb_dataset.json'), dataset=from_data('cifar'),
                                         remove_labeled=False)
    
    all_labeled[net_path] = labeled

In [None]:
all_labeled['net_seeds_1.pt']['outputs']

In [None]:
from info_nas.datasets.io.semi_dataset import labeled_network_dataset
from info_nas.datasets.io.transforms import get_transforms

def get_labeled_data(data):
    transforms = get_transforms(from_data('scales/scales/scale-train-include_bias.pickle'),
                                True, None, True, scale_whole_path=None)
    transforms.transforms = [transforms.transforms[0], transforms.transforms[2]]

    labeled = labeled_network_dataset(data, transforms=transforms, return_ref_id=True)
    return torch.utils.data.DataLoader(labeled, batch_size=32, shuffle=False, num_workers=4)

In [None]:
import numpy as np

def get_pred_and_orig(gen, model=None, print_freq=1000):
    orig = []
    pred = []
    info = []
    weights = []
    labels = []


    for i, batch in enumerate(gen):
        if i % print_freq == 0:
            print(f"Batch {i}")

        info.append({w: batch[w] for w in ['label', 'hash', 'ref_id']})

        b = batch['adj'], batch['ops'], batch['input'], batch['output']

        if model is not None:
            res = model(b[1].to(device), b[0].to(device), b[2].to(device))
            pred.append(res[-1].detach().cpu().numpy())
        orig.append(b[3].numpy())
        weights.append(np.concatenate([batch['weights'], batch['bias'][:, :, np.newaxis]], axis=-1))
        labels.append(batch['label'].numpy())

    orig = np.vstack(orig)
    weights = np.vstack(weights)
    labels = np.hstack(labels)
    
    if model is None:
        return orig, info, weights, labels
    
    pred = np.vstack(pred)
    return orig, pred, info, weights, labels

In [None]:
features = {}

for k, v in all_labeled.items():
    gen = get_labeled_data(v)
    o, i, w, lab = get_pred_and_orig(gen)
    
    features[k] = (o, i, w, lab)

In [None]:
def feat_norm(f, axis=1, no_div=False):
    if axis is None:
        res = (f - np.mean(f))
        return res if no_div else res / np.std(f)
    
    args = (1, -1) if axis == 0 else (-1, 1)
    res = (f - np.mean(f, axis=axis).reshape(*args))
    return res if no_div else res / np.std(f, axis=axis).reshape(*args)

In [None]:
other_net = from_data('test_net', data=False)
other_net_pt = from_data(f'other_net.pt', data=False)

cifar_batch = 128

random.seed(1)
torch.manual_seed(1)
cifar = prepare_dataset(cifar_batch, validation_size=1000, num_workers=4, root=from_data('cifar'), random_state=1)

dataset_from_pretrained(other_net, nb, cifar, other_net_pt)


other_net_labeled, _ = prepare_labeled_dataset(other_net_pt, nb, device=torch.device('cpu'), nb_dataset=from_data('nb_dataset.json'), dataset=from_data('cifar'),
                                               remove_labeled=False)

In [None]:
gen = get_labeled_data(other_net_labeled)
o_test, i_test, w_test, lab_test = get_pred_and_orig(gen)

In [None]:
o_test

In [None]:
def mult_by_weights(feats, ws, labs):
    ws = -np.sort(-ws, axis=-1)
    mult_out = feats * ws[np.arange(1000), labs]
    return mult_out

In [None]:
mw = mult_by_weights(net_1[0], net_1[2], net_1[3])
mw2 = mult_by_weights(net_2[0], net_2[2], net_2[3])
mw_test = mult_by_weights(o_test, w_test, lab_test)

mw = feat_norm(mw, axis=1, no_div=True)
mw2 = feat_norm(mw2, axis=1, no_div=True)
mw_test = feat_norm(mw_test, axis=1, no_div=True)

In [None]:
top_k = 5

vmin = 0
vmax = 0.5
xlim = (0, 0.5)
ylim = (0, 5000)

binwidth = 0.025

sns.heatmap(mw[:, :top_k])
plt.show()
sns.heatmap(mw2[:, :top_k])
plt.show()
sns.heatmap(mw_test[:, :top_k])
plt.show()

mw_diff = mw - mw2
mw_diff = np.square(mw_diff)
sns.heatmap(mw_diff[:, :top_k], vmin=vmin, vmax=vmax)
plt.show()
sns.histplot(mw_diff[:, :top_k].flatten(), binwidth=binwidth)
plt.xlim(*xlim)
plt.ylim(*ylim)
plt.show()

mw_diff = mw - mw_test
mw_diff = np.square(mw_diff)
sns.heatmap(mw_diff[:, :top_k], vmin=vmin, vmax=vmax)
plt.show()

sns.histplot(mw_diff[:, :top_k].flatten(), binwidth=binwidth)
plt.xlim(*xlim)
plt.ylim(*ylim)
plt.show()

In [None]:
net_1 = features['net_seeds_0.pt']
net_2 = features['net_seeds_6.pt']

sns.heatmap(mult_by_weights(net_1[0], net_1[2], net_1[3])[:, :10], vmax=1.6, vmin=0.2)
plt.show()

sns.heatmap(mult_by_weights(net_2[0], net_2[2], net_2[3])[:, :10], vmax=1.6, vmin=0.2)
plt.show()

sns.heatmap(mult_by_weights(o_test, w_test, lab_test)[:, :10], vmax=1.6, vmin=0.2)
plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set()

sns.heatmap(features['net_seeds_0.pt'][0])
plt.show()
sns.heatmap(features['net_seeds_1.pt'][0])
plt.show()

sns.heatmap(feat_norm(features['net_seeds_0.pt'][0]), vmax=6, vmin=0)
plt.show()
sns.heatmap(feat_norm(features['net_seeds_1.pt'][0]), vmax=6, vmin=0)
plt.show()
sns.heatmap(feat_norm(o_test), vmax=6, vmin=0)
plt.show()

In [None]:
from matplotlib.colors import LogNorm

diff = feat_norm(features['net_seeds_0.pt'][0]) - feat_norm(features['net_seeds_6.pt'][0])
diff = np.square(diff)
plt.title('Same net, different seeds')
sns.heatmap(diff[:], cmap='YlGnBu', vmax=2, vmin=0)
plt.show()

In [None]:
diff.shape

In [None]:
sns.histplot(diff.flatten())
plt.show()

In [None]:
diff_2 = feat_norm(features['net_seeds_0.pt'][0]) - feat_norm(o_test)
diff_2 = np.square(diff_2)
plt.title('Different nets')
sns.heatmap(diff_2[:], cmap='YlGnBu', vmax=2, vmin=0)
plt.show()

In [None]:
sns.histplot(diff_2.flatten())
plt.show()

In [None]:
top_k = 100

plt.title(f"Abs. Difference histogram for top {top_k} features")
sns.histplot(diff[:, :top_k].flatten())
plt.xlim(0, 6)
#plt.ylim(0, 2000)
plt.show()

sns.histplot(diff_2[:, :top_k].flatten())
plt.xlim(0, 6)
#plt.ylim(0, 2000)
plt.show()

In [None]:
top_kk = 10
xlim = 90

plt.title(f"MSE between features (per image) - same net, top {top_kk}")
err1 = np.square(feat_norm(features['net_seeds_3.pt'][0])[:, :top_kk] - feat_norm(features['net_seeds_6.pt'][0])[:, :top_kk]).sum(axis=1)
sns.histplot(err1, binwidth=5)
plt.xlim(0, xlim)
print(np.mean(err1))
print(np.median(err1))
print(np.std(err1))
plt.show()

In [None]:
err2 = np.square(feat_norm(features['net_seeds_3.pt'][0])[:, :top_kk] - feat_norm(o_test)[:, :top_kk]).sum(axis=1)
sns.histplot(err2, binwidth=5)
plt.xlim(0, xlim)
plt.title(f"MSE between features (per image) - different nets, top {top_kk}")
print(np.mean(err2))
print(np.median(err2))
print(np.std(err2))
plt.show()

In [None]:
np.mean(features['net_seeds_1.pt'][0], axis=1)

In [None]:
np.mean(o_test, axis=1).shape