# Variables

In [None]:
%pip install -r requirements.txt

In [None]:
# General

drivePath = "/content/drive"
myDrive = drivePath + "/MyDrive/colabStuff" # /content/drive/MyDrive/colabStuff

ucadPath = myDrive + "/UCAD" #/content/drive/MyDrive/colabStuff/UCAD
ucadSegment = ucadPath + "/segment_anything" #/content/drive/MyDrive/colabStuff/UCAD/segment_anything

In [None]:
# Dataset MVTec LOCO AD

dataset_link = "https://www.mydrive.ch/shares/48237/1b9106ccdfbb09a0c414bd49fe44a14a/download/430647091-1646842701/mvtec_loco_anomaly_detection.tar.xz"
dataset_file_name = "mvtec_loco_anomaly_detection.tar.xz"
# dataset_labels = ('breakfast_box', 'juice_bottle', 'pushpins', 'screw_bag', 'splicing_connectors')
dataset_labels = ('breakfast_box', 'juice_bottle', 'pushpins')

In [None]:
# Dataset MVTec AD

# dataset_link = "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz"
# dataset_file_name = "mvtec_anomaly_detection.tar.xz"
# dataset_labels = ('bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper')
# dataset_labels = ('bottle', 'cable', 'capsule')

In [None]:
# Dataset general
dataset_name = 'mvtec2d'
dataset_path = myDrive + '/' + dataset_name # /content/drive/MyDrive/colabStuff/mvtec2d

In [None]:
# Sam
sam_link = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"

sam_path = ucadSegment + '/dataset_sam.py'
sam_checkpoint_path = myDrive + '/sam_vit_b_01ec64.pth'
sam_b_my_path = myDrive + '/' + dataset_name + '-sam-b'

# Initial setup

In [None]:
from google.colab import drive
drive.mount(drivePath)

## Dataset

In [None]:
# DOWNLOAD THE DATA INTO MY DRIVE
!wget {dataset_link} -P {myDrive}

In [None]:
# !tar xzf ./{dataset_name}-sam-b.tar.gz
# !tar xf {myDrive}/{dataset_name}-sam-b.tar.xz
# !unzip {myDrive}/{dataset_name}-sam-b.zip -d {dataset_path}

!mkdir -p {dataset_path}
!tar xf {myDrive}/{dataset_file_name} -C {dataset_path}

## Sam

In [None]:
!wget {sam_link} -P {myDrive}

# Prepare for training

In [None]:
!cp -r {dataset_path} {sam_b_my_path}

In [None]:
%cd {ucadSegment}
!python3 {sam_path} --sam_type 'vit_b' --sam_checkpoint {sam_checkpoint_path} --data_path {sam_b_my_path}

# Training

In [None]:
import os

os.environ['DATA_PATH'] = myDrive + '/' + dataset_name # /content/drive/MyDrive/colabStuff/mvtec2d

dataset_flags = ' '.join(list(map(lambda x: '-d ' + x, dataset_labels)))
os.environ['DATASET_FLAGS'] = dataset_flags
os.environ['UCAD_PATH'] = ucadPath

In the command below the `--log_project` flag is the folder name where the results will be saved, so we could have the results in different folders for different settings/datasets.

In [None]:
%cd /content/drive/MyDrive/colabStuff/UCAD/
!CUDA_VISIBLE_DEVICES=0 python3 ./run_ucad.py --gpu 0 --epochs_num 5 --seed 0 --memory_size 196 --log_group IM224_UCAD_L5_P01_D1024_M196 --save_segmentation_images --log_project Final_LOCOAD_512_5_epo results ucad -b wideresnet50 -le layer2 -le layer3 --faiss_on_gpu --pretrain_embed_dimension 1024 --target_embed_dimension 1024 --anomaly_scorer_num_nn 1 --patchsize 1 sampler -p 0.1 approx_greedy_coreset dataset --batch_size 8 --num_workers 12 --resize 224 --imagesize 224 {dataset_flags} mvtec {dataset_path}

In [None]:
# %cd ..
# !ls
# !rm -r UCAD
# !rm -r mvtec_ad_2
# !rm -r sample_data
# !rm -r results
# !rm -r results_nolimit

In [None]:
import matplotlib.pyplot as plt
from sklearn import metrics
import numpy as np


