In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
import os

from data.mnist_novelty import MNIST_OneClass
from data.cifar_novelty import Cifar10_OneClass
from hugeica import *

from torchvision.transforms import transforms
from sklearn.metrics import roc_auc_score

np.random.seed(252525)
torch.manual_seed(252525)


<torch._C.Generator at 0x7f46dc191610>

In [32]:
def preprocess(X, X_in, X_out, norm_contrast=True, DC=True, channels=None):
    
    #DEQUANTIZE
    X_, _ = dequantize(X) 
    
    # CONTRAST
    if norm_contrast:
        X_, _ = to_norm_contrast(X_, DC=DC, channels=channels)
    mean, std = np.zeros(X_.mean(0).shape), X_.std()

    # SCALE
    X_, _ = scale(X_, mean, std)
    
    X_in_, _ = dequantize(X_in)
    if norm_contrast:
        X_in_, _ = to_norm_contrast(X_in_, DC=DC, channels=channels)
    X_in_, _ = scale(X_in_, mean, std)

    X_out_, _ = dequantize(X_out)
    if norm_contrast:
        X_out_, _ = to_norm_contrast(X_out_, DC=DC, channels=channels)
    
    X_out_, _ = scale(X_out_, mean, std)
    return X_, X_in_, X_out_, mean, std

def augment(X):
    X = torch.stack( [ transforms.ToTensor()( TF.rotate(transforms.ToPILImage()(  torch.from_numpy(x.reshape(3,32,32))  ), [90, 180, 270, 0][np.random.randint(4)]) )  for x in X ] )
    return X.reshape(len(X), -1)

def augment_rot(X, rot = [0, 90, 180, 270], random=True):
    X_aug = []
    for x in X:
        for r in rot:
            if random:
                r = rot[np.random.randint(4)]
            X_aug += [transforms.ToTensor()( TF.rotate(transforms.ToPILImage()(  torch.from_numpy(x.reshape(3,32,32))  ) , r) )]
            if random:
                break             
    X = torch.stack( X_aug)
    return X.reshape(len(X), -1)

def augment_trans(X, trans = range(-2, 2)):
    X = torch.stack( [torch.from_numpy(  np.roll(x.reshape(3, 32, 32), int(np.random.uniform(*trans)), 1))  for x in X] )
    X = torch.stack( [torch.from_numpy(  np.roll(x.reshape(3, 32, 32), int(np.random.uniform(*trans)), 2))  for x in X] )
    return X.reshape(len(X), -1)

In [33]:
log_full = []
for clazz in range(0,10):

    test_classes = [0,1,2,3,4,5,6,7,8,9]
    test_classes.remove(clazz)
    shape = (3, 32, 32)
    trans = [transforms.ToTensor()]
    
    np.random.seed(32)
    X_, X_valid_, X_test_ = Cifar10_OneClass(train_classes=[clazz], test_classes=test_classes, z_normalize=False, balance=True, transform=transforms.Compose(trans))[0]
    
    X_ = augment_rot(X_)
    #X_ = augment_trans(X_)
    
    X_valid_ = augment_rot(X_valid_)
    #X_valid_ = augment_trans(X_valid_)
    
    test_ = augment_rot(X_test_)
    #X_test_ = augment_trans(X_test_)

    hyp2 = SFA.hyperparameter_search(X_, X_valid_, X_test_, 
                      patch_size=range(14,31,4),
                      n_components=["q90"], 
                      stride=[2], 
                      shape=(3,32,32), 
                      bs=10000, 
                      lr=1e-4,
                      epochs=20,
                      norm=[2],
                      mode="ta",
                      max_components=256,
                      remove_components=[0],
                      use_conv=False,
                      logging=1, 
                      aucs=["mean"]) 
    log_full.append(hyp2)
    print(clazz)
    
    
concat = pd.concat(log_full)
concat["class"] = np.repeat(np.arange(len(log_full)), len(log_full[0]))

Files already downloaded and verified
Files already downloaded and verified
Cifar10_OneClass(z_normalize=False, train_classes=[0], test_classes=[1, 2, 3, 4, 5, 6, 7, 8, 9], data_train=(5000, 3072), data_test_inliers=(1000, 3072), data_test_outliers=(1000, 3072))
# Fit SpatialICA(q90).
# Fit HugeICA((500000, 588, 256), device='cuda', bs=10000)
 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋    | 49/50 [00:06<00:00,  8.15it/s]
Ep.  0 - -0.9848 - validation (loss/white/kurt/mi/logp): -1.0033 / 0.02 / 4.62 / 0.8171 / 0.3362 (eval took: 0.0s)
# Re-Fit SpatialICA(40).
# Fit HugeICA((500000, 588, 40), device='cuda', bs=10000)
 98%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [35]:
