# Imports

In [29]:
# Imports
########################

# Standard libraries
import os
import time
import random
from collections import OrderedDict

IS_COLAB = any(var.startswith("COLAB_") for var in os.environ)
IS_KAGGLE = any(var.startswith("KAGGLE_") for var in os.environ)

if IS_KAGGLE:
  IS_COLAB = False
  print("Running in Kaggle Notebook...")
  print("Installing libraries...")
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
  %pip install lpips --no-deps -q
  %pip install wandb seaborn -qU
  print("Libraries installed.")

if IS_COLAB:
  print("Running in Google Colab...")
  print("Installing libraries...")
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
  %pip install lpips --no-deps -q
  %pip install wandb seaborn -qU
  print("Libraries installed.")

else:
  print("Not running in Google Colab.")

# Data manipulation and analysis
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import roc_auc_score, roc_curve, auc, precision_recall_curve
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

# Image processing and visualization
import seaborn as sns
import skimage.io as io
from skimage import measure
from skimage.io import imread
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter

# PyTorch and deep learning
import torch
import torch.nn as nn
from lpips import LPIPS
from torch.optim import Adam
import torch.utils.data as data
import torch.nn.functional as F
import torchvision.utils as utils
from torchvision import transforms
from torch.autograd import Variable
from einops import rearrange, repeat

# Environment
import wandb
import kagglehub


Running in Kaggle Notebook...
Installing libraries...
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Libraries installed.
Not running in Google Colab.


# Setup

## Drive Setup

In [30]:
# Google Drive Setup
########################


# Set up directories
if IS_KAGGLE:
    BASE_PATH = '/kaggle/working'
    MODEL_SAVE_DIR = os.path.join(BASE_PATH, 'trained_models')
    DATASET_DIR = '/kaggle/input'
    MVTEC_DATA_DIR = os.path.join(DATASET_DIR, 'mvtec-ad')
    BTAD_DATA_DIR = os.path.join(DATASET_DIR, 'btad-beantech-anomaly-detection')

elif IS_COLAB:
    from google.colab import drive
    print("Mounting Google Drive...")
    drive.mount('/content/drive')
    print("Google Drive mounted.")
    BASE_PATH = '/content/drive/MyDrive/Colab Notebooks/Computer Vision/'
    MODEL_SAVE_DIR = os.path.join(BASE_PATH, 'trained_models')
    DATASET_DIR = os.path.join(BASE_PATH, 'datasets')
    MVTEC_DATA_DIR = os.path.join(DATASET_DIR, 'mvtec')
    BTAD_DATA_DIR = os.path.join(DATASET_DIR, 'btad')
else:
    BASE_PATH = os.path.dirname(os.curdir) #os.path.abspath('.')
    MODEL_SAVE_DIR = os.path.join(BASE_PATH, 'trained_models')
    DATASET_DIR = os.path.join(BASE_PATH, 'datasets')

# Download MVTEC and BTAD dataset
# MVTEC_DATA_DIR = kagglehub.dataset_download("ipythonx/mvtec-ad", force_download=False)
# BTAD_DATA_DIR = kagglehub.dataset_download("thtuan/btad-beantech-anomaly-detection", force_download=False)

# Create directories if they don't exist
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
# os.makedirs(MVTEC_DATA_DIR, exist_ok=True)
# os.makedirs(BTAD_DATA_DIR, exist_ok=True)

# Set num_workers for data loaders
NUM_WORKERS_PARAM = {
    'num_workers': 0  # Set to 0 for Colab to avoid crashes
}


## Device Setup

In [31]:
# Set up seeds
SEED = 71
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Set up device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == 'cuda':
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True
else:
    print("CUDA is not available. Using CPU.")
    torch.backends.cudnn.enabled = False


Using GPU: Tesla T4


# Globals

In [32]:
# Globals
########################

####################
# Script Arguments #
####################
TRAIN_MODEL = True      # True: train+test, False: test
USING_MVTEC = False
NUM_EPOCHS = 400
USING_DYT = True
LOG_LPIPS = False
USING_LW_SSIM = True
USING_RANDOM_MASK = True

COMPARE_MODELS = False #TODO?

# Model Configuration Parameters
BATCH_SIZE = 8
PATCH_SIZE = 64
NUM_GAUSSIANS = 150
LEARNING_RATE = 1e-4

LAMBDA_MSE = 5
DYT_INIT_A = 0.5
LAMBDA_LPIPS = 0        # we only log it
LAMBDA_SSIM = 0.5
NOISE_DECAY = True
WEIGHT_DECAY = 1e-4
NOISE_DECAY_FACTOR = 1  # handles noise decaying, better not touch


# --- Derived Configurations ---
NORM_STR = "DyT" if USING_DYT else "LayerNorm"
DATASET_STR = "MVTEC" if USING_MVTEC else "BTAD"
DATASET_CAT_STR = "bottle" if USING_MVTEC else '01' # Example: 'bottle', 'cable', 'hazelnut' for MVTec; '01', '02', '03' for BTAD
RUN_NAME = f"{NORM_STR}_{DATASET_STR}_e{NUM_EPOCHS}_" + ("lwssim" if USING_LW_SSIM else "ssim") + "_" + ("randommask" if USING_RANDOM_MASK else "no_randommask") + "_" + ("kaggle" if IS_KAGGLE else "colab")

model_path = os.path.join(MODEL_SAVE_DIR, f'VT_AE_{NORM_STR}_{DATASET_STR}_{DATASET_CAT_STR}.pt')
g_path = os.path.join(MODEL_SAVE_DIR, f'G_estimate_{NORM_STR}_{DATASET_STR}_{DATASET_CAT_STR}.pt')


# Minimum number of patches for Vision Transformer
MIN_NUM_PATCHES = 16

# Visualization settings
PLOT_COLORS = {"normal": "green", "anomaly": "red"}
PLOT_STYLE = "whitegrid"

# Time measurements
TIME_IN_NORM = 0
NORM_CALLS = 0

# Utils

## Weight & Biases 

In [33]:
def setup_wandb():
    """
    Setup wandb for logging
    """

    # Get the wandb API key from secrets or environment variables
    if IS_COLAB:
        from google.colab import userdata
        wandbkey = userdata.get('WANDB_API_KEY')

    elif IS_KAGGLE:
        from kaggle_secrets import UserSecretsClient
        user_secrets = UserSecretsClient()
        wandbkey = user_secrets.get_secret("WANDB_API_KEY")

    else:
        wandbkey = os.environ['WANDB_API_KEY']

    # Initialize wandb
    wandb.login(key=wandbkey)
    wandb.init(
            entity = "CV_albe_gab_kri",
            project = "CV_Final_runs",
            name = RUN_NAME,
            config={
                # Dataset Configuration
                "dataset": DATASET_STR,
                "product_category": DATASET_CAT_STR,
                "using_mvtec": USING_MVTEC,

                # Model Architecture
                "patch_size": PATCH_SIZE,

                # Training Hyperparameters
                "num_epochs": NUM_EPOCHS,
                "learning_rate": LEARNING_RATE,
                "weight_decay": WEIGHT_DECAY,
                "batch_size": BATCH_SIZE,

                # Loss Function Weights
                "lambda_mse": LAMBDA_MSE,
                "lambda_ssim": LAMBDA_SSIM,

                # Normalization and Activation
                "normalization_type": NORM_STR,
                "dyt_init_a": DYT_INIT_A,

                # Loss Function Configuration
                "using_lw_ssim": USING_LW_SSIM,
                "ssim_type": "LWSSIM" if USING_LW_SSIM else "Standard SSIM",

                # Data Augmentation and Noise
                "using_random_mask": USING_RANDOM_MASK,
                "noise_decay": NOISE_DECAY,
                "noise_decay_factor": NOISE_DECAY_FACTOR,

                # Model Paths for Reproducibility
                "base_path": BASE_PATH,
                "model_save_dir": MODEL_SAVE_DIR,
                "dataset_dir": DATASET_DIR,

                # Hardware Configuration
                "device": str(DEVICE),
                "seed": SEED,

                # Environment Detection
                "computed_by" : "COLAB" if IS_COLAB else ("KAGGLE" if IS_KAGGLE else "LOCAL"),

                # Training Configuration Flags
                "train_model": TRAIN_MODEL,
                "compare_models": COMPARE_MODELS,
            })


## Plots

In [34]:
def plot(image, grnd_truth, score):
    """Plot image, ground truth and predicted score"""
    plt.figure(figsize=(15, 5))
    plt.subplot(131)
    plt.imshow(image[0].permute(1, 2, 0))
    plt.title('Original Image')
    plt.subplot(132)
    plt.imshow(grnd_truth.squeeze(0).squeeze(0))
    plt.title('Ground Truth')
    plt.subplot(133)
    plt.imshow(score)
    plt.title('Anomaly Score')
    plt.colorbar()
    plt.tight_layout()
    plt.pause(1)
    plt.show()

