In [8]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
import os

from data.mnist_novelty import MNIST_OneClass
from data.cifar_novelty import Cifar10_OneClass
from hugeica import *

import torchvision
from torchvision.transforms import transforms
from sklearn.metrics import roc_auc_score

np.random.seed(252525)
torch.manual_seed(252525)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Downloading bedroom val set


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 4561k  100 4561k    0     0  2786k      0  0:00:01  0:00:01 --:--:-- 2784k


In [None]:
import subprocess
from urllib.request import Request, urlopen
from os.path import join
import zipfile

def list_categories():
    url = 'http://dl.yf.io/lsun/categories.txt'
    with urlopen(Request(url)) as response:
        return response.read().decode().strip().split('\n')


def download(out_dir, category, set_name):
    url = 'http://dl.yf.io/lsun/scenes/{category}_' \
          '{set_name}_lmdb.zip'.format(**locals())
    if set_name == 'test':
        out_name = 'test_lmdb.zip'
        url = 'http://dl.yf.io/lsun/scenes/{set_name}_lmdb.zip'
    else:
        out_name = '{category}_{set_name}_lmdb.zip'.format(**locals())
    out_path = join(out_dir, out_name)
    cmd = ['curl', url, '-o', out_path]
    print('Downloading', category, set_name, 'set')
    subprocess.call(cmd)
    
list_categories()
download("./data", "bedroom", "val")

zipf = zipfile.ZipFile("./data/bedroom_val_lmdb.zip")
zipf.extractall("./data")

In [9]:
svhn = torchvision.datasets.SVHN("./data", split='train', transform=transforms.ToTensor(), target_transform=None, download=True)
cifar10 = torchvision.datasets.CIFAR10("./data", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
cifar10_test = torchvision.datasets.CIFAR10("./data", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)
lsun =  torchvision.datasets.LSUN(root='./data', classes=['bedroom_val'], transform=transforms.Compose([transforms.CenterCrop((200, 200)), transforms.Resize((32,32)), transforms.ToTensor()])) 

Using downloaded and verified file: ./data/train_32x32.mat
Files already downloaded and verified
Files already downloaded and verified


In [10]:
def preprocess(X, X_in, X_out, norm_contrast=True, DC=True, channels=None):
    
    #DEQUANTIZE
    X_, _ = dequantize(X) 
    
    # CONTRAST
    if norm_contrast:
        X_, _ = to_norm_contrast(X_, DC=DC, channels=channels)
    mean, std = np.zeros(X_.mean(0).shape), X_.std()

    # SCALE
    X_, _ = scale(X_, mean, std)
    
    X_in_, _ = dequantize(X_in)
    if norm_contrast:
        X_in_, _ = to_norm_contrast(X_in_, DC=DC, channels=channels)
    X_in_, _ = scale(X_in_, mean, std)

    X_out_, _ = dequantize(X_out)
    if norm_contrast:
        X_out_, _ = to_norm_contrast(X_out_, DC=DC, channels=channels)
    
    X_out_, _ = scale(X_out_, mean, std)
    return X_, X_in_, X_out_, mean, std

def augment(X):
    X = torch.stack( [ transforms.ToTensor()( TF.rotate(transforms.ToPILImage()(  torch.from_numpy(x.reshape(3,32,32))  ), [90, 180, 270, 0][np.random.randint(4)]) )  for x in X ] )
    return X.reshape(len(X), -1)

def augment_rot(X, rot = [0, 90, 180, 270]):
    X_aug = []
    for x in X:
        for r in rot:
             X_aug += [transforms.ToTensor()( TF.rotate(transforms.ToPILImage()(  torch.from_numpy(x.reshape(3,32,32))  ) , r) )]
    X = torch.stack( X_aug)
    return X.reshape(len(X), -1)

def augment_trans(X, trans = range(-8, 8)):
    X = torch.stack( [torch.from_numpy(  np.roll(x.reshape(3, 32, 32), int(np.random.uniform(-2, 2)), 1))  for x in X] )
    X = torch.stack( [torch.from_numpy(  np.roll(x.reshape(3, 32, 32), int(np.random.uniform(-2, 2)), 2))  for x in X] )
    return X.reshape(len(X), -1)

In [14]:
X_cifar = torch.stack([x for x,y in cifar10]).numpy()
X_cifar = X_cifar[np.random.permutation(int(0.1*len(X_cifar)))]
X_cifar_test = torch.stack([x for x,y in cifar10_test]).numpy()
X_svhn  = torch.stack([x for x,y in svhn]).numpy()
X_lsun  = torch.stack([x for x,y in lsun]).numpy()

In [16]:
log_full = []
for ood in [X_svhn, X_lsun]:

    #X_ = augment_rot(X_)
    #X_ = augment_trans(X_)
    
    #X_valid_ = augment_rot(X_valid_)
    #X_valid_ = augment_trans(X_valid_)
    
    #_test_ = augment_rot(X_test_)
    #X_test_ = augment_trans(X_test_)

    hyp2 = SFA.hyperparameter_search(X_cifar, X_cifar_test, ood, 
                      patch_size=range(14,31,4),
                      n_components=["q90"], 
                      stride=[2], 
                      shape=(3,32,32), 
                      bs=10000, 
                      lr=1e-4,
                      epochs=20,
                      norm=[2],
                      mode="ta",
                      max_components=256,
                      remove_components=[0],
                      use_conv=False,
                      logging=1, 
                      aucs=["mean"]) 
    log_full.append(hyp2)
    print(clazz)
    
concat = pd.concat(log_full)
concat["class"] = np.repeat(np.arange(len(log_full)), len(log_full[0]))
concat.to_csv(f"./experiments/cifar10_ood_hyperparameter_search_q90_ta.csv")

# Fit SpatialICA(q90).


The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at  /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:1937.)


