In [1]:
import numpy as np


def _normalize_per_hash(vals, hashes, normalize=False, axis=None):
    for i, net_hash in enumerate(np.unique(hashes)):
        filtered = vals[norm_hash == net_hash]
        
        if normalize:
            mean, std = np.mean(filtered, axis=axis), np.std(filtered, axis=axis)
            filtered = (filtered - mean) / std
        else:
            hmax = np.max(filtered, axis=axis)
            filtered /= hmax

def scale(values, hashes, labels, per_label=False, inplace=False, normalize=False, axis=None):
    if not inplace:
        values = values.copy()
        
    if per_label:
        for i in np.unique(labels):
            labelmap = labels == label

            label_hashes = hashes[labelmap]
            label_vals = values[labelmap]

            _normalize_per_hash(label_vals, label_hashes, normalize=normalize, axis=axis)
    else:
        _normalize_per_hash(values, hashes, normalize=normalize, axis=axis)
        
    return values

In [2]:
import numpy as np

from info_nas.datasets.io.semi_dataset import labeled_network_dataset
from info_nas.datasets.io.transforms import SortByWeights


def get_dataset_for_plotting(dataset, include_bias=True, fixed_label=None, print_freq=50000,
                             return_all_features=False):
    
    sorter = SortByWeights(include_bias=include_bias, fixed_label=fixed_label,
                           return_all_features=return_all_features)

    train_dataset = labeled_network_dataset(dataset, transforms=sorter,
                                            return_hash=True, return_ref_id=True)

    outputs = []
    labels = []
    hashes = []
    ref_ids = []

    print("Dataset length: len(train_dataset)")
    for i, b in enumerate(train_dataset):
        if i % print_freq == 0:
            print(i)

        outputs.append(b[3])
        labels.append(b[4])
        hashes.append(b[7])
        ref_ids.append(b[8])

    outputs = np.array([o.cpu().numpy() for o in outputs])
    labels = np.array(labels)
    hashes = np.array(hashes)
    ref_ids = np.array(ref_ids)
    
    return outputs, labels, hashes, ref_ids

In [11]:
import matplotlib.pyplot as plt
import seaborn as sns


def filter_data(outputs, labels, hashes, ref_ids, label=None, net_hash=None, ref=None):
    
    hashmap = hashes == net_hash if net_hash is not None else hashes == hashes
    labelmap = labels == label if label is not None else labels == labels
    refmap = ref_ids == ref if ref is not None else ref_ids == ref_ids
    
    omap = np.logical_and(hashmap, labelmap, refmap)

    return outputs[omap], labels[omap], hashes[omap], ref_ids[omap]

In [12]:
def plot_heatmap(values, n_features=None, figsize=None, fname=None):
    values = values[:n_features] if n_features is not None else values
    
    plt.figure(figsize=figsize)
    sns.heatmap(values)
    plt.tight_layout()
    
    if fname is not None:
        plt.savefig(fname)
    
    plt.show()
    plt.close()

In [None]:
def get_net_diff(hash_1, hash_2, o_data, label=None):
    o_data_1, _, _, _ = filter_data(*o_data, net_hash=hash_1, label=label)
    o_data_2, _, _, _ = filter_data(*o_data, net_hash=hash_2, label=label)
    
    return np.sum(np.mean((o_data_1 - o_data_2) ** 2, axis=1))


def get_error_matrix(outputs, labels, hashes, ref_ids, label=None, print_freq=10):
    all_hashes = np.unique(hashes)
    size = len(all_hashes)

    o_data = [outputs, labels, hashes, ref_ids]
    errors = np.zeros((size, size))

    for i, ival in enumerate(all_hashes):
        if i % print_freq == 0:
            print(i)
        
        for j, jval in enumerate(all_hashes):
            
            errors[i, j] = get_net_diff(ival, jval, o_data, label=label)
            
    return errors