def plot_enhanced(image, grnd_truth, reconstructed, score, threshold=None, save_path=None):
    """Enhanced plot with threshold and better visualization"""
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        plt.figure(figsize=(24, 8))

        # Original image
        plt.subplot(151)
        plt.imshow(image[0].permute(1, 2, 0))
        plt.title('Original Image')
        plt.axis('off')

        plt.subplot(152)
        plt.imshow(reconstructed[0].permute(1, 2, 0).clip(0,1))
        plt.title('Reconstructed Image')
        plt.axis('off')

        # Ground truth
        plt.subplot(153)
        plt.imshow(grnd_truth.squeeze(0).squeeze(0), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')

        # Anomaly score
        plt.subplot(154)
        heatmap = plt.imshow(score, cmap='jet')
        plt.title('Anomaly Score')
        plt.axis('off')
        plt.colorbar(heatmap, fraction=0.046, pad=0.04)

        # Thresholded result (if threshold provided)
        if threshold is not None:
            plt.subplot(155)
            binary_mask = np.where(score > threshold, 1., 0.)
            plt.imshow(binary_mask, cmap='gray')
            plt.title(f'Binary Result (t={threshold:.3f})')
            plt.axis('off')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def plot_roc_curve(fpr, tpr, auc_score, color='darkorange', label='', title='Receiver Operating Characteristic', save_path=None):
    """Plot ROC curve with AUC score"""
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        plt.figure(figsize=(8, 8))
        plt.plot(fpr, tpr, lw=2, color=color, label=f'(AUC = {auc_score:.3f})')
        plt.fill_between(fpr, tpr, alpha=0.2, color=color)
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(title)
        plt.legend(loc="lower right")
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def plot_precision_recall_curve(precision, recall, pr_auc, color='green', label='', title='Precision-Recall Curve', save_path=None):
    """Plot precision-recall curve with AUC score"""
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        plt.figure(figsize=(8, 8))
        plt.plot(recall, precision, lw=2, color=color, label=f'(AUC = {pr_auc:.3f})')
        plt.fill_between(recall, precision, alpha=0.2, color=color)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title(title)
        plt.legend(loc="lower left")
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def plot_confusion_matrix(y_true, y_pred, save_path=None, normalize=False):
    """Plot (optionally normalized) confusion matrix using seaborn"""
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        cm = confusion_matrix(y_true, y_pred)

        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

        plt.figure(figsize=(8, 8))
        sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues',
                    xticklabels=['Normal', 'Anomaly'],
                    yticklabels=['Normal', 'Anomaly'])
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix' + (' (Normalized)' if normalize else ''))
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def plot_score_distributions(normal_scores, anomaly_scores, threshold=None, save_path=None):
    """Plot histogram of normal and anomaly scores with threshold line if provided"""
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        plt.figure(figsize=(10, 6))
        sns.histplot(normal_scores, color=PLOT_COLORS["normal"], label="Normal", alpha=0.6, kde=True)
        sns.histplot(anomaly_scores, color=PLOT_COLORS["anomaly"], label="Anomaly", alpha=0.6, kde=True)

        if threshold is not None:
            plt.axvline(x=threshold, color='black', linestyle='--', label=f'Threshold: {threshold:.3f}')

        plt.xlabel('Anomaly Score')
        plt.ylabel('Density')
        plt.title('Distribution of Anomaly Scores')
        plt.legend()

        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def visualize_regions(image, score_map, threshold, min_area=100, save_path=None):
    """Visualize detected anomaly regions"""
    # Threshold the score map and find contours
    binary_mask = np.where(score_map > threshold, 1., 0.)
    labeled_mask = measure.label(binary_mask)
    regions = measure.regionprops(labeled_mask)

    # Filter small regions
    filtered_regions = [region for region in regions if region.area >= min_area]

    # Visualize
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        plt.figure(figsize=(12, 8))

        # Original image
        plt.subplot(121)
        plt.imshow(image[0].permute(1, 2, 0))
        plt.title('Original Image')
        plt.axis('off')

        # Image with detected regions
        plt.subplot(122)
        plt.imshow(image[0].permute(1, 2, 0))

        # Draw bounding boxes around regions
        for region in filtered_regions:
            minr, minc, maxr, maxc = region.bbox
            rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr,
                                fill=False, edgecolor='red', linewidth=2)
            plt.gca().add_patch(rect)

            # Add area text
            plt.text(minc, minr - 5, f"Area: {region.area}",
                    color='white', fontsize=9, backgroundcolor='red')

        plt.title(f'Detected Anomalies (n={len(filtered_regions)})')
        plt.axis('off')

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()


## Newtork Utils

In [35]:
def initialize_weights(*models):
    """Initialize network weights using kaiming normal"""
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()


def process_mask(mask):
    """Process mask to binary format"""
    mask = np.where(mask > 0., 1, mask)
    return torch.tensor(mask)

# Datasets

## MVTEC

In [36]:
class MVTEC:
    def __init__(self, batch_size, root=MVTEC_DATA_DIR, product='bottle'):
        self.root = root
        self.batch = batch_size
        self.product = product

        if self.product == 'all':
            print('--------Please select a valid product.......See Train_data function-----------')
        else:
            # Importing all the image_path dictionaries for test and train data
            train_path_images = self.Train_data()
            test_norm_path_images = self.Test_normal_data()
            test_anom_image_paths, test_anom_mask_paths = self.load_test_anom_images_and_masks()

            # Image Transformation
            T = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((550, 550)),
                transforms.CenterCrop(512),
                transforms.ToTensor(),
            ])

            train_normal_image = torch.stack([T(self.load_images(j, i)) for j in train_path_images.keys() for i in train_path_images[j]])
            test_normal_image = torch.stack([T(self.load_images(j, i)) for j in test_norm_path_images.keys() for i in test_norm_path_images[j]])

            train_normal_mask = torch.zeros(train_normal_image.size(0), 1, train_normal_image.size(2), train_normal_image.size(3))
            test_normal_mask = torch.zeros(test_normal_image.size(0), 1, test_normal_image.size(2), test_normal_image.size(3))

            test_anom_image = torch.stack([
                T(self.load_images(os.path.dirname(p), os.path.basename(p)))
                for p in test_anom_image_paths
            ])
            
            test_anom_mask = torch.stack([
                process_mask(T(self.load_images(os.path.dirname(p), os.path.basename(p))))
                for p in test_anom_mask_paths
            ])

            
            train_normal = tuple(zip(train_normal_image, train_normal_mask))
            test_anom = tuple(zip(test_anom_image, test_anom_mask))
            test_normal = tuple(zip(test_normal_image, test_normal_mask))

            print(f' --Size of {self.product} train loader: {train_normal_image.size()}--')
            if test_anom_image.size(0) == test_anom_mask.size(0):
                print(f' --Size of {self.product} test anomaly loader: {test_anom_image.size()}--')
            else:
                print(f'[!Info] Size Mismatch between Anomaly images {test_anom_image.size()} and Masks {test_anom_mask.size()} Loaded')
            print(f' --Size of {self.product} test normal loader: {test_normal_image.size()}--')

            # Create validation set
            num = self.ran_generator(len(test_anom), 10)
            val_anom = [test_anom[i] for i in num]
            num = self.ran_generator(len(test_normal), 10)
            val_norm = [test_normal[j] for j in num]
            val_set = [*val_norm, *val_anom]
            print(f' --Total Image in {self.product} Validation loader: {len(val_set)}--')

            # Final Data Loader - Updated with num_workers=0
            self.train_loader = torch.utils.data.DataLoader(train_normal, batch_size=batch_size, shuffle=True, **NUM_WORKERS_PARAM)
            self.test_anom_loader = torch.utils.data.DataLoader(test_anom, batch_size=batch_size, shuffle=False, **NUM_WORKERS_PARAM)
            self.test_norm_loader = torch.utils.data.DataLoader(test_normal, batch_size=batch_size, shuffle=False, **NUM_WORKERS_PARAM)
            self.validation_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, **NUM_WORKERS_PARAM)

    def ran_generator(self, length, shots=1):
        """Generate random indices"""
        rand_list = random.sample(range(0, length), shots)
        return rand_list

    def read_files(self, d, data_motive='train', use_good=True, normal=True):
        '''
        Return the path of the train directory and list of train images

        Parameters:
            root : root directory of mvtech images
            d = List of directories in the root directory
            product : name of the product to return the images for single class training
            data_motive : Can be 'train' or 'test' or 'ground_truth'
            use_good : To use the data in the good folder
            normal : Signify if the normal images are included

        Returns:
            Path and Image ordered dict for the dataset
        '''
        files = next(os.walk(os.path.join(self.root, d)))[1]
        for d_in in files:
            if os.path.isdir(os.path.join(self.root, d, d_in)):
                if d_in == data_motive:
                    im_pt = OrderedDict()
                    file = os.listdir(os.path.join(self.root, d, d_in))

                    for i in file:
                        if os.path.isdir(os.path.join(self.root, d, d_in, i)):
                            if (data_motive == 'train'):
                                tr_img_pth = os.path.join(self.root, d, d_in, i)
                                images = os.listdir(tr_img_pth)
                                im_pt[tr_img_pth] = images
                                print(f'total {d_in} images of {i} {d} are: {len(images)}')

                            if (data_motive == 'test'):
                                if (use_good == False) and (i == 'good') and normal != True:
                                    print(f'the good images for {d_in} images of {i} {d} is not included in the test anomalous data')
                                elif (use_good == False) and (i != 'good') and normal != True:
                                    tr_img_pth = os.path.join(self.root, d, d_in, i)
                                    images = os.listdir(tr_img_pth)
                                    im_pt[tr_img_pth] = images
                                    print(f'total {d_in} images of {i} {d} are: {len(images)}')
                                elif (use_good == True) and (i == 'good') and (normal == True):
                                    tr_img_pth = os.path.join(self.root, d, d_in, i)
                                    images = os.listdir(tr_img_pth)
                                    im_pt[tr_img_pth] = images
                                    print(f'total {d_in} images of {i} {d} are: {len(images)}')
                            if (data_motive == 'ground_truth'):
                                tr_img_pth = os.path.join(self.root, d, d_in, i)
                                images = os.listdir(tr_img_pth)
                                im_pt[tr_img_pth] = images
                                print(f'total {d_in} images of {i} {d} are: {len(images)}')
                    if self.product == "all":
                        return
                    else:
                        return im_pt

    def load_images(self, path, image_name):
        """Load image from path"""
        return imread(os.path.join(path, image_name))

    def Train_data(self, use_good=True):
      '''
      Return the path of the train directory and list of train images
      '''
      dir = os.listdir(self.root)

      for d in dir:
          if self.product == "all":
              self.read_files(d, data_motive='train')
          elif self.product == d:
              pth_img = self.read_files( d, data_motive='train')
              return pth_img

    def load_test_anom_images_and_masks(self):
        """
        Return two lists: paths of anomalous test images and their corresponding masks.
        Assumes:
          - image: '000.png'
          - mask:  '000_mask.png'
        """
        image_paths = []
        mask_paths = []
        dir = os.listdir(self.root)
    
        for d in dir:
            if self.product != d:
                continue
    
            # Anomalous images
            test_dir = os.path.join(self.root, d, 'test')
            gt_dir = os.path.join(self.root, d, 'ground_truth')
    
            for defect_type in os.listdir(test_dir):
                if defect_type == 'good':
                    continue
    
                defect_image_dir = os.path.join(test_dir, defect_type)
                defect_mask_dir = os.path.join(gt_dir, defect_type)
    
                for img_file in os.listdir(defect_image_dir):
                    img_path = os.path.join(defect_image_dir, img_file)
                    mask_file = img_file.replace('.png', '_mask.png')
                    mask_path = os.path.join(defect_mask_dir, mask_file)
    
                    if not os.path.exists(mask_path):
                        print(f"[!] Warning: No mask found for image {img_file}")
                        continue
    
                    image_paths.append(img_path)
                    mask_paths.append(mask_path)
    
            break  # since we found the product, no need to continue
    
        return image_paths, mask_paths

    
    def Test_normal_data(self, use_good=True):
        '''
        Return path and images for normal test data
        '''
        if self.product == 'all':
            print('Please choose a valid product. Normal test data can be seen product wise')
            return
        dir = os.listdir(self.root)

        for d in dir:
            if self.product == d:
                pth_img = self.read_files(d, data_motive='test', use_good=True, normal=True)
                return pth_img