image_metric = {
    'auroc': 0.6931664981462756,
    'fpr': np.array([
        0., 0., 0., 0.00724638, 0.00724638,
        0.02173913, 0.02173913, 0.03623188, 0.03623188, 0.04347826,
        0.04347826, 0.07246377, 0.07246377, 0.0942029, 0.0942029,
        0.10869565, 0.10869565, 0.12318841, 0.12318841, 0.13768116,
        0.13768116, 0.14492754, 0.14492754, 0.15942029, 0.15942029,
        0.16666667, 0.16666667, 0.17391304, 0.17391304, 0.18115942,
        0.18115942, 0.19565217, 0.19565217, 0.22463768, 0.22463768,
        0.23188406, 0.23188406, 0.24637681, 0.24637681, 0.25362319,
        0.25362319, 0.26086957, 0.26086957, 0.30434783, 0.30434783,
        0.3115942, 0.3115942, 0.31884058, 0.31884058, 0.34782609,
        0.34782609, 0.36231884, 0.36231884, 0.37681159, 0.37681159,
        0.38405797, 0.38405797, 0.39130435, 0.39130435, 0.41304348,
        0.41304348, 0.42028986, 0.42028986, 0.42753623, 0.42753623,
        0.43478261, 0.43478261, 0.47101449, 0.47101449, 0.47826087,
        0.47826087, 0.49275362, 0.49275362, 0.50724638, 0.50724638,
        0.52173913, 0.52173913, 0.53623188, 0.53623188, 0.54347826,
        0.54347826, 0.56521739, 0.56521739, 0.58695652, 0.58695652,
        0.60144928, 0.60144928, 0.61594203, 0.61594203, 0.63768116,
        0.63768116, 0.70289855, 0.70289855, 0.7173913, 0.7173913,
        0.72463768, 0.72463768, 0.73913043, 0.73913043, 0.76086957,
        0.76086957, 0.79710145, 0.79710145, 0.80434783, 0.80434783,
        0.8115942, 0.8115942, 0.84782609, 0.84782609, 0.86231884,
        0.86231884, 0.89130435, 0.89130435, 0.89855072, 0.89855072,
        0.93478261, 0.93478261, 0.94202899, 0.94202899, 0.95652174,
        0.95652174, 1.
    ]),
    'tpr': np.array([
        0., 0.00581395, 0.02906977, 0.02906977, 0.06976744,
        0.06976744, 0.12209302, 0.12209302, 0.20930233, 0.20930233,
        0.22093023, 0.22093023, 0.23837209, 0.23837209, 0.24418605,
        0.24418605, 0.25581395, 0.25581395, 0.28488372, 0.28488372,
        0.29069767, 0.29069767, 0.3372093, 0.3372093, 0.35465116,
        0.35465116, 0.38372093, 0.38372093, 0.38953488, 0.38953488,
        0.43604651, 0.43604651, 0.44186047, 0.44186047, 0.45930233,
        0.45930233, 0.46511628, 0.46511628, 0.47674419, 0.47674419,
        0.48837209, 0.48837209, 0.54651163, 0.54651163, 0.58139535,
        0.58139535, 0.59302326, 0.59302326, 0.59883721, 0.59883721,
        0.62209302, 0.62209302, 0.6627907, 0.6627907, 0.6744186,
        0.6744186, 0.68023256, 0.68023256, 0.68604651, 0.68604651,
        0.72093023, 0.72093023, 0.73255814, 0.73255814, 0.73837209,
        0.73837209, 0.74418605, 0.74418605, 0.76162791, 0.76162791,
        0.76744186, 0.76744186, 0.77906977, 0.77906977, 0.79069767,
        0.79069767, 0.80813953, 0.80813953, 0.81395349, 0.81395349,
        0.81976744, 0.81976744, 0.83139535, 0.83139535, 0.85465116,
        0.85465116, 0.87209302, 0.87209302, 0.87790698, 0.87790698,
        0.88953488, 0.88953488, 0.89534884, 0.89534884, 0.90697674,
        0.90697674, 0.9127907, 0.9127907, 0.91860465, 0.91860465,
        0.9244186, 0.9244186, 0.93023256, 0.93023256, 0.94186047,
        0.94186047, 0.94767442, 0.94767442, 0.95930233, 0.95930233,
        0.96511628, 0.96511628, 0.97093023, 0.97093023, 0.97674419,
        0.97674419, 0.98837209, 0.98837209, 0.99418605, 0.99418605,
        1., 1.
    ])
}


metrics.RocCurveDisplay(
    fpr=image_metric['fpr'], tpr=image_metric['tpr']).plot()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

ad_set = {
    'instance_auroc': 0.858,
    'full_pixel_auroc': 0.960,
    'anomaly_pixel_auroc': 0.954,
    'image_ap': 0.922,
    'pixel_ap': 0.470,
    'pixel_pro': 0.806,
}

locoad_set = {
    'instance_auroc': 0.752,
    'full_pixel_auroc': 0.773,
    'anomaly_pixel_auroc': 0.757,
    'image_ap': 0.832,
    'pixel_ap': 0.038,
    'pixel_pro': 0.480,
}

locoad_512_set = {
    'instance_auroc': 0.761,
    'full_pixel_auroc': 0.837,
    'anomaly_pixel_auroc': 0.820,
    'image_ap': 0.836,
    'pixel_ap': 0.047,
    'pixel_pro': 0.572,
}


labels = list(ad_set.keys())
ad_set_values = list(ad_set.values())
locoad_set_values = list(locoad_set.values())
locoad_512_set_values = list(locoad_512_set.values())

x = np.arange(len(labels))
width = 0.25

fig, ax = plt.subplots(figsize=(14, 8))

rects1 = ax.bar(x - width, ad_set_values, width,
                label='AD Set', color='skyblue')
rects2 = ax.bar(x, locoad_set_values, width,
                label='LOCOAD Set', color='royalblue')
rects3 = ax.bar(x + width, locoad_512_set_values, width,
                label='LOCOAD 512 Set', color='navy')

ax.set_ylabel('Score', fontsize=12)
ax.set_title('Performance comparison of the datasets',
             fontsize=16, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=45, ha="right")
ax.legend()

ax.bar_label(rects1, padding=3, fmt='%.3f')
ax.bar_label(rects2, padding=3, fmt='%.3f')
ax.bar_label(rects3, padding=3, fmt='%.3f')

fig.tight_layout()

plt.savefig('performance_comparison.png', dpi=300)

plt.show()