In [1]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import os

from data.digital_printing_defects import DigitalPrintingDefects
from hugeica import *

from torchvision.transforms import transforms
from sklearn.metrics import roc_auc_score

In [None]:
np.random.seed(32)

log_full = []
for clazz in ["IMG_4", "IMG_6", "IMG_8", "IMG_10", "IMG_11" , "IMG_12", "IMG_16" , "IMG_20", "IMG_22", "IMG_25"]:
    trans = [transforms.ToTensor()]
    X_, X_valid_, X_test_ = DigitalPrintingDefects(clazz, z_normalize=False, transform=transforms.Compose(trans), patch_size=96, stride=96)[0]
    
    X_ = X_[np.random.permutation(range(len(X_)))[:1000]]
    X_valid_ = X_valid_[np.random.permutation(range(len(X_test_)))[:1000]]
    X_test_ = X_test_[np.random.permutation(range(len(X_test_)))[:1000]]
    
    hyp = SFA.hyperparameter_search(X_, X_valid_, X_test_, 
                      patch_size=range(16, 95, 8), 
                      n_components=["q90"], 
                      stride=[2], 
                      shape=(3, 96, 96), 
                      max_components = 256,
                      bs=10000, 
                      epochs=10,
                      norm=[2],
                      lr=1e-2,
                      compute_bpd=False,
                      mode="ta", 
                      use_conv=True,
                      aucs=["mean"],
                      logging=1)
    
    log_full.append(hyp)
    
    concat = pd.concat(log_full)
    concat["class"] = np.repeat(np.arange(len(log_full)), len(log_full[0]))
    concat.to_csv(f"./experiments/digital_printing_hyperparameter_search_ta_q90_{clazz}.csv")

Loading train data.
Loading inlier data.
Anomaly is not in the center of the patch (45)
122 patches skipped because error was too close to the border.
Loading outliers data.
Anomaly is not in the center of the patch (45)
122 patches skipped because error was too close to the border.
DigitalPrintingDefects(image=IMG_4, error=2, dual=True, mode=rgb, patch_size=96, stride=96, data_train=(7100, 27648), data_test_inliers=(3308, 27648), data_test_outliers=(3308, 27648))
# Fit SpatialICA(q90).
# Fit HugeICA((1000, 27648, 256), device='cuda', bs=5)
100%|█████████▉| 199/200 [01:59<00:00,  1.67it/s]
Ep.  0 - -1.0058 - validation (loss/white/kurt/mi/logp): -1.0081 / 0.00 / 0.63 / 0.0268 / 0.4052 (eval took: 1.0s)
# Re-Fit SpatialICA(18).
# Fit HugeICA((1000, 27648, 18), device='cuda', bs=5)
100%|█████████▉| 199/200 [01:22<00:00,  2.42it/s]
Ep.  0 - -0.9885 - validation (loss/white/kurt/mi/logp): -0.9890 / 0.02 / 0.34 / 0.0006 / 0.4051 (eval took: 0.1s)
Ep.  1 - -0.9908 - validation (loss/white/ku

  k      = model.change_variance_.max()/model.change_variance_.min()


 99%|█████████▉| 142/143 [03:43<00:01,  1.57s/it]
Ep.  0 - -1.0017 - validation (loss/white/kurt/mi/logp): -1.0076 / 0.00 / 0.52 / 0.0211 / 0.4072 (eval took: 0.9s)
# Re-Fit SpatialICA(34).
# Fit HugeICA((1001, 27648, 34), device='cuda', bs=7)
 99%|█████████▉| 142/143 [03:15<00:01,  1.38s/it]
Ep.  0 - -0.9890 - validation (loss/white/kurt/mi/logp): -0.9943 / 0.01 / 0.36 / 0.0014 / 0.4062 (eval took: 0.2s)
Ep.  1 - -0.9947 - validation (loss/white/kurt/mi/logp): -0.9946 / 0.01 / 0.37 / 0.0015 / 0.4060 (eval took: 0.2s)
Ep.  2 - -0.9951 - validation (loss/white/kurt/mi/logp): -0.9949 / 0.01 / 0.37 / 0.0015 / 0.4058 (eval took: 0.1s)
Ep.  3 - -0.9953 - validation (loss/white/kurt/mi/logp): -0.9950 / 0.01 / 0.38 / 0.0015 / 0.4057 (eval took: 0.1s)
Ep.  4 - -0.9953 - validation (loss/white/kurt/mi/logp): -0.9950 / 0.01 / 0.38 / 0.0015 / 0.4056 (eval took: 0.1s)
Ep.  5 - -0.9954 - validation (loss/white/kurt/mi/logp): -0.9950 / 0.01 / 0.38 / 0.0016 / 0.4055 (eval took: 0.2s)
Ep.  6 - -0.9954

In [1]:


log = pd.read_csv("./experiments/digital_printing_hyperparameter_search_ta_q90_IMG_25.csv")

res = []
for i in range(10):
    l = log[log["class"] == i]
    l = l[l["nor"] == 2]
    l = l[l["remove_components"] >= 0]
    res.append(l.sort_values("negH_sum", ascending=False)["mean"].head(1).item())
    print(l.sort_values("negH_sum", ascending=False)["patch_size"].head(1).item(),
          l.sort_values("negH_sum", ascending=False)["n_components"].head(1).item(),
          l.sort_values("negH_sum", ascending=False)["negH_sum"].head(1).item())    
print(" &\n".join([f"{a:.3f}" for a in np.asarray(res).round(3)]), "& \n", np.asarray(res).mean().round(3))

80.0 77.0 0.043001287
88.0 61.0 0.024408061
80.0 63.0 0.02460811
56.0 21.0 0.026292354
88.0 67.0 0.068674795
72.0 67.0 0.21009833
88.0 45.0 0.016352198999999998
88.0 25.0 0.012398221000000001
56.0 107.0 0.22272813
72.0 59.0 0.045595344
0.924 &
0.616 &
0.678 &
0.620 &
0.562 &
0.683 &
0.642 &
0.548 &
0.843 &
0.756 & 
 0.687