## BTAD

In [37]:
class BTAD:
    """
    BTAD dataset class that mirrors the Mvtec class interface.
    This ensures consistency between MVTec and BTAD dataset handling.
    """
    def __init__(self, batch_size, root=BTAD_DATA_DIR, product='01'):
        self.root = root
        self.batch = batch_size
        self.product = product

        if self.product not in ['01', '02', '03']:
            print(f'--------Please select a valid product: 01, 02, or 03. Got: {self.product}-----------')
        else:
            print(f"Loading BTAD dataset for product {self.product}...")

            # Create individual DataLoaders using your existing load_btad_dataset function
            # Training data (normal samples only)
            self.train_loader = load_btad_dataset(
                root_dir=self.root,
                product=self.product,
                batch_size=batch_size,
                mode='train',
                resize=True,
                use_normal=True,
                anomalous=False
            )

            # Test normal data (normal samples from test set)
            self.test_norm_loader = load_btad_dataset(
                root_dir=self.root,
                product=self.product,
                batch_size=batch_size,
                mode='test',
                resize=True,
                use_normal=True,
                anomalous=False
            )

            # Test anomalous data (anomalous samples from test set with ground truth masks)
            self.test_anom_loader = load_btad_dataset(
                root_dir=self.root,
                product=self.product,
                batch_size=batch_size,
                mode='test',
                resize=True,
                use_normal=False,
                anomalous=True
            )

            # Create a simple validation loader (you can enhance this later)
            # For now, we'll use the test normal loader as validation
            # In a more sophisticated setup, you might want to split the training data
            self.validation_loader = self.test_norm_loader

            print(f"BTAD dataset for product {self.product} loaded successfully!")
            print(f"Train loader: Ready")
            print(f"Test normal loader: Ready")
            print(f"Test anomalous loader: Ready")
            print(f"Validation loader: Using test normal data")


class BTADDataset(data.Dataset):
    def __init__(self, images, masks=None, resize=True):
        """
        BTAD dataset class that exactly matches MVTec's data structure

        This version ensures that masks are processed identically to MVTec,
        eliminating shape inconsistencies in downstream processing.
        """
        self.images = images
        self.masks = masks
        self.resize = resize

        self.resize_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((512, 512)),  # Match VT-ADL Paper
            transforms.ToTensor()
        ])

        self.default_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):
        # Load and process image exactly like MVTec
        if isinstance(self.images[i], str):
            image_ = io.imread(self.images[i])

            # Convert grayscale to RGB if needed
            if len(image_.shape) < 3:
                image_ = np.stack((image_,) * 3, axis=-1)

            if self.resize:
                image = self.resize_transform(image_)
            else:
                image = self.default_transform(image_)
        else:
            image = self.images[i]

        # Process masks to EXACTLY match MVTec's structure
        if self.masks is not None and self.masks[i] is not None:
            if isinstance(self.masks[i], str):
                # Load mask from file and apply SAME transform as image
                mask_ = io.imread(self.masks[i], as_gray=True)

                if self.resize:
                    # Apply the EXACT same transform pipeline as the image
                    # This ensures mask and image have identical dimensions
                    mask = self.resize_transform(mask_)
                else:
                    mask = self.default_transform(mask_)

                # Apply process_mask function exactly like MVTec does
                mask = process_mask(mask)
            else:
                mask = self.masks[i]
        else:
            # Create zero mask with EXACT same structure as MVTec
            # MVTec creates: torch.zeros(batch_size, 1, height, width)
            # For individual samples, this becomes: torch.zeros(1, height, width)
            mask = torch.zeros(1, image.size(1), image.size(2))

        return image, mask

def load_btad_dataset(root_dir, product, batch_size, mode='train', resize=True, use_normal=True, anomalous=False):
    """
    Load BTAD dataset with MVTec-compatible data structures.
    Looks for the product folder, even if nested one level under root_dir.
    """
    # Attempt direct path
    product_path = os.path.join(root_dir, product)

    # If not found, search one level deeper
    if not os.path.isdir(product_path):
        for sub in sorted(os.listdir(root_dir)):
            sub_path = os.path.join(root_dir, sub)
            candidate = os.path.join(sub_path, product)
            if os.path.isdir(candidate):
                print(f"Found product '{product}' under nested directory '{sub}'.")
                root_dir = sub_path
                product_path = candidate
                break

    # Final check
    if not os.path.isdir(product_path):
        available = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        raise ValueError(f"Product '{product}' not found in {root_dir}. Available products: {available}")

    image_paths = []
    mask_paths = []

    # Define paths
    if mode == 'train':
        base_img_dir = os.path.join(product_path, 'train')
        gt_available = False
    else:
        base_img_dir = os.path.join(product_path, 'test')
        gt_dir = os.path.join(product_path, 'ground_truth', 'ko')
        gt_available = os.path.isdir(gt_dir)

    # Load images
    if mode == 'train':
        # Normal samples
        if use_normal:
            ok_dir = os.path.join(base_img_dir, 'ok')
            if os.path.isdir(ok_dir):
                for f in sorted(os.listdir(ok_dir)):
                    if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                        image_paths.append(os.path.join(ok_dir, f))
                        mask_paths.append(None)
        # Anomalous training
        if anomalous:
            ko_dir = os.path.join(base_img_dir, 'ko')
            if os.path.isdir(ko_dir):
                for f in sorted(os.listdir(ko_dir)):
                    if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                        image_paths.append(os.path.join(ko_dir, f))
                        mask_paths.append(None)
    else:
        # Test normal
        if use_normal:
            ok_dir = os.path.join(base_img_dir, 'ok')
            if os.path.isdir(ok_dir):
                for f in sorted(os.listdir(ok_dir)):
                    if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                        image_paths.append(os.path.join(ok_dir, f))
                        mask_paths.append(None)
        # Test anomalous
        if anomalous:
            ko_dir = os.path.join(base_img_dir, 'ko')
            if os.path.isdir(ko_dir):
                for f in sorted(os.listdir(ko_dir)):
                    if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                        image_paths.append(os.path.join(ko_dir, f))
                        # ground truth mask
                        mask_file_base = os.path.splitext(f)[0]
                        mask = None
                        if gt_available:
                            for ext in ['.png', '.jpg', '.jpeg', '.bmp']:
                                candidate = os.path.join(gt_dir, mask_file_base + ext)
                                if os.path.isfile(candidate):
                                    mask = candidate
                                    break
                        if mask is None:
                            print(f"Warning: No mask for {f}")
                        mask_paths.append(mask)

    print(f"Found {len(image_paths)} images for product {product} ({mode}, ok={use_normal}, ko={anomalous})")

    # Create dataset and dataloader
    dataset = BTADDataset(image_paths, mask_paths, resize)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(mode == 'train'),
        **NUM_WORKERS_PARAM
    )
    return data_loader