# Fit HugeICA((500000, 588, 256), device='cuda', bs=10000)
  4%|████████▋                                                                                                                                                                                                                 | 2/50 [00:00<00:15,  3.20it/s]

torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at  /opt/conda/conda-bld/pytorch_1634272204863/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:766.)


 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋    | 49/50 [00:06<00:00,  7.16it/s]
Ep.  0 - -1.0039 - validation (loss/white/kurt/mi/logp): -1.0036 / 0.02 / 6.66 / 1.2195 / 0.3690 (eval took: 0.0s)
# Re-Fit SpatialICA(16).
# Fit HugeICA((500000, 588, 16), device='cuda', bs=10000)
 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋    | 49/50 [00:05<00:00,  8.74it/s]
Ep.  0 - -0.9529 - validation (loss/white/kurt/mi/logp): -0.9550 / 0.04 / 4.18 / 0.0537 / 0.3661 (eval took: 0.0s)
# Compute ICA metrics.
# Fit SFA(16).
# Compute information measures
# Compute AUCs
# Compute Spread
# Compute Entropy
# Fit SpatialICA(q90).
# Fit HugeICA((329472, 972, 256), d

In [21]:
log = pd.read_csv("./experiments/cifar10_ood_hyperparameter_search_q90_ta.csv")

df = []
for i in range(2):
    l = log[log["class"] == i]
    l = l[l["nor"] == 2]
    df.append((l.sort_values("mean", ascending=False)["mean"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["patch_size"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["n_components"].head(1).item(),
              l.sort_values("negH_sum", ascending=False)["negH_sum"].head(1).item()))

df.append(tuple(np.asarray(df).mean(0)))
pd.DataFrame(df, columns=["score", "patch_size", "k", "-H"] )

Unnamed: 0,score,patch_size,k,-H
0,0.446164,26.0,36.0,0.056304
1,0.514605,26.0,36.0,0.055665
2,0.480384,26.0,36.0,0.055984