concat.to_csv(f"./experiments/cifar10_hyperparameter_search_q90_ta_aug_rot_no_duplicates.csv")

In [93]:
log = pd.read_csv("./experiments/cifar10_hyperparameter_search_q90_ta_aug_rot_trans.csv")

df = []
for i in range(10):
    l = log[log["class"] == i]
    l = l[l["nor"] == 2]
    df.append((l.sort_values("negH_sum", ascending=False)["4mean"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["patch_size"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["n_components"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["negH_sum"].head(1).item()))

df.append(tuple(np.asarray(df).mean(0)))
pd.DataFrame(df, columns=["score", "patch_size", "k", "-H"] )

Unnamed: 0,score,patch_size,k,-H
0,0.613336,26.0,85.0,0.055209
1,0.651199,22.0,81.0,0.058757
2,0.474452,26.0,90.0,0.042917
3,0.625699,26.0,77.0,0.038364
4,0.584731,22.0,83.0,0.035822
5,0.588331,26.0,81.0,0.022977
6,0.588214,22.0,97.0,0.036123
7,0.679984,26.0,92.0,0.027616
8,0.774771,30.0,76.0,0.041336
9,0.704293,22.0,79.0,0.04503


In [91]:
log = pd.read_csv("./experiments/cifar10_hyperparameter_search_q90_ta_aug_rot.csv")

df = []
for i in range(10):
    l = log[log["class"] == i]
    l = l[l["nor"] == 2]
    df.append((l.sort_values("negH_sum", ascending=False)["4mean"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["patch_size"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["n_components"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["negH_sum"].head(1).item()))

df.append(tuple(np.asarray(df).mean(0)))
pd.DataFrame(df, columns=["score", "patch_size", "k", "-H"] )

Unnamed: 0,score,patch_size,k,-H
0,0.637312,26.0,84.0,0.058965
1,0.593584,18.0,63.0,0.059895
2,0.510612,26.0,88.0,0.039255
3,0.606968,26.0,75.0,0.038243
4,0.503812,22.0,81.0,0.035773
5,0.552836,26.0,79.0,0.023622
6,0.606856,22.0,95.0,0.037287
7,0.615,22.0,77.0,0.032051
8,0.733188,30.0,73.0,0.044066
9,0.662956,22.0,77.0,0.039864


In [94]:
log = pd.read_csv("./experiments/cifar10_hyperparameter_search_q90_ta_aug_rot_no_duplicates.csv")

df = []
for i in range(10):
    l = log[log["class"] == i]
    l = l[l["nor"] == 2]
    df.append((l.sort_values("negH_sum", ascending=False)["mean"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["patch_size"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["n_components"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["negH_sum"].head(1).item()))

df.append(tuple(np.asarray(df).mean(0)))
pd.DataFrame(df, columns=["score", "patch_size", "k", "-H"] )

Unnamed: 0,score,patch_size,k,-H
0,0.596504,22.0,74.0,0.055379
1,0.638222,18.0,62.0,0.064761
2,0.500649,22.0,77.0,0.046877
3,0.596611,26.0,75.0,0.039464
4,0.575776,26.0,91.0,0.038256
5,0.588081,26.0,79.0,0.025666
6,0.726227,22.0,95.0,0.037306
7,0.621574,22.0,77.0,0.035694
8,0.666348,22.0,63.0,0.04489
9,0.670834,22.0,77.0,0.037511


In [92]:
log = pd.read_csv("./experiments/cifar10_hyperparameter_search_q90_ta_aug_trans.csv")

df = []
for i in range(10):
    l = log[log["class"] == i]
    l = l[l["nor"] == 2]
    df.append((l.sort_values("negH_sum", ascending=False)["mean"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["patch_size"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["n_components"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["negH_sum"].head(1).item()))

df.append(tuple(np.asarray(df).mean(0)))
pd.DataFrame(df, columns=["score", "patch_size", "k", "-H"] )

Unnamed: 0,score,patch_size,k,-H
0,0.729301,26.0,83.0,0.05017
1,0.710757,26.0,92.0,0.040965
2,0.508655,26.0,89.0,0.042188
3,0.60903,26.0,77.0,0.041942
4,0.56664,22.0,82.0,0.029609
5,0.624324,26.0,80.0,0.025874
6,0.72119,22.0,95.0,0.040797
7,0.717727,22.0,76.0,0.028446
8,0.801793,26.0,72.0,0.039698
9,0.737923,22.0,83.0,0.041832