# Losses

## SSIM

In [38]:
class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = self.create_window(window_size, self.channel)

    def gaussian(self, window_size, sigma):
        gauss = torch.Tensor([np.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
        return gauss/gauss.sum()

    def create_window(self, window_size, channel):
        _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
        return window

    def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
        mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
        mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1*mu2

        sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq
        sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq
        sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2

        C1 = 0.01**2
        C2 = 0.03**2

        ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

        if size_average:
            return ssim_map.mean()
        else:
            return ssim_map.mean(1).mean(1).mean(1)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = self.create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return self._ssim(img1, img2, window, self.window_size, channel, self.size_average)


## Level Weighted SSIM (LWSSIM)

In [39]:
#NEW!
class LWSSIM(nn.Module):
    """
    Level Weighted Structural Similarity Loss with improved numerical stability
    to prevent NaN errors during backpropagation.
    """
    def __init__(self, filter_sizes=[11, 9, 7, 5, 3], filter_weights=None,
                 data_range=1.0, alpha=1.0, beta=1.0, gamma=1.0,
                 C1=0.01**2, C2=0.03**2, C3=None, eps=1e-8):
        super(LWSSIM, self).__init__()

        # Filter sizes for different scales
        self.filter_sizes = filter_sizes
        self.num_levels = len(filter_sizes)

        # Default to equal weighting if not specified
        if filter_weights is None:
            self.filter_weights = torch.ones(self.num_levels) / self.num_levels
        else:
            self.filter_weights = torch.tensor(filter_weights)
            # Normalize weights to sum to 1
            self.filter_weights = self.filter_weights / self.filter_weights.sum()

        self.register_buffer('weights', self.filter_weights)

        # Component weights
        self.alpha = alpha  # Luminance weight
        self.beta = beta    # Contrast weight
        self.gamma = gamma  # Structure weight

        # Constants to avoid division by zero
        self.C1 = C1
        self.C2 = C2
        self.C3 = C3 if C3 is not None else C2/2

        # Epsilon for numerical stability
        self.eps = eps

        self.data_range = data_range

        # Create and register Gaussian windows individually
        for size in filter_sizes:
            # Create Gaussian window and register it as a buffer with a unique name
            window = self._create_gaussian_window(size)
            self.register_buffer(f'window_{size}', window)

    def _create_gaussian_window(self, window_size, sigma=1.5):
        """
        Create a 2D Gaussian window for filtering.
        """
        # Create a 1D Gaussian kernel
        coords = torch.arange(window_size, dtype=torch.float)
        coords -= window_size // 2

        # Gaussian function
        gauss = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
        gauss /= gauss.sum()

        # Create 2D Gaussian kernel by outer product
        kernel = gauss.unsqueeze(0) * gauss.unsqueeze(1)

        # Normalize
        kernel /= kernel.sum()

        # Reshape to [1, 1, window_size, window_size] for conv2d operation
        return kernel.unsqueeze(0).unsqueeze(0)

    def _get_window(self, window_size):
        """Get the Gaussian window for the given size."""
        return getattr(self, f'window_{window_size}')

    def _compute_ssim_components(self, x, y, window_size):
        """
        Compute the luminance, contrast, and structure components of SSIM
        with improved numerical stability.
        """
        window = self._get_window(window_size)
        padding = window_size // 2

        # Expand window to match input channels
        channel_window = window.expand(x.shape[1], 1, window_size, window_size)

        # Calculate means for each channel
        mu_x = F.conv2d(x, channel_window, padding=padding, groups=x.shape[1])
        mu_y = F.conv2d(y, channel_window, padding=padding, groups=y.shape[1])

        mu_x_sq = mu_x ** 2
        mu_y_sq = mu_y ** 2
        mu_xy = mu_x * mu_y

        # Calculate variances and covariance
        sigma_x_sq = F.conv2d(x ** 2, channel_window, padding=padding, groups=x.shape[1]) - mu_x_sq
        sigma_y_sq = F.conv2d(y ** 2, channel_window, padding=padding, groups=y.shape[1]) - mu_y_sq
        sigma_xy = F.conv2d(x * y, channel_window, padding=padding, groups=x.shape[1]) - mu_xy

        # Apply a stronger clamp to prevent very small negative values from numerical errors
        sigma_x_sq = torch.clamp(sigma_x_sq, min=self.eps)
        sigma_y_sq = torch.clamp(sigma_y_sq, min=self.eps)

        # Compute standard deviations - add eps inside the sqrt for stability
        sigma_x = torch.sqrt(sigma_x_sq)
        sigma_y = torch.sqrt(sigma_y_sq)

        # Luminance comparison
        l = (2 * mu_xy + self.C1) / (mu_x_sq + mu_y_sq + self.C1)

        # Contrast comparison - use stabilized formulation
        c = (2 * sigma_x * sigma_y + self.C2) / (sigma_x_sq + sigma_y_sq + self.C2)

        # Structure comparison - add eps to denominator for stability
        s_numerator = sigma_xy + self.C3
        s_denominator = sigma_x * sigma_y + self.C3
        s = s_numerator / s_denominator

        return l, c, s

    def _compute_lwssim_single_scale(self, x, y, window_size):
        """
        Compute the Level-Weighted SSIM at a single scale.
        """
        l, c, s = self._compute_ssim_components(x, y, window_size)

        # Key difference from standard SSIM:
        # Addition instead of multiplication between luminance and other components
        lwssim = self.alpha * l + self.beta * c * s

        return lwssim

    def forward(self, x, y):
        """
        Compute LWSSIM loss between input and target images.
        """
        # Ensure inputs are in the appropriate range
        if self.data_range != 1.0:
            x = x / self.data_range
            y = y / self.data_range

        # Clamp values for stability
        x = torch.clamp(x, min=0, max=1)
        y = torch.clamp(y, min=0, max=1)

        # Compute multi-scale LWSSIM
        multi_scale_lwssim = 0.0

        for i, window_size in enumerate(self.filter_sizes):
            # Compute LWSSIM at current scale
            lwssim_val = self._compute_lwssim_single_scale(x, y, window_size)

            # Apply weight for this scale
            weight = self.weights[i]
            multi_scale_lwssim += weight * lwssim_val

        # Average across spatial dimensions
        multi_scale_lwssim = multi_scale_lwssim.mean([2, 3])

        # Average across batch and channels
        lwssim_score = multi_scale_lwssim.mean()

        return lwssim_score



## Per Region Overlap (PRO) Score

In [40]:
def calculate_pro_score(anomaly_maps, ground_truth_masks, num_thresholds=100, max_fpr=1):
    """
    Optimized version of PRO calculation using vectorized operations.

    Args:
        anomaly_maps: List of anomaly score maps
        ground_truth_masks: List of binary ground truth masks
        num_thresholds: Number of thresholds to evaluate
        max_fpr: Maximum false positive rate threshold

    Returns:
        pro_score: The PRO score
        pro_curve: The PRO values at each FPR point (for plotting)
    """

    # Flatten scores and labels
    all_scores = []
    all_labels = []

    for score_map, gt_mask in zip(anomaly_maps, ground_truth_masks):
        all_scores.append(score_map.flatten())
        all_labels.append(gt_mask.flatten())

    all_scores = np.concatenate(all_scores)
    all_labels = np.concatenate(all_labels)

    # Get thresholds from ROC curve
    fpr, tpr, thresholds = roc_curve(all_labels, all_scores)

    # Sample thresholds uniformly
    if len(thresholds) > num_thresholds:
        sampled_indices = np.linspace(0, len(thresholds) - 1, num_thresholds, dtype=int)
        thresholds = thresholds[sampled_indices]
        fpr = fpr[sampled_indices]

    pro_curve = []

    for i, threshold in enumerate(thresholds):
        if fpr[i] > max_fpr:
            break

        overlaps = []

        for score_map, gt_mask in zip(anomaly_maps, ground_truth_masks):
            # Binary prediction
            prediction = (score_map >= threshold).astype(np.uint8)

            # Skip if no ground truth anomalies
            if gt_mask.max() == 0:
                continue

            # Label connected components
            labeled_gt, num_regions = measure.label(gt_mask, return_num=True, connectivity=2)

            # Calculate overlap for each region
            for region_id in range(1, num_regions + 1):
                region_mask = (labeled_gt == region_id)

                # Calculate overlap (IoU)
                intersection = np.logical_and(prediction, region_mask).sum()
                union = np.logical_or(prediction, region_mask).sum()

                if union > 0:
                    overlap = intersection / union
                    overlaps.append(overlap)

        # Average overlap at this threshold
        if overlaps:
            pro_curve.append(np.mean(overlaps))
        else:
            pro_curve.append(0.0)

    # Integrate PRO curve up to max_fpr
    if pro_curve:
        pro_score = np.mean(pro_curve)
    else:
        pro_score = 0.0

    return pro_score, np.array(pro_curve)


# Network

## DyT

In [41]:
class DyT(nn.Module):
    """ Dynamic Tanh Layer """
    def __init__(self, dim, init_a=DYT_INIT_A):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(init_a, dtype=torch.float32)) #0.5 come il paper di Lecun
        self.gamma = nn.Parameter(torch.ones(dim, dtype=torch.float32) * 0.9) #Come il paper di LeCun, ma gamma leggermente più piccolo
        self.beta = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) #Come il paper di LeCun
        # To ensure gamma and beta can be broadcasted correctly with (B, N, D) input
        # they need to be shaped as (1, 1, D) or (D)
        # PyTorch handles (D) broadcasting to (B,N,D) if op is like self.gamma * tensor

    def forward(self, x):
        # x shape: (batch_size, num_patches, embed_dim)
        return self.gamma * torch.tanh(self.alpha * x) + self.beta

