In [2]:
import json
import pathlib

import numpy as np
import sklearn
import yaml
from sklearn.preprocessing import normalize
from numba import jit


from utils import get_weight_path_in_current_system


In [3]:
def load_features() -> dict:
    datasets = ("cifar10", "cifar100", "ag_news")
    epochs = (500, 500, 100)

    features = {}
    for dataset, epoch in zip(datasets, epochs):

        base_dir = pathlib.Path("../results/{}/analysis/save_unnormalised_feature/".format(dataset))

        for config_path in base_dir.glob("**/config.yaml"):
                
            with open(config_path) as f:
                config = yaml.load(f, Loader=yaml.FullLoader)
                
                seed = config["experiment"]["seed"]

                if config["experiment"]["use_projection_head"]:
                    extractor = "Head"
                else:
                    extractor = "Without Head"

            self_sup_path = pathlib.Path(
                get_weight_path_in_current_system(config["experiment"]["target_weight_file"])).parent

            with open(self_sup_path / ".hydra" / "config.yaml") as f:
                config = yaml.load(f, Loader=yaml.FullLoader)
                num_mini_batches = config["experiment"]["batches"]


            path = config_path.parent.parent

            d = dataset.replace("100", "").replace("10", "")
            y_train = np.load(path / "epoch_{}-{}.pt.label.train.npy".format(epoch, d))

            X_train_0 = np.load(path / "epoch_{}-{}.pt.feature.0.train.npy".format(epoch, d))
            X_train_1 = np.load(path / "epoch_{}-{}.pt.feature.1.train.npy".format(epoch, d))
            
            d_name = dataset

            if "augmentation_type" in config["dataset"]:
                d_name = "{}-{}".format(dataset, config["dataset"]["augmentation_type"])
                
            if d_name not in features:
                features[d_name] = {}                

            if extractor not in features[d_name]:
                features[d_name][extractor] = {}

            if seed not in features[d_name][extractor]:
                features[d_name][extractor][seed] = {}

            features[d_name][extractor][seed][num_mini_batches] = (
                X_train_0,
                X_train_1,
                y_train
            )

    return features


In [4]:
features = load_features()

In [5]:
@jit(nopython=True, parallel=True)
def compute_bound(c, y_train, X_train_0, X_train_1):
    target_ids = y_train == c
    X_train_0_c = X_train_0[target_ids]
    X_train_1_c = X_train_1[target_ids]
    cos_sim = X_train_0_c.dot(X_train_1_c.T)
    n = np.sum(target_ids)

    bounds_by_sample = np.abs(cos_sim - np.diag(cos_sim)).sum(axis=0) / (n - 1)
    return bounds_by_sample

In [6]:
upper_bound_collision = {}
for dataset, f_d in features.items():
    upper_bound_collision[dataset] = {}

    for head_info, f_d_h in f_d.items():

        upper_bound_collision[dataset][head_info] = {}
        for seed, f_d_h_s in f_d_h.items():
            negs = list(sorted(f_d_h_s))

            for i, neg in enumerate(negs):
                if neg not in upper_bound_collision[dataset][head_info]:
                    upper_bound_collision[dataset][head_info][neg] = []

                X_train_0, X_train_1, y_train = f_d_h[seed][neg]

                C = len(np.unique(y_train))

                X_train_0 = sklearn.preprocessing.normalize(X_train_0, axis=1)
                X_train_1 = sklearn.preprocessing.normalize(X_train_1, axis=1)

                upper_bounds = []

                for c in range(C):
                    upper_bounds.append(
                        compute_bound(c, y_train, X_train_0, X_train_1)
                    )

                upper_bound = np.array(upper_bounds).flatten().mean()
                print(dataset, head_info, seed, neg, upper_bound)
                upper_bound_collision[dataset][head_info][neg].append(float(upper_bound))


cifar10 Without Head 13 32 0.20344003
cifar10 Without Head 13 64 0.20893289
cifar10 Without Head 13 128 0.21950303
cifar10 Without Head 13 256 0.23831828
cifar10 Without Head 13 512 0.37764415
cifar10 Without Head 11 32 0.20120734
cifar10 Without Head 11 64 0.21036161
cifar10 Without Head 11 128 0.21415293
cifar10 Without Head 11 256 0.24843346
cifar10 Without Head 11 512 0.3816658
cifar10 Without Head 7 32 0.2020581
cifar10 Without Head 7 64 0.20859843
cifar10 Without Head 7 128 0.21780524
cifar10 Without Head 7 256 0.24569704
cifar10 Without Head 7 512 0.37573105
cifar10 Head 13 32 0.6004361
cifar10 Head 13 64 0.60644984
cifar10 Head 13 128 0.6113785
cifar10 Head 13 256 0.6158205
cifar10 Head 13 512 0.62248313
cifar10 Head 11 32 0.60062814
cifar10 Head 11 64 0.6057505
cifar10 Head 11 128 0.60984033
cifar10 Head 11 256 0.6150309
cifar10 Head 11 512 0.6226506
cifar10 Head 7 32 0.60107994
cifar10 Head 7 64 0.60679954
cifar10 Head 7 128 0.6118673
cifar10 Head 7 256 0.6146793
cifar10 Head

In [7]:
with open("upper_bound_collision.json", "w") as f:
    json.dump(upper_bound_collision, f)