## Residual Connection

In [42]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

## PreNorm

In [43]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        if USING_DYT:
            self.norm = DyT(dim)
        else:
            self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        global TIME_IN_NORM
        global NORM_CALLS


        start_time = time.time()
        res = self.fn(self.norm(x), **kwargs)
        end_time = time.time()

        used_time = end_time - start_time

        TIME_IN_NORM += used_time
        NORM_CALLS += 1

        return res

## FeedForward

In [44]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )
    def forward(self, x):
        return self.net(x)

## Attention

In [45]:
class Attention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x, mask=None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value=True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

        attn = dots.softmax(dim=-1)
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out


## Transformer

In [46]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads=heads))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
            ]))

    def forward(self, x, mask=None):
        for attn, ff in self.layers:
            x = attn(x, mask=mask)
            x = ff(x)
        return x

## ViT Encoder

In [47]:
class ViTEncoder(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.transformer = Transformer(dim, depth, heads, mlp_dim)

        self.to_cls_token = nn.Identity()

    def forward(self, img, mask=None):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]

        x = self.transformer(x, mask)

        x = self.to_cls_token(x[:, 1:, :])

        return x

## DigitCaps

In [48]:
class DigitCaps(nn.Module):
    def __init__(self, out_num_caps=1, in_num_caps=8*8*64, in_dim_caps=8, out_dim_caps=512, decode_idx=-1):
        super(DigitCaps, self).__init__()

        self.in_dim_caps = in_dim_caps
        self.in_num_caps = in_num_caps
        self.out_dim_caps = out_dim_caps
        self.out_num_caps = out_num_caps
        self.decode_idx = decode_idx
        self.W = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

    def forward(self, x):
        # x size: batch x 1152 x 8
        x_hat = torch.squeeze(torch.matmul(self.W, x[:, None, :, :, None]), dim=-1)
        x_hat_detached = x_hat.detach()
        # x_hat size: batch x ndigits x 1152 x 16
        b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps)).to(DEVICE)
        # b size: batch x ndigits x 1152

        # Routing algorithm
        num_iters = 3
        for i in range(num_iters):
            c = F.softmax(b, dim=1)
            # c size: batch x ndigits x 1152
            if i == num_iters - 1:
                # output size: batch x ndigits x 1 x 16
                outputs = self.squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
            else:
                outputs = self.squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True))
                b = b + torch.sum(outputs * x_hat_detached, dim=-1)

        outputs = torch.squeeze(outputs, dim=-2)  # squeezing to remove ones at the dimension -1

        # Choose the longest vector as the one to decode
        if self.decode_idx == -1:
            classes = torch.sqrt((outputs ** 2).sum(2))
            classes = F.softmax(classes, dim=1)
            _, max_length_indices = classes.max(dim=1)
        else:  # always choose the same digitcaps
            max_length_indices = torch.ones(outputs.size(0)).long() * self.decode_idx
            max_length_indices.to(DEVICE)

        masked = Variable(torch.sparse.torch.eye(self.out_num_caps)).to(DEVICE)

        masked = masked.index_select(dim=0, index=max_length_indices)
        t = (outputs * masked[:, :, None]).sum(dim=1).unsqueeze(1)

        return t, outputs

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

## Convolutional Decoder

In [49]:
class ConvolutionalDecoder(nn.Module):
    def __init__(self, in_channels):
        super(ConvolutionalDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 32, 9, stride=3, padding=1),
            nn.BatchNorm2d(32, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, 7, stride=5, padding=1),
            nn.BatchNorm2d(32, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 9, stride=2),
            nn.BatchNorm2d(16, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 6, stride=1),
            nn.BatchNorm2d(8, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 3, 11, stride=1),
            nn.Tanh()
        )

    def forward(self, x):
        recon = self.decoder(x)
        return recon

## VT AutoEncoder

In [50]:
class VT_AutoEncoder(nn.Module):
    def __init__(self, image_size=512,
                patch_size=64,
                num_classes=1,
                dim=512,
                depth=6,
                heads=8,
                mlp_dim=1024,
                train=True, use_mask=USING_RANDOM_MASK):

        super(VT_AutoEncoder, self).__init__()
        self.vt = ViTEncoder(
            image_size=image_size,
            patch_size=patch_size,
            num_classes=num_classes,
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim)

        self.decoder = ConvolutionalDecoder(8)
        self.Digcap = DigitCaps(in_num_caps=((image_size//patch_size)**2)*8*8, in_dim_caps=8)
        self.fixed_mask = torch.ones(1, image_size//patch_size, image_size//patch_size).bool().to(DEVICE)
        self.Train = train

        if self.Train:
            print("\nInitializing network weights...")
            initialize_weights(self.vt, self.decoder)

    def forward(self, x):
        b = x.size(0)
        if self.Train:
            gaussian_mask = torch.randn_like(self.fixed_mask.float()).to(DEVICE)
            # Imposta True (1) dove il valore è maggiore di 0.4, altrimenti False (0)
            mask = (gaussian_mask > 0.4).bool().to(DEVICE)
            encoded = self.vt(x, mask)
            encoded = self.add_noise(encoded, decaying=NOISE_DECAY)

        else:
            encoded = self.vt(x, self.fixed_mask)

        encoded1, vectors = self.Digcap(encoded.view(b, encoded.size(1)*8*8, -1))
        recons = self.decoder(encoded1.view(b, -1, 8, 8))

        return encoded, recons

    def add_noise(self, latent, sd=0.2, decaying=False):
        global NOISE_DECAY_FACTOR
        assert sd >= 0.0
        n = torch.distributions.Normal(torch.tensor([0.]), torch.tensor([sd]))
        noise = n.sample(latent.size()).squeeze(-1).to(DEVICE)

        if decaying:
            noise = noise * (1000/(1000 + NOISE_DECAY_FACTOR))
            NOISE_DECAY_FACTOR += 1

        latent = latent + noise
        return latent

## Mixture Density Network

In [51]:

class MDN(nn.Module):
    def __init__(self, input_dim=512, out_dim=512, layer_size=512, coefs=NUM_GAUSSIANS, test=False, sd=0.5):
        super(MDN, self).__init__()
        self.in_features = input_dim

        self.pi = nn.Linear(layer_size, coefs, bias=False)
        self.mu = nn.Linear(layer_size, out_dim * coefs, bias=False)  # mean
        self.sigma_sq = nn.Linear(layer_size, out_dim * coefs, bias=False)  # isotropic independent variance
        self.out_dim = out_dim
        self.coefs = coefs
        self.test = test
        self.sd = sd

    def forward(self, x):
        ep = np.finfo(float).eps
        x = torch.clamp(x, ep)

        pi = F.softmax(self.pi(x), dim=-1)
        sigma_sq = F.softplus(self.sigma_sq(x)).view(x.size(0), x.size(1), self.in_features, -1)  # logvar
        mu = self.mu(x).view(x.size(0), x.size(1), self.in_features, -1)  # mean
        return pi, mu, sigma_sq


    def negative_log_likelihood(self, x, means, logvars, weights, test=False):

        eps = 1e-8

        x_squeezed = x.unsqueeze(-1).expand_as(logvars)
        logvars = torch.clamp(logvars, min=-10, max=10)
        a = (x_squeezed - means) ** 2
        log_p = (logvars + a / (torch.exp(logvars) + eps)).sum(2)
        log_p = -0.5 * (np.log(2 * np.pi) + log_p)

        weighted_log_p = - weights * log_p
        weighted_log_p = torch.sum(weighted_log_p, 2)

        if test:
            res = weighted_log_p
        else:
            res = torch.mean(torch.sum(weighted_log_p, 1))
        return res

# Train

In [52]:
# Train
########################

def train_model(save_best=False):
    """
    Train the anomaly detection model

    Args:
        se safe_best=True salva il migliore ogni volta, altrimenti ogni 25 epochs.
    """
    # we reset the global variables for time and calls for norm layer
    global TIME_IN_NORM
    global NORM_CALLS
    TIME_IN_NORM = 0
    NORM_CALLS = 0

    print(f"\n{'='*20} Training on {DATASET_STR} dataset {'='*20}")
    print(f"Product: {DATASET_CAT_STR}, Epochs: {NUM_EPOCHS}, Learning Rate: {LEARNING_RATE}")

    # Initialize wandb

    # Initialize SSIM loss
    if USING_LW_SSIM:
        ssim = LWSSIM().to(DEVICE)
    else:
        ssim = SSIM().to(DEVICE)

    if LOG_LPIPS:
        lpips = LPIPS(net='vgg').to(DEVICE)
    else:
        lpips = None

    # Load dataset
    if DATASET_STR == 'MVTEC':
        data = MVTEC(BATCH_SIZE, root=MVTEC_DATA_DIR, product=DATASET_CAT_STR)
    elif DATASET_STR == 'BTAD':
        data = BTAD(BATCH_SIZE, root=BTAD_DATA_DIR, product=DATASET_CAT_STR)
    else:
        raise ValueError(f"Dataset {DATASET_STR} not supported")

    # Initialize models
    model = VT_AutoEncoder(patch_size=PATCH_SIZE, train=True, use_mask=USING_RANDOM_MASK).to(DEVICE)
    G_estimate = MDN().to(DEVICE)

    # Initialize optimizer
    optimizer = Adam(list(model.parameters()) + list(G_estimate.parameters()),
                    lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    # Set models to train mode
    model.train()
    G_estimate.train()

    # Training loop
    minloss = float('inf')
    best_epoch = 0

    print('\nNetwork training started...')
    start_time = time.time()
    for epoch in range(NUM_EPOCHS):
        epoch_losses = []
        epoch_mse_losses = []
        epoch_ssim_losses = []
        epoch_mdn_losses = []
        epoch_lpips_losses = []

        # Both MVTec and BTAD objects have a .train_loader attribute
        train_loader = data.train_loader

        for images, masks in train_loader:
            # Handle single channel images
            if images.size(1) == 1:
                images = torch.stack([images, images, images]).squeeze(2).permute(1, 0, 2, 3)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass through models
            vector, reconstructions = model(images.to(DEVICE))
            pi, mu, sigma = G_estimate(vector)

            # Calculate losses
            mse_loss = F.mse_loss(reconstructions, images.to(DEVICE), reduction='mean')  # Reconstruction Loss
            ssim_loss = -ssim(images.to(DEVICE), reconstructions)  # Structural similarity loss
            mdn_loss = G_estimate.negative_log_likelihood(vector, mu, sigma, pi)  # Mixture density network loss

            if LOG_LPIPS:
                lpips_loss = lpips.forward(reconstructions, images.to(DEVICE)).mean()  # LPIPS loss
            else:
                lpips_loss = torch.tensor(0.0).to(DEVICE)

            # Total loss
            total_loss = LAMBDA_MSE * mse_loss + LAMBDA_SSIM * ssim_loss + LAMBDA_LPIPS * lpips_loss + mdn_loss

            # Store loss
            epoch_losses.append(total_loss.item())
            epoch_mse_losses.append(mse_loss.item())
            epoch_ssim_losses.append(ssim_loss.item())
            epoch_mdn_losses.append(mdn_loss.item())
            epoch_lpips_losses.append(lpips_loss.item())

            # Backpropagate and update weights
            total_loss.backward()
            optimizer.step()

        # Calculate epoch average loss
        avg_epoch_loss = np.mean(epoch_losses)
        avg_mse_loss = np.mean(epoch_mse_losses)
        avg_ssim_loss = np.mean(epoch_ssim_losses)
        avg_mdn_loss = np.mean(epoch_mdn_losses)
        avg_lpips_loss = np.mean(epoch_lpips_losses)

        # Log to wandb
        wandb.log({
            "epoch": epoch,
            "total_loss": avg_epoch_loss,
            "mse_loss": avg_mse_loss,
            "ssim_loss": avg_ssim_loss,
            "mdn_loss": avg_mdn_loss,
            "lpips_loss": avg_lpips_loss
        })

        # Log reconstructed images
        if epoch % 10 == 0:
            # Get a sample batch for visualization
            sample_images, _ = next(iter(train_loader))
            if sample_images.size(1) == 1:
                sample_images = torch.stack([sample_images, sample_images, sample_images]).squeeze(2).permute(1, 0, 2, 3)

            with torch.no_grad():
                _, sample_reconstructions = model(sample_images.to(DEVICE))

            # Create comparison grid
            comparison = torch.cat([sample_images[:4], sample_reconstructions[:4].cpu()])
            grid = utils.make_grid(comparison, nrow=4)

            # Log to tensorboard and wandb
            wandb.log({"reconstructions": wandb.Image(grid)})

        print(f'Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {avg_epoch_loss:.3f}')

        # Save best model
        if avg_epoch_loss < minloss and save_best:
            minloss = avg_epoch_loss
            best_epoch = epoch
            os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

            torch.save(model.state_dict(), model_path)
            torch.save(G_estimate.state_dict(), g_path)
            print(f"Saved best model at epoch {epoch+1} with loss {minloss:.3f}")

            # Save to wandb
            if IS_KAGGLE or IS_COLAB:
                wandb.save(model_path)
                wandb.save(g_path)

        elif (epoch % 25 == 0 and epoch != 0) or (epoch == NUM_EPOCHS - 1):
            os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
            torch.save(model.state_dict(), model_path)
            torch.save(G_estimate.state_dict(), g_path)
            print(f"Saved model at epoch {epoch+1} with loss {avg_epoch_loss:.3f}")

            # Save to wandb
            if IS_KAGGLE or IS_COLAB:
                wandb.save(model_path)
                wandb.save(g_path)

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"\nTraining completed in {elapsed_time:.2f} seconds.")
    print(f"Time per normalization {NORM_STR}: {TIME_IN_NORM:.8f} seconds")
    print(f"Model saved at: {os.path.join(MODEL_SAVE_DIR, f'VT_AE_{DATASET_STR}_{DATASET_CAT_STR}.pt')}")
    print(f"MDN model saved at: {os.path.join(MODEL_SAVE_DIR, f'G_estimate_{DATASET_STR}_{DATASET_CAT_STR}.pt')}")


# Evaluation

## Image-Level Threshold

In [53]:
def calculate_image_level_threshold(model, G_estimate, test_norm_loader, test_anom_loader, patch_size, target_fpr=0.05):
    """
    Calculate an image-level threshold based on np.max(score_map) for each image.

    Args:
        model: Trained VT_AE model.
        G_estimate: Trained MDN model.
        test_norm_loader: DataLoader for normal test/validation images.
        test_anom_loader: DataLoader for anomalous test/validation images.
        patch_size: Patch size used by the model.
        target_fpr (float, optional): The desired maximum false positive rate at the image level.
                                     If None, uses Youden's J statistic to find the threshold.
                                     Defaults to 0.05.

    Returns:
        float: Calculated image-level threshold.
    """
    model.eval()
    G_estimate.eval()

    image_max_scores = []
    image_true_labels = [] # 0 for normal, 1 for anomaly

    # Process normal images
    print("Calculating threshold: Processing normal samples...")
    for images, _ in test_norm_loader: # Masks for normal are all zeros, not needed for label here
        if images.size(1) == 1: # Handle single channel images
            images = torch.stack([images, images, images]).squeeze(2).permute(1,0,2,3).to(DEVICE)

        images = images.to(DEVICE)

        with torch.no_grad():
            vector, reconstructions = model(images)
            pi, mu, sigma = G_estimate(vector)
            mdn_loss = G_estimate.negative_log_likelihood(vector, mu, sigma, pi, test=True) # Shape (batch, num_patches)

        mdn_loss_np = mdn_loss.detach().cpu().numpy()

        # Upsampling layer (ensure it's on the correct device if created here, or pass as arg)
        upsample_layer = torch.nn.UpsamplingBilinear2d((512, 512))

        # Reshape to (batch_size, 1, H_patch, W_patch)
        # Note: args.patch_size would be better if passed directly or from a config
        h_patch = 512 // patch_size
        w_patch = 512 // patch_size
        norm_score_batch_reshaped = mdn_loss_np.reshape(-1, 1, h_patch, w_patch)

        # Upsample
        score_map_batch_upsampled = upsample_layer(torch.tensor(norm_score_batch_reshaped)) # (batch_size, 1, 512, 512)

        for batch_idx in range(images.size(0)):
            # Filter one image's score map: score_map_batch_upsampled is (B, C, H, W)
            # Convert to numpy, select channel 0 for filtering
            score_map_one_image_np = score_map_batch_upsampled[batch_idx, 0].cpu().numpy()
            score_map_one_image_filtered = gaussian_filter(score_map_one_image_np, sigma=4) # Assuming Filter expects (H,W)

            image_max_scores.append(np.max(score_map_one_image_filtered))
            image_true_labels.append(0) # This is a normal image

    # Process anomalous images
    print("Calculating threshold: Processing anomalous samples...")
    for images, masks in test_anom_loader:
        if images.size(1) == 1: # Handle single channel images
            images = torch.stack([images, images, images]).squeeze(2).permute(1,0,2,3)

        images = images.to(DEVICE)
        masks = masks.to(DEVICE)

        with torch.no_grad():
            vector, reconstructions = model(images)
            pi, mu, sigma = G_estimate(vector)
            mdn_loss = G_estimate.negative_log_likelihood(vector, mu, sigma, pi, test=True)

        mdn_loss_np = mdn_loss.detach().cpu().numpy()
        upsample_layer = torch.nn.UpsamplingBilinear2d((512, 512))
        h_patch = 512 // patch_size
        w_patch = 512 // patch_size
        norm_score_batch_reshaped = mdn_loss_np.reshape(-1, 1, h_patch, w_patch)
        score_map_batch_upsampled = upsample_layer(torch.tensor(norm_score_batch_reshaped))

        for batch_idx in range(images.size(0)):
            score_map_one_image_np = score_map_batch_upsampled[batch_idx, 0].cpu().numpy()
            score_map_one_image_filtered = gaussian_filter(score_map_one_image_np, sigma=4)
            image_max_scores.append(np.max(score_map_one_image_filtered))

            # Determine if the ground truth mask indicates an anomaly for this image
            true_mask_for_image = masks[batch_idx, 0].cpu().numpy()
            # An image from anom_loader is considered anomalous if its mask has any positive pixels
            is_truly_anomalous = 1 if np.sum(true_mask_for_image) > 0 else 0
            image_true_labels.append(is_truly_anomalous)


    if not image_max_scores:
        print("Warning: No scores collected for threshold calculation. Returning a default threshold (e.g., 0.5).")
        return 0.5

    fpr, tpr, roc_thresholds = roc_curve(image_true_labels, image_max_scores)

    chosen_threshold = 0.5 # Default

    if target_fpr is not None:
        # Find thresholds where actual_fpr <= target_fpr
        candidate_indices = np.where(fpr <= target_fpr)[0]
        if len(candidate_indices) > 0:
            # Among these, pick the one with the highest TPR
            best_candidate_idx = candidate_indices[np.argmax(tpr[candidate_indices])]
            chosen_threshold = roc_thresholds[best_candidate_idx]
            actual_fpr_val = fpr[best_candidate_idx]
            actual_tpr_val = tpr[best_candidate_idx]
            print(f"Threshold (for image-level FPR <= {target_fpr}): {chosen_threshold:.4f} (Actual FPR: {actual_fpr_val:.4f}, TPR: {actual_tpr_val:.4f})")
        else:
            # No threshold meets target_fpr, fall back to Youden's J or warn
            print(f"Warning: No threshold found for image-level target FPR <= {target_fpr}. Falling back to Youden's J.")
            if len(tpr) > 0 and len(fpr) > 0 : # Ensure not empty
                best_idx = np.argmax(tpr - fpr) # Youden's J
                chosen_threshold = roc_thresholds[best_idx]
                print(f"Threshold (Youden's J fallback): {chosen_threshold:.4f} (FPR: {fpr[best_idx]:.4f}, TPR: {tpr[best_idx]:.4f})")
            else:
                print("Warning: ROC curve could not be computed properly. Using default threshold.")
    else: # If target_fpr is None, just use Youden's J
        if len(tpr) > 0 and len(fpr) > 0 : # Ensure not empty
            best_idx = np.argmax(tpr - fpr) # Youden's J
            chosen_threshold = roc_thresholds[best_idx]
            print(f"Threshold (Youden's J): {chosen_threshold:.4f} (FPR: {fpr[best_idx]:.4f}, TPR: {tpr[best_idx]:.4f})")
        else:
            print("Warning: ROC curve could not be computed properly. Using default threshold.")

    return chosen_threshold

## Threshold (credo pixel level)

In [54]:
def calculate_threshold(model, G_estimate, data_loaders, patch_size, fpr_threshold=0.3):
    """
    Calculate threshold for anomaly detection

    Args:
        model: Trained VT_AE model
        G_estimate: Trained MDN model
        data_loaders: List of data loaders to use for threshold calculation
        patch_size: Patch size used in the model
        fpr_threshold: False positive rate threshold

    Returns:
        threshold: Calculated threshold
    """
    norm_loss_values = []
    normalised_scores = []
    mask_scores = []

    for data_loader in data_loaders:
        for images, masks in data_loader:
            # Handle single channel images
            if images.size(1) == 1:
                images = torch.stack([images, images, images]).squeeze(2).permute(1, 0, 2, 3)

            # Forward pass
            vector, reconstructions = model(images.to(DEVICE))
            pi, mu, sigma = G_estimate(vector)

            # Calculate MDN loss (anomaly score)
            mdn_loss = G_estimate.negative_log_likelihood(vector, mu, sigma, pi, test=True)
            norm_loss_values.append(mdn_loss.detach().cpu().numpy())

            # Process masks and scores for ROC calculation
            m = torch.nn.UpsamplingBilinear2d((512, 512))
            norm_score = norm_loss_values[-1].reshape(-1, 1, 512//patch_size, 512//patch_size)
            score_map = m(torch.tensor(norm_score))
            score_map = gaussian_filter(score_map, sigma=4)

            mask_scores.append(masks.squeeze(0).squeeze(0).cpu().numpy())
            normalised_scores.append(score_map)

    # Flatten scores and masks
    scores = np.asarray(normalised_scores).flatten()
    masks = np.asarray(mask_scores).flatten()

    # Calculate ROC curve and find threshold
    fpr, tpr, thresholds = roc_curve(masks, scores)
    fp_indices = np.where(fpr <= fpr_threshold)
    threshold = thresholds[fp_indices[-1][-1]]

    return threshold


## Evaluation function

In [55]:
# Evaluation
########################

def test_model():
    """
    Test the anomaly detection model
    """
    # we reset the global variables for time and calls for norm layer
    global TIME_IN_NORM
    global NORM_CALLS
    TIME_IN_NORM = 0
    NORM_CALLS = 0

    print(f"\n{'='*20} Testing on {DATASET_STR} dataset {'='*20}")
    print(f"Product: {DATASET_CAT_STR}, Patch Size: {PATCH_SIZE}")

    # Create results directory
    results_dir = os.path.join(MODEL_SAVE_DIR, f"results_{DATASET_STR}_{DATASET_CAT_STR}")
    os.makedirs(results_dir, exist_ok=True)

    # TODO BRILLARE
    # # Initialize SSIM loss
    # if USING_LW_SSIM:
    #     ssim = LWSSIM().to(DEVICE)
    # else:
    #     ssim = SSIM().to(DEVICE)

    # Load dataset
    if DATASET_STR == 'MVTEC':
        data = MVTEC(1, root=MVTEC_DATA_DIR, product=DATASET_CAT_STR)
    elif DATASET_STR == 'BTAD':
        # For BTAD, we need different handling for test
        data = BTAD(1, root=BTAD_DATA_DIR, product=DATASET_CAT_STR)
    else:
        raise ValueError(f"Dataset {DATASET_STR} not supported")

    # Load models
    model = VT_AutoEncoder(patch_size=PATCH_SIZE, train=False).to(DEVICE)
    G_estimate = MDN().to(DEVICE)


    try:
        model.load_state_dict(torch.load(model_path))
        G_estimate.load_state_dict(torch.load(g_path))
        print(f"Models loaded from {model_path} and {g_path}")
    except Exception as e:
        print(f"Error loading models: {e}")
        return

    # Set models to eval mode
    model.eval()
    G_estimate.eval()

    # Calculate threshold
    print("Calculating threshold...")
    test_loaders = [data.test_norm_loader, data.test_anom_loader]
    threshold = calculate_threshold(model, G_estimate, test_loaders, PATCH_SIZE)
    print(f"VINTAGE!!! Threshold: {threshold}")
    # Pass the specific normal and anomaly loaders
    threshold = calculate_image_level_threshold(
        model,
        G_estimate,
        data.test_norm_loader, # Pass the normal loader
        data.test_anom_loader, # Pass the anomaly loader
        PATCH_SIZE,
        target_fpr=None #0.05        # Adjust target_fpr as needed (e.g., 0.05 or 0.1 for image-level)
                              # Or pass None to use Youden's J statistic by default
    )
    # Log threshold to wandb
    wandb.config.update({"threshold": threshold})

    # Evaluate on test data
    print("Evaluating on test data...")
    with torch.no_grad():
        # Lists to store results
        normal_losses = []
        anomaly_losses = []
        normal_scores = []
        anomaly_scores = []

        # NEW: Lists to store full pixel-level data
        full_score_maps = []
        full_masks = []

        all_y_true = []
        all_y_pred = []

        start_time = time.time()
        # Process normal test data
        for images, masks in data.test_norm_loader:
            if images.size(1) == 1:
                images = torch.stack([images, images, images]).squeeze(2).permute(1, 0, 2, 3)

            vector, reconstructions = model(images.to(DEVICE))
            pi, mu, sigma = G_estimate(vector)

            mdn_loss = G_estimate.negative_log_likelihood(vector, mu, sigma, pi, test=True)
            normal_losses.append(mdn_loss.sum().item())

            # Generate score map
            m = torch.nn.UpsamplingBilinear2d((512, 512))
            norm_score = mdn_loss.detach().cpu().numpy().reshape(-1, 1, 512//PATCH_SIZE, 512//PATCH_SIZE)
            score_map = m(torch.tensor(norm_score))
            score_map = gaussian_filter(score_map.numpy(), sigma=4)

            # Store full score map and corresponding zero mask for normal samples
            full_score_maps.append(score_map[0][0])
            full_masks.append(np.zeros_like(score_map[0][0]))

            score_val = np.max(score_map)
            normal_scores.append(score_val)

            # Binary prediction (0: normal, 1: anomaly)
            all_y_true.append(0)  # Ground truth: normal
            all_y_pred.append(1 if score_val > threshold else 0)  # Prediction based on threshold

        # Process anomalous test data
        for idx, (images, masks) in enumerate(data.test_anom_loader):
            if images.size(1) == 1:
                images = torch.stack([images, images, images]).squeeze(2).permute(1, 0, 2, 3)

            vector, reconstructions = model(images.to(DEVICE))
            pi, mu, sigma = G_estimate(vector)

            mdn_loss = G_estimate.negative_log_likelihood(vector, mu, sigma, pi, test=True)
            anomaly_losses.append(mdn_loss.sum().item())

            # Generate score map
            m = torch.nn.UpsamplingBilinear2d((512, 512))
            norm_score = mdn_loss.detach().cpu().numpy().reshape(-1, 1, 512//PATCH_SIZE, 512//PATCH_SIZE)
            score_map = m(torch.tensor(norm_score))
            score_map = gaussian_filter(score_map.numpy(), sigma=4)

            # NEW: Store full score map and corresponding mask
            full_score_maps.append(score_map[0][0])
            full_masks.append(masks.squeeze(0).squeeze(0).cpu().numpy())

            score_val = np.max(score_map)
            anomaly_scores.append(score_val)

            # Binary prediction (0: normal, 1: anomaly)
            all_y_true.append(1)  # Ground truth: anomaly
            all_y_pred.append(1 if score_val > threshold else 0)  # Prediction based on threshold

            # Visualize some results and save
            if idx % 5 == 0:
                # Enhanced visualization
                save_path = os.path.join(results_dir, f"anomaly_{idx}.png")
                plot_enhanced(images, masks, reconstructions.cpu(), score_map[0][0], threshold, save_path)

                # Region visualization
                region_path = os.path.join(results_dir, f"anomaly_regions_{idx}.png")
                visualize_regions(images, score_map[0][0], threshold, min_area=100, save_path=region_path)

                # Log to wandb
                wandb.log({
                    f"anomaly_sample_{idx}": wandb.Image(save_path),
                    f"anomaly_regions_{idx}": wandb.Image(region_path)
                })
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Evaluation completed in {elapsed_time:.2f} seconds.")
        print(f"Total Time for normalization {NORM_STR}: {TIME_IN_NORM:.4f} seconds")
        print(f"Time per normalization {NORM_STR}: {TIME_IN_NORM/NORM_CALLS:.8f} seconds")
        # Calculate metrics
        print("\nCalculating evaluation metrics...")

        # Convert to numpy arrays
        y_true = np.array(all_y_true)
        y_pred = np.array(all_y_pred)
        roc_labels = np.concatenate((np.zeros(len(normal_losses)), np.ones(len(anomaly_losses))))
        roc_scores = np.concatenate((normal_losses, anomaly_losses))

        # Image-level ROC AUC
        image_auc = roc_auc_score(roc_labels, roc_scores)

        # Calculate pixel-level ROC AUC  and PRO score
        # Flatten and concatenate all score maps and masks
        pixel_scores = np.concatenate([s.flatten() for s in full_score_maps])
        pixel_masks = np.concatenate([m.flatten() for m in full_masks])

        # Now the dimensions will match
        pixel_auc = roc_auc_score(pixel_masks, pixel_scores)
        pro_score, _ = calculate_pro_score(anomaly_maps = full_score_maps, ground_truth_masks = full_masks)

        # Precision-Recall AUC
        img_precision, img_recall, _ = precision_recall_curve(roc_labels, roc_scores)
        img_pr_auc = auc(img_recall, img_precision)

        pixel_precision, pixel_recall, _ = precision_recall_curve(pixel_masks, pixel_scores)
        pr_auc = auc(pixel_recall, pixel_precision)

        # Additional metrics
        f1 = f1_score(y_true, y_pred)
        accuracy = accuracy_score(y_true, y_pred)
        precision_score_val = precision_score(y_true, y_pred)
        recall_score_val = recall_score(y_true, y_pred)

        # Generate ROC curve
        img_fpr, img_tpr, _ = roc_curve(roc_labels, roc_scores)
        pixel_fpr, pixel_tpr, _ = roc_curve(pixel_masks, pixel_scores)

        # Plot and save img ROC curve
        img_roc_path = os.path.join(results_dir, "img_roc_curve.png")
        plot_roc_curve(img_fpr, img_tpr, image_auc, color='darkorange', title='Image ROC Curve',save_path=img_roc_path)

        # Plot and save pixel ROC curve
        pixel_roc_path = os.path.join(results_dir, "pixel_roc_curve.png")
        plot_roc_curve(pixel_fpr, pixel_tpr, pixel_auc, color='purple', title='Pixel ROC Curve', save_path=pixel_roc_path)

        # Plot and save img PR curve
        img_pr_path = os.path.join(results_dir, "img_pr_curve.png")
        plot_precision_recall_curve(img_precision, img_recall, img_pr_auc, color='green', title='Image PR Curve', save_path=img_pr_path)

        # Plot and save pixel PR curve
        pixel_pr_path = os.path.join(results_dir, "pixel_pr_curve.png")
        plot_precision_recall_curve(pixel_precision, pixel_recall, pr_auc, color='blue', title='Pixel PR Curve', save_path=pixel_pr_path)


        # Plot and save confusion matrix
        cm_path = os.path.join(results_dir, "confusion_matrix.png")
        plot_confusion_matrix(y_true, y_pred, cm_path, normalize=True)

        # Plot and save score distributions
        dist_path = os.path.join(results_dir, "score_distributions.png")
        plot_score_distributions(normal_scores, anomaly_scores, threshold, dist_path)

        # Log results to wandb

        wandb.summary["img_level_auc"] = image_auc
        wandb.summary["pixel_level_auc"] = pixel_auc
        wandb.summary["img_pr_auc"] = img_pr_auc
        wandb.summary["f1_score"] = f1
        wandb.summary["accuracy"] = accuracy
        wandb.summary["precision"] = precision_score_val
        wandb.summary["recall"] = recall_score_val
        wandb.summary["pro_score"] = pro_score

        wandb.log({
            "img_roc_curve": wandb.Image(img_roc_path),
            "img_pr_curve": wandb.Image(img_pr_path),
            "pixel_roc_curve": wandb.Image(pixel_roc_path),
            "pixel_pr_curve": wandb.Image(pixel_pr_path),
            "confusion_matrix": wandb.Image(cm_path),
            "score_distributions": wandb.Image(dist_path)
        })

        # Log classification report
        cls_report = classification_report(y_true, y_pred, target_names=['Normal', 'Anomaly'])
        with open(os.path.join(results_dir, "classification_report.txt"), "w") as f:
            f.write(cls_report)

        print(f"\nResults for {NORM_STR} - {DATASET_STR} - {DATASET_CAT_STR}:")
        print(f"Per Region Overlap (PRO) Score: {pro_score:.4f}")
        print(f"Image-level AUC: {image_auc:.4f}")
        print(f"Pixel-level AUC [Real PRO]: {pixel_auc:.4f}")
        print(f"Precision-Recall AUC: {pr_auc:.4f}")
        print(f"F1 Score: {f1:.4f}")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision_score_val:.4f}")
        print(f"Recall: {recall_score_val:.4f}")
        print("\nClassification Report:")
        print(cls_report)

    # Finish wandb
    wandb.finish()

# Main

In [None]:
def main():

    print("-" * 30)
    print("Paths Configuration:")
    print(f"BASE_PATH: {BASE_PATH}")
    print(f"MODEL_SAVE_DIR: {MODEL_SAVE_DIR}")
    print(f"DATASET_DIR: {DATASET_DIR}")
    print(f"MVTEC_DATA_DIR: {MVTEC_DATA_DIR}")
    print(f"BTAD_DATA_DIR: {BTAD_DATA_DIR}")

    print("-" * 30)
    print("Run Configuration:")
    print(f"Normalization: {NORM_STR}")
    print(f"Dataset: {DATASET_STR}")
    print(f"Category: {DATASET_CAT_STR}")
    print(f"Epochs: {NUM_EPOCHS}")
    print(f"Train Model: {TRAIN_MODEL}")
    print(f"Compare Models: {COMPARE_MODELS}")
    print(f"Using LW SSIM: {USING_LW_SSIM}")
    print(f"Using Mask: {USING_RANDOM_MASK}")
    print(f"Logging LPIPS: {LOG_LPIPS}")
    print(f"Noise Decay: {NOISE_DECAY}")
    print(f"Noise Decay Factor: {NOISE_DECAY_FACTOR}")
    print("-" * 30)

    # Setup wandb
    setup_wandb()

    # Train the model if specified
    if TRAIN_MODEL:
        train_model(save_best=False)

    # Test the model
    test_model()
    wandb.finish()

# https://wandb.ai/authorize?ref=models
main()

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


------------------------------
Paths Configuration:
BASE_PATH: /kaggle/working
MODEL_SAVE_DIR: /kaggle/working/trained_models
DATASET_DIR: /kaggle/input
MVTEC_DATA_DIR: /kaggle/input/mvtec-ad
BTAD_DATA_DIR: /kaggle/input/btad-beantech-anomaly-detection
------------------------------
Run Configuration:
Normalization: DyT
Dataset: BTAD
Category: 01
Epochs: 400
Train Model: True
Compare Models: False
Using LW SSIM: True
Using Mask: True
Logging LPIPS: False
Noise Decay: True
Noise Decay Factor: 1
------------------------------



Product: 01, Epochs: 400, Learning Rate: 0.0001
Loading BTAD dataset for product 01...
Found product '01' under nested directory 'BTech_Dataset_transformed'.
Found 400 images for product 01 (train, ok=True, ko=False)
Found product '01' under nested directory 'BTech_Dataset_transformed'.
Found 21 images for product 01 (test, ok=True, ko=False)
Found product '01' under nested directory 'BTech_Dataset_transformed'.
Found 49 images for product 01 (test, ok=False, ko=True)
BTAD dataset for product 01 loaded successfully!
Train loader: Ready
Test normal loader: Ready
Test anomalous loader: Ready
Validation loader: Using test normal data

Initializing network weights...

Network training started...
Epoch 1/400, Loss: 40898.105
