# Drive Setup

In [None]:
# Google Drive Setup
########################


# Set up directories
import os
BASE_PATH = '/kaggle/working'
MODEL_SAVE_DIR = os.path.join(BASE_PATH, 'trained_models')
DATASET_DIR = '/kaggle/input'
MVTEC_DATA_DIR = os.path.join(DATASET_DIR, 'mvtec-ad')
BTAD_DATA_DIR = os.path.join(DATASET_DIR, 'btad-beantech-anomaly-detection')

# Create directories if they don't exist
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(MVTEC_DATA_DIR, exist_ok=True)
os.makedirs(BTAD_DATA_DIR, exist_ok=True)

USE_DYT = True
DYT_INIT_A = 0.5

# Set num_workers for data loaders
NUM_WORKERS_PARAM = {
    'num_workers': 0  # Set to 0 for Colab to avoid crashes
}


# Imports

In [None]:
# Imports
########################

# Standard libraries
import random
import argparse
from collections import OrderedDict
from itertools import chain
import time

# Data manipulation and analysis
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve, auc, precision_recall_curve
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from sklearn.metrics import confusion_matrix, classification_report

# Image processing and visualization
import matplotlib.pyplot as plt
import seaborn as sns
import skimage.io as io
from skimage.io import imread
from skimage.measure import label, regionprops
from skimage import measure
from PIL import Image
import cv2
from scipy.ndimage import gaussian_filter, median_filter

# PyTorch and deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter
import torchvision.models as models
import torchvision.utils as utils
from torchvision import transforms
from einops import rearrange, repeat
import wandb


# Globals

In [None]:
# Globals
########################

class Config():
    def __init__(self):
        self.USE_CUDA = True and torch.cuda.is_available()

# Minimum number of patches for Vision Transformer
MIN_NUM_PATCHES = 16

# Visualization settings
PLOT_COLORS = {"normal": "green", "anomaly": "red"}
PLOT_STYLE = "whitegrid"

# Utils

In [None]:
# Utils
########################

def Normalise(score_map):
    """Normalize score map to range [0,1]"""
    max_score = score_map.max()
    min_score = score_map.min()
    scores = (score_map - min_score) / (max_score - min_score)
    return scores

def Mean_var(score_map):
    """Calculate mean and variance of a score map"""
    mean = np.mean(score_map)
    var = np.var(score_map)
    return mean, var

def Filter(score_map, type=0):
    """
    Apply filtering to score map
    Parameters
    ----------
    score_map : score map as tensor or ndarray
    type : Int, optional
            DESCRIPTION. The values are:
            0 = Gaussian
            1 = Median

    Returns
    -------
    score: Filtered score
    """
    if type == 0:
        score = gaussian_filter(score_map, sigma=4)
    if type == 1:
        score = median_filter(score_map, size=3)
    return score

def Binarization(mask, thres=0., type=0):
    """Binarize a mask using threshold"""
    if type == 0:
        mask = np.where(mask > thres, 1., 0.)
    elif type == 1:
        mask = np.where(mask > thres, mask, 0.)
    return mask

def plot(image, grnd_truth, score):
    """Plot image, ground truth and predicted score"""
    plt.figure(figsize=(15, 5))
    plt.subplot(131)
    plt.imshow(image[0].permute(1, 2, 0))
    plt.title('Original Image')
    plt.subplot(132)
    plt.imshow(grnd_truth.squeeze(0).squeeze(0))
    plt.title('Ground Truth')
    plt.subplot(133)
    plt.imshow(score)
    plt.title('Anomaly Score')
    plt.colorbar()
    plt.tight_layout()
    plt.pause(1)
    plt.show()

def plot_enhanced(image, grnd_truth, score, threshold=None, save_path=None):
    """Enhanced plot with threshold and better visualization"""
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        plt.figure(figsize=(18, 5))

        # Original image
        plt.subplot(141)
        plt.imshow(image[0].permute(1, 2, 0))
        plt.title('Original Image')
        plt.axis('off')

        # Ground truth
        plt.subplot(142)
        plt.imshow(grnd_truth.squeeze(0).squeeze(0), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')

        # Anomaly score
        plt.subplot(143)
        heatmap = plt.imshow(score, cmap='jet')
        plt.title('Anomaly Score')
        plt.axis('off')
        plt.colorbar(heatmap, fraction=0.046, pad=0.04)

        # Thresholded result (if threshold provided)
        if threshold is not None:
            plt.subplot(144)
            binary_mask = Binarization(score, threshold)
            plt.imshow(binary_mask, cmap='gray')
            plt.title(f'Binary Result (t={threshold:.3f})')
            plt.axis('off')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def plot_roc_curve(fpr, tpr, auc_score, save_path=None):
    """Plot ROC curve with AUC score"""
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        plt.figure(figsize=(8, 8))
        plt.plot(fpr, tpr, lw=2, color='darkorange', label=f'ROC curve (AUC = {auc_score:.3f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic')
        plt.legend(loc="lower right")
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def plot_precision_recall_curve(precision, recall, pr_auc, save_path=None):
    """Plot precision-recall curve with AUC score"""
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        plt.figure(figsize=(8, 8))
        plt.plot(recall, precision, lw=2, color='green', label=f'PR curve (AUC = {pr_auc:.3f})')
        plt.fill_between(recall, precision, alpha=0.2, color='green')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall Curve')
        plt.legend(loc="lower left")
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def plot_confusion_matrix(y_true, y_pred, save_path=None):
    """Plot confusion matrix using seaborn"""
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(8, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=['Normal', 'Anomaly'],
                    yticklabels=['Normal', 'Anomaly'])
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def plot_score_distributions(normal_scores, anomaly_scores, threshold=None, save_path=None):
    """Plot histogram of normal and anomaly scores with threshold line if provided"""
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        plt.figure(figsize=(10, 6))
        sns.histplot(normal_scores, color=PLOT_COLORS["normal"], label="Normal", alpha=0.6, kde=True)
        sns.histplot(anomaly_scores, color=PLOT_COLORS["anomaly"], label="Anomaly", alpha=0.6, kde=True)

        if threshold is not None:
            plt.axvline(x=threshold, color='black', linestyle='--', label=f'Threshold: {threshold:.3f}')

        plt.xlabel('Anomaly Score')
        plt.ylabel('Density')
        plt.title('Distribution of Anomaly Scores')
        plt.legend()

        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def visualize_regions(image, score_map, threshold, min_area=100, save_path=None):
    """Visualize detected anomaly regions"""
    # Threshold the score map and find contours
    binary_mask = Binarization(score_map, threshold)
    labeled_mask = measure.label(binary_mask)
    regions = measure.regionprops(labeled_mask)

    # Filter small regions
    filtered_regions = [region for region in regions if region.area >= min_area]

    # Visualize
    with sns.plotting_context("talk"):
        sns.set_style(PLOT_STYLE)
        plt.figure(figsize=(12, 8))

        # Original image
        plt.subplot(121)
        plt.imshow(image[0].permute(1, 2, 0))
        plt.title('Original Image')
        plt.axis('off')

        # Image with detected regions
        plt.subplot(122)
        plt.imshow(image[0].permute(1, 2, 0))

        # Draw bounding boxes around regions
        for region in filtered_regions:
            minr, minc, maxr, maxc = region.bbox
            rect = plt.Rectangle((minc, minr), maxc - minc, maxr - minr,
                                fill=False, edgecolor='red', linewidth=2)
            plt.gca().add_patch(rect)

            # Add area text
            plt.text(minc, minr - 5, f"Area: {region.area}",
                    color='white', fontsize=9, backgroundcolor='red')

        plt.title(f'Detected Anomalies (n={len(filtered_regions)})')
        plt.axis('off')

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
        plt.show()

def binImage(heatmap, thres=0):
    """Binarize heatmap using threshold"""
    _, heatmap_bin = cv2.threshold(heatmap, thres, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    return heatmap_bin

def selectMaxConnect(heatmap):
    """Select largest connected component"""
    labeled_img, num = label(heatmap, connectivity=2, background=0, return_num=True)
    max_label = 0
    max_num = 0
    for i in range(1, num+1):
        if np.sum(labeled_img == i) > max_num:
            max_num = np.sum(labeled_img == i)
            max_label = i
    lcc = (labeled_img == max_label)
    if max_num == 0:
        lcc = (labeled_img == -1)
    lcc = lcc + 0
    return lcc

# Initialize weight function
def initialize_weights(*models):
    """Initialize network weights using kaiming normal"""
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

def add_noise(latent, noise_type="gaussian", sd=0.2):
    """
    Add noise to latent features

    Arguments:
    'gaussian' (string): Gaussian-distributed additive noise.
    'speckle' (string) : Multiplicative noise using out = image + n*image, where n is uniform noise.
    'sd' (float) : standard deviation used for generating noise
    """
    assert sd >= 0.0
    if noise_type == "gaussian":
        mean = 0.
        n = torch.distributions.Normal(torch.tensor([mean]), torch.tensor([sd]))
        noise = n.sample(latent.size()).squeeze(-1).cuda()
        latent = latent + noise
        return latent

    if noise_type == "speckle":
        noise = torch.randn(latent.size()).cuda()
        latent = latent + latent * noise
        return latent

def ran_generator(length, shots=1):
    """Generate random indices"""
    rand_list = random.sample(range(0, length), shots)
    return rand_list

class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = self.create_window(window_size, self.channel)

    def gaussian(self, window_size, sigma):
        gauss = torch.Tensor([np.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
        return gauss/gauss.sum()

    def create_window(self, window_size, channel):
        _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
        return window

    def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
        mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
        mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1*mu2

        sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq
        sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq
        sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2

        C1 = 0.01**2
        C2 = 0.03**2

        ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

        if size_average:
            return ssim_map.mean()
        else:
            return ssim_map.mean(1).mean(1).mean(1)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = self.create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return self._ssim(img1, img2, window, self.window_size, channel, self.size_average)

def log_gaussian(x, mean, logvar):
    '''
    Computes the Gaussian log-likelihoods
    '''
    x = x.unsqueeze(-1).expand_as(logvar)
    a = (x - mean) ** 2  # works on multiple samples thanks to tensor broadcasting
    log_p = (logvar + a / (torch.exp(logvar))).sum(2)
    log_p = -0.5 * (np.log(2 * np.pi) + log_p)

    return log_p

def log_gmm(x, means, logvars, weights, total=True):
    '''
    Computes the Gaussian Mixture Model log-likelihoods
    '''
    res = -log_gaussian(x, means, logvars)  # negative of log likelihood

    res = weights * res

    if total:
        return torch.sum(res, 2)
    else:
        return res

def mdn_loss_function(x, means, logvars, weights, test=False):
    if test:
        res = log_gmm(x, means, logvars, weights)
    else:
        res = torch.mean(torch.sum(log_gmm(x, means, logvars, weights), 1))
    return res


# Data

In [None]:
# Data
########################

def read_files(root, d, product, data_motive='train', use_good=True, normal=True):
    '''
    Return the path of the train directory and list of train images

    Parameters:
        root : root directory of mvtech images
        d = List of directories in the root directory
        product : name of the product to return the images for single class training
        data_motive : Can be 'train' or 'test' or 'ground_truth'
        use_good : To use the data in the good folder
        normal : Signify if the normal images are included

    Returns:
        Path and Image ordered dict for the dataset
    '''
    files = next(os.walk(os.path.join(root, d)))[1]
    for d_in in files:
        if os.path.isdir(os.path.join(root, d, d_in)):
            if d_in == data_motive:
                im_pt = OrderedDict()
                file = os.listdir(os.path.join(root, d, d_in))

                for i in file:
                    if os.path.isdir(os.path.join(root, d, d_in, i)):
                        if (data_motive == 'train'):
                            tr_img_pth = os.path.join(root, d, d_in, i)
                            images = os.listdir(tr_img_pth)
                            im_pt[tr_img_pth] = images
                            print(f'total {d_in} images of {i} {d} are: {len(images)}')

                        if (data_motive == 'test'):
                            if (use_good == False) and (i == 'good') and normal != True:
                                print(f'the good images for {d_in} images of {i} {d} is not included in the test anomalous data')
                            elif (use_good == False) and (i != 'good') and normal != True:
                                tr_img_pth = os.path.join(root, d, d_in, i)
                                images = os.listdir(tr_img_pth)
                                im_pt[tr_img_pth] = images
                                print(f'total {d_in} images of {i} {d} are: {len(images)}')
                            elif (use_good == True) and (i == 'good') and (normal == True):
                                tr_img_pth = os.path.join(root, d, d_in, i)
                                images = os.listdir(tr_img_pth)
                                im_pt[tr_img_pth] = images
                                print(f'total {d_in} images of {i} {d} are: {len(images)}')
                        if (data_motive == 'ground_truth'):
                            tr_img_pth = os.path.join(root, d, d_in, i)
                            images = os.listdir(tr_img_pth)
                            im_pt[tr_img_pth] = images
                            print(f'total {d_in} images of {i} {d} are: {len(images)}')
                if product == "all":
                    return
                else:
                    return im_pt

def load_images(path, image_name):
    """Load image from path"""
    return imread(os.path.join(path, image_name))

def Train_data(root, product='bottle', use_good=True):
    '''
    Return the path of the train directory and list of train images
    '''
    dir = os.listdir(root)

    for d in dir:
        if product == "all":
            read_files(root, d, product, data_motive='train')
        elif product == d:
            pth_img = read_files(root, d, product, data_motive='train')
            return pth_img

def Test_anom_data(root, product='bottle', use_good=False):
    '''
    Return path and images for anomalous test data
    '''
    dir = os.listdir(root)

    for d in dir:
        if product == "all":
            read_files(root, d, product, data_motive='test', use_good=use_good, normal=False)
        elif product == d:
            pth_img_dict = read_files(root, d, product, data_motive='test', use_good=use_good, normal=False)
            return pth_img_dict

def Test_anom_mask(root, product='bottle', use_good=False):
    '''
    Return path and images for anomalous test masks
    '''
    dir = os.listdir(root)

    for d in dir:
        if product == "all":
            read_files(root, d, product, data_motive='test', use_good=use_good, normal=False)
        elif product == d:
            pth_img_dict = read_files(root, d, product, data_motive='ground_truth', use_good=use_good, normal=False)
            return pth_img_dict

def Test_normal_data(root, product='bottle', use_good=True):
    '''
    Return path and images for normal test data
    '''
    if product == 'all':
        print('Please choose a valid product. Normal test data can be seen product wise')
        return
    dir = os.listdir(root)

    for d in dir:
        if product == d:
            pth_img = read_files(root, d, product, data_motive='test', use_good=True, normal=True)
            return pth_img

def Process_mask(mask):
    """Process mask to binary format"""
    mask = np.where(mask > 0., 1, mask)
    return torch.tensor(mask)

class Mvtec:
    def __init__(self, batch_size, root=MVTEC_DATA_DIR, product='bottle'):
        self.root = root
        self.batch = batch_size
        self.product = product
        torch.manual_seed(123)

        if self.product == 'all':
            print('--------Please select a valid product.......See Train_data function-----------')
        else:
            # Importing all the image_path dictionaries for test and train data
            train_path_images = Train_data(root=self.root, product=self.product)
            test_anom_path_images = Test_anom_data(root=self.root, product=self.product)
            test_anom_mask_path_images = Test_anom_mask(root=self.root, product=self.product)
            test_norm_path_images = Test_normal_data(root=self.root, product=self.product)

            # Image Transformation
            T = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((550, 550)),
                transforms.CenterCrop(512),
                transforms.ToTensor(),
            ])

            train_normal_image = torch.stack([T(load_images(j, i)) for j in train_path_images.keys() for i in train_path_images[j]])
            test_anom_image = torch.stack([T(load_images(j, i)) for j in test_anom_path_images.keys() for i in test_anom_path_images[j]])
            test_normal_image = torch.stack([T(load_images(j, i)) for j in test_norm_path_images.keys() for i in test_norm_path_images[j]])

            train_normal_mask = torch.zeros(train_normal_image.size(0), 1, train_normal_image.size(2), train_normal_image.size(3))
            test_normal_mask = torch.zeros(test_normal_image.size(0), 1, test_normal_image.size(2), test_normal_image.size(3))
            test_anom_mask = torch.stack([Process_mask(T(load_images(j, i))) for j in test_anom_mask_path_images.keys() for i in test_anom_mask_path_images[j]])

            train_normal = tuple(zip(train_normal_image, train_normal_mask))
            test_anom = tuple(zip(test_anom_image, test_anom_mask))
            test_normal = tuple(zip(test_normal_image, test_normal_mask))

            print(f' --Size of {self.product} train loader: {train_normal_image.size()}--')
            if test_anom_image.size(0) == test_anom_mask.size(0):
                print(f' --Size of {self.product} test anomaly loader: {test_anom_image.size()}--')
            else:
                print(f'[!Info] Size Mismatch between Anomaly images {test_anom_image.size()} and Masks {test_anom_mask.size()} Loaded')
            print(f' --Size of {self.product} test normal loader: {test_normal_image.size()}--')

            # Create validation set
            num = ran_generator(len(test_anom), 10)
            val_anom = [test_anom[i] for i in num]
            num = ran_generator(len(test_normal), 10)
            val_norm = [test_normal[j] for j in num]
            val_set = [*val_norm, *val_anom]
            print(f' --Total Image in {self.product} Validation loader: {len(val_set)}--')

            # Final Data Loader - Updated with num_workers=0
            self.train_loader = torch.utils.data.DataLoader(train_normal, batch_size=batch_size, shuffle=True, **NUM_WORKERS_PARAM)
            self.test_anom_loader = torch.utils.data.DataLoader(test_anom, batch_size=batch_size, shuffle=False, **NUM_WORKERS_PARAM)
            self.test_norm_loader = torch.utils.data.DataLoader(test_normal, batch_size=batch_size, shuffle=False, **NUM_WORKERS_PARAM)
            self.validation_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, **NUM_WORKERS_PARAM)

class BTADDataset(data.Dataset):
    def __init__(self, images, masks=None, resize=True):
        """
        BTAD dataset class

        Args:
            images: List of image tensors or paths
            masks: List of mask tensors or paths (optional)
            resize: Whether to resize images
        """
        self.images = images
        self.masks = masks
        self.resize = resize

        self.resize_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((300, 300)),  # Adjusted size
            transforms.ToTensor()])
        self.default_transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):
        # Check if image is already a tensor or a path
        if isinstance(self.images[i], str):
            # Load image from path
            image_ = io.imread(self.images[i])

            # Convert grayscale to RGB if needed
            if len(image_.shape) < 3:
                image_ = np.stack((image_,) * 3, axis=-1)

            if self.resize:
                image = self.resize_transform(image_)
            else:
                image = self.default_transform(image_)
        else:
            # Image is already a tensor
            image = self.images[i]

        # Check if masks are provided
        if self.masks is not None:
            if isinstance(self.masks[i], str):
                # Load mask from path
                mask_ = io.imread(self.masks[i], as_gray=True)
                mask = torch.from_numpy(mask_).unsqueeze(0).float()
            else:
                # Mask is already a tensor
                mask = self.masks[i]
        else:
            # Create empty label (for training on normal data)
            mask = torch.zeros(1, image.size(1), image.size(2))

        return image, mask

def load_btad_dataset(root_dir, product, batch_size, mode='train', resize=True, use_normal=True, anomalous=False):
    """
    Load BTAD dataset directly from Google Drive structure

    Args:
        root_dir: Root directory of BTAD dataset
        product: Product category (e.g., '01', '02', '03')
        batch_size: Batch size for data loader
        mode: 'train' or 'test'
        resize: Whether to resize images
        use_normal: Include normal ('ok') samples
        anomalous: Specify if loading anomalous ('ko') data

    Returns:
        DataLoader for the specified dataset
    """
    # Check if product is valid
    if not os.path.exists(os.path.join(root_dir, product)):
        available_products = os.listdir(root_dir)
        raise ValueError(f"Product '{product}' not found. Available products: {available_products}")

    # Path to the product folder
    product_path = os.path.join(root_dir, product)

    # Get image paths based on mode
    image_paths = []
    mask_paths = []

    if mode == 'train':
        # For training, we usually just use the normal samples
        train_path = os.path.join(product_path, 'train')
        if os.path.exists(train_path):
            # In BTAD, normal samples are labeled as 'ok'
            ok_path = os.path.join(train_path, 'ok')
            if os.path.exists(ok_path) and use_normal:
                for img_file in os.listdir(ok_path):
                    if img_file.endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                        image_paths.append(os.path.join(ok_path, img_file))

            # In BTAD, anomalous samples are labeled as 'ko'
            ko_path = os.path.join(train_path, 'ko')
            if os.path.exists(ko_path) and anomalous:
                for img_file in os.listdir(ko_path):
                    if img_file.endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                        image_paths.append(os.path.join(ko_path, img_file))
        else:
            print(f"Warning: Train path {train_path} does not exist.")

    elif mode == 'test':
        # For testing, include both normal and anomalous samples
        test_path = os.path.join(product_path, 'test')
        gt_path = os.path.join(product_path, 'ground_truth')

        if os.path.exists(test_path):
            # Process normal samples (ok)
            ok_path = os.path.join(test_path, 'ok')
            if os.path.exists(ok_path) and use_normal:
                for img_file in os.listdir(ok_path):
                    if img_file.endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                        image_paths.append(os.path.join(ok_path, img_file))
                        # Add empty mask path for normal samples
                        mask_paths.append(None)

            # Process anomalous samples (ko)
            ko_path = os.path.join(test_path, 'ko')
            if os.path.exists(ko_path) and anomalous:
                for img_file in os.listdir(ko_path):
                    if img_file.endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                        image_paths.append(os.path.join(ko_path, img_file))

                        # Try to find corresponding mask in ground_truth
                        gt_ko_path = os.path.join(gt_path, 'ko') if os.path.exists(gt_path) else None
                        if gt_ko_path and os.path.exists(gt_ko_path):
                            # Find matching mask (might have different extension)
                            img_base = os.path.splitext(img_file)[0]
                            mask_found = False
                            for ext in ['.png', '.jpg', '.jpeg', '.bmp']:
                                mask_file = img_base + ext
                                mask_path = os.path.join(gt_ko_path, mask_file)
                                if os.path.exists(mask_path):
                                    mask_paths.append(mask_path)
                                    mask_found = True
                                    break

                            if not mask_found:
                                print(f"Warning: No mask found for {img_file}")
                                mask_paths.append(None)
                        else:
                            # No ground truth directory
                            mask_paths.append(None)
        else:
            print(f"Warning: Test path {test_path} does not exist.")

    print(f"Found {len(image_paths)} images for {product} ({mode}, {'ok' if use_normal else ''}{'+ko' if anomalous else ''})")

    # If no masks were found or provided, set to None
    if len(mask_paths) == 0 or all(m is None for m in mask_paths):
        mask_paths = None

    # Create dataset
    dataset = BTADDataset(image_paths, mask_paths, resize)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=(mode == 'train'), **NUM_WORKERS_PARAM)

    return data_loader


# Network

In [None]:
# Network
########################

class DyT(nn.Module):
    """ Dynamic Tanh Layer """
    def __init__(self, dim, init_a=DYT_INIT_A):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(init_a, dtype=torch.float32)) #0.5 come il paper di Lecun
        self.gamma = nn.Parameter(torch.ones(dim, dtype=torch.float32) * 0.9) #Come il paper di LeCun, ma gamma leggermente più piccolo
        self.beta = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) #Come il paper di LeCun
        # To ensure gamma and beta can be broadcasted correctly with (B, N, D) input
        # they need to be shaped as (1, 1, D) or (D)
        # PyTorch handles (D) broadcasting to (B,N,D) if op is like self.gamma * tensor

    def forward(self, x):
        # x shape: (batch_size, num_patches, embed_dim)
        return self.gamma * torch.tanh(self.alpha * x) + self.beta

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        if USE_DYT:
            self.norm = DyT(dim)
        else:
            self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x, mask=None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value=True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

        attn = dots.softmax(dim=-1)
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads=heads))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
            ]))

    def forward(self, x, mask=None):
        for attn, ff in self.layers:
            x = attn(x, mask=mask)
            x = ff(x)
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.transformer = Transformer(dim, depth, heads, mlp_dim)

        self.to_cls_token = nn.Identity()

    def forward(self, img, mask=None):
        p = self.patch_size

        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]

        x = self.transformer(x, mask)

        x = self.to_cls_token(x[:, 1:, :])

        return x

class Unity(nn.Module):
    def __init__(self, ks, in_ch=512):
        super(Unity, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_ch, out_channels=512, kernel_size=ks)

    def forward(self, x):
        return F.relu(self.conv(x), inplace=True)

class Spatial_Scorer(nn.Module):
    def __init__(self, in_dim=512, test=False):
        super(Spatial_Scorer, self).__init__()
        self.test = test

        self.layers = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 128),
            nn.ReLU(True),
            nn.Linear(128, 1),
            nn.Tanh())

        if not self.test:
            print("Initializing Spatial scorer network...")
            initialize_weights(self.layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        F_out = self.layers(x)
        return F_out

class DigitCaps(nn.Module):
    def __init__(self, out_num_caps=1, in_num_caps=8*8*64, in_dim_caps=8, out_dim_caps=512, decode_idx=-1):
        super(DigitCaps, self).__init__()

        self.conf = Config()
        self.in_dim_caps = in_dim_caps
        self.in_num_caps = in_num_caps
        self.out_dim_caps = out_dim_caps
        self.out_num_caps = out_num_caps
        self.decode_idx = decode_idx
        self.W = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

    def forward(self, x):
        # x size: batch x 1152 x 8
        x_hat = torch.squeeze(torch.matmul(self.W, x[:, None, :, :, None]), dim=-1)
        x_hat_detached = x_hat.detach()
        # x_hat size: batch x ndigits x 1152 x 16
        b = Variable(torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps))
        # b size: batch x ndigits x 1152
        if self.conf.USE_CUDA:
            b = b.cuda()

        # Routing algorithm
        num_iters = 3
        for i in range(num_iters):
            c = F.softmax(b, dim=1)
            # c size: batch x ndigits x 1152
            if i == num_iters - 1:
                # output size: batch x ndigits x 1 x 16
                outputs = self.squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
            else:
                outputs = self.squash(torch.sum(c[:, :, :, None] * x_hat_detached, dim=-2, keepdim=True))
                b = b + torch.sum(outputs * x_hat_detached, dim=-1)

        outputs = torch.squeeze(outputs, dim=-2)  # squeezing to remove ones at the dimension -1

        # Choose the longest vector as the one to decode
        if self.decode_idx == -1:
            classes = torch.sqrt((outputs ** 2).sum(2))
            classes = F.softmax(classes, dim=1)
            _, max_length_indices = classes.max(dim=1)
        else:  # always choose the same digitcaps
            max_length_indices = torch.ones(outputs.size(0)).long() * self.decode_idx
            if self.conf.USE_CUDA:
                max_length_indices = max_length_indices.cuda()

        masked = Variable(torch.sparse.torch.eye(self.out_num_caps))
        if self.conf.USE_CUDA:
            masked = masked.cuda()
        masked = masked.index_select(dim=0, index=max_length_indices)
        t = (outputs * masked[:, :, None]).sum(dim=1).unsqueeze(1)

        return t, outputs

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor
class decoder2(nn.Module):
    def __init__(self, in_channels):
        super(decoder2, self).__init__()
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 32, 9, stride=3, padding=1),
            nn.BatchNorm2d(32, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 32, 7, stride=5, padding=1),
            nn.BatchNorm2d(32, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 9, stride=2),
            nn.BatchNorm2d(16, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 6, stride=1),
            nn.BatchNorm2d(8, affine=True),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 3, 11, stride=1),
            nn.Tanh()
        )

    def forward(self, x):
        recon = self.decoder2(x)
        return recon

class VT_AE(nn.Module):
    def __init__(self, image_size=512,
                patch_size=64,
                num_classes=1,
                dim=512,
                depth=6,
                heads=8,
                mlp_dim=1024,
                train=True):

        super(VT_AE, self).__init__()
        self.vt = ViT(
            image_size=image_size,
            patch_size=patch_size,
            num_classes=num_classes,
            dim=dim,
            depth=depth,
            heads=heads,
            mlp_dim=mlp_dim)

        self.decoder = decoder2(8)
        self.Digcap = DigitCaps(in_num_caps=((image_size//patch_size)**2)*8*8, in_dim_caps=8)
        self.mask = torch.ones(1, image_size//patch_size, image_size//patch_size).bool().cuda()
        self.Train = train

        if self.Train:
            print("\nInitializing network weights...")
            initialize_weights(self.vt, self.decoder)

    def forward(self, x):
        b = x.size(0)
        encoded = self.vt(x, self.mask)
        if self.Train:
            encoded = add_noise(encoded)
        encoded1, vectors = self.Digcap(encoded.view(b, encoded.size(1)*8*8, -1))
        recons = self.decoder(encoded1.view(b, -1, 8, 8))

        return encoded, recons

class MDN(nn.Module):
    def __init__(self, input_dim=512, out_dim=512, layer_size=512, coefs=10, test=False, sd=0.5):
        super(MDN, self).__init__()
        self.in_features = input_dim

        self.pi = nn.Linear(layer_size, coefs, bias=False)
        self.mu = nn.Linear(layer_size, out_dim * coefs, bias=False)  # mean
        self.sigma_sq = nn.Linear(layer_size, out_dim * coefs, bias=False)  # isotropic independent variance
        self.out_dim = out_dim
        self.coefs = coefs
        self.test = test
        self.sd = sd

    def forward(self, x):
        ep = np.finfo(float).eps
        x = torch.clamp(x, ep)

        pi = F.softmax(self.pi(x), dim=-1)
        sigma_sq = F.softplus(self.sigma_sq(x)).view(x.size(0), x.size(1), self.in_features, -1)  # logvar
        mu = self.mu(x).view(x.size(0), x.size(1), self.in_features, -1)  # mean
        return pi, mu, sigma_sq


# Train

In [None]:
# Train
########################

def train_model(args):
    """
    Train the anomaly detection model

    Args:
        args: Command line arguments
    """
    print(f"\n{'='*20} Training on {args.dataset} dataset {'='*20}")
    print(f"Product: {args.product}, Epochs: {args.epochs}, Learning Rate: {args.learning_rate}")

    # Initialize wandb
    wandb.login(key='bee3a112e4b13d4b3341435b7e978e7c4b8c7e31')
    wandb.init(
        project=f"anomaly-detection-{args.dataset}",
        config={
            "dataset": args.dataset,
            "product": args.product,
            "epochs": args.epochs,
            "batch_size": args.batch_size,
            "learning_rate": args.learning_rate,
            "patch_size": args.patch_size
        }
    )

    # Initialize tensorboard
    writer = SummaryWriter()

    # Initialize SSIM loss
    ssim_loss = SSIM().cuda()

    # Load dataset
    if args.dataset.lower() == 'mvtec':
        data = Mvtec(args.batch_size, root=MVTEC_DATA_DIR, product=args.product)
    elif args.dataset.lower() == 'btad':
        data = load_btad_dataset(os.path.join(BTAD_DATA_DIR, 'train.csv'), args.batch_size)
    else:
        raise ValueError(f"Dataset {args.dataset} not supported")

    # Initialize models
    model = VT_AE(patch_size=args.patch_size, train=True).cuda()
    G_estimate = MDN().cuda()

    # Initialize optimizer
    optimizer = Adam(list(model.parameters()) + list(G_estimate.parameters()),
                    lr=args.learning_rate, weight_decay=0.0001)

    # Set models to train mode
    model.train()
    G_estimate.train()

    # Training loop
    minloss = float('inf')
    best_epoch = 0

    print('\nNetwork training started...')
    for epoch in range(args.epochs):
        epoch_losses = []
        epoch_reconstruction_losses = []
        epoch_ssim_losses = []
        epoch_mdn_losses = []

        # Get the right dataloader based on dataset
        if args.dataset.lower() == 'mvtec':
            train_loader = data.train_loader
        else:
            train_loader = data

        for images, masks in train_loader:
            # Handle single channel images
            if images.size(1) == 1:
                images = torch.stack([images, images, images]).squeeze(2).permute(1, 0, 2, 3)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass through models
            vector, reconstructions = model(images.cuda())
            pi, mu, sigma = G_estimate(vector)

            # Calculate losses
            reconstruction_loss = F.mse_loss(reconstructions, images.cuda(), reduction='mean')  # Reconstruction Loss
            ssim_loss_val = -ssim_loss(images.cuda(), reconstructions)  # Structural similarity loss
            mdn_loss = mdn_loss_function(vector, mu, sigma, pi)  # Mixture density network loss

            # Total loss
            total_loss = 5 * reconstruction_loss + 0.5 * ssim_loss_val + mdn_loss

            # Store loss
            epoch_losses.append(total_loss.item())
            epoch_reconstruction_losses.append(reconstruction_loss.item())
            epoch_ssim_losses.append(ssim_loss_val.item())
            epoch_mdn_losses.append(mdn_loss.item())

            # Log to tensorboard
            writer.add_scalar('reconstruction_loss', reconstruction_loss.item(), epoch)
            writer.add_scalar('ssim_loss', ssim_loss_val.item(), epoch)
            writer.add_scalar('mdn_loss', mdn_loss.item(), epoch)
            writer.add_histogram('feature_vectors', vector)

            # Backpropagate and update weights
            total_loss.backward()
            optimizer.step()

        # Calculate epoch average loss
        avg_epoch_loss = np.mean(epoch_losses)
        avg_reconstruction_loss = np.mean(epoch_reconstruction_losses)
        avg_ssim_loss = np.mean(epoch_ssim_losses)
        avg_mdn_loss = np.mean(epoch_mdn_losses)

        # Log to wandb
        wandb.log({
            "epoch": epoch,
            "total_loss": avg_epoch_loss,
            "reconstruction_loss": avg_reconstruction_loss,
            "ssim_loss": avg_ssim_loss,
            "mdn_loss": avg_mdn_loss
        })

        # Log reconstructed images
        if epoch % 10 == 0:
            # Get a sample batch for visualization
            sample_images, _ = next(iter(train_loader))
            if sample_images.size(1) == 1:
                sample_images = torch.stack([sample_images, sample_images, sample_images]).squeeze(2).permute(1, 0, 2, 3)

            with torch.no_grad():
                _, sample_reconstructions = model(sample_images.cuda())

            # Create comparison grid
            comparison = torch.cat([sample_images[:4], sample_reconstructions[:4].cpu()])
            grid = utils.make_grid(comparison, nrow=4)

            # Log to tensorboard and wandb
            writer.add_image('Reconstructions', grid, epoch)
            wandb.log({"reconstructions": wandb.Image(grid)})

        # Log to tensorboard
        writer.add_scalar('Average_Epoch_Loss', avg_epoch_loss, epoch)

        print(f'Epoch {epoch+1}/{args.epochs}, Loss: {avg_epoch_loss:.6f}')

        # Save best model
        if avg_epoch_loss < minloss:
            minloss = avg_epoch_loss
            best_epoch = epoch
            os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
            model_path = os.path.join(MODEL_SAVE_DIR, f'VT_AE_{args.dataset}_{args.product}.pt')
            g_path = os.path.join(MODEL_SAVE_DIR, f'G_estimate_{args.dataset}_{args.product}.pt')
            torch.save(model.state_dict(), model_path)
            torch.save(G_estimate.state_dict(), g_path)
            print(f"Saved best model at epoch {epoch+1} with loss {minloss:.6f}")

            # Save to wandb
            wandb.save(model_path)
            wandb.save(g_path)

        # Print best epoch info
        print(f'Best epoch so far: {best_epoch+1} with loss: {minloss:.6f}')

    # Finish logging
    writer.close()
    wandb.finish()

    print(f"\nTraining completed! Best model saved at epoch {best_epoch+1} with loss {minloss:.6f}")
    print(f"Model saved at: {os.path.join(MODEL_SAVE_DIR, f'VT_AE_{args.dataset}_{args.product}.pt')}")
    print(f"MDN model saved at: {os.path.join(MODEL_SAVE_DIR, f'G_estimate_{args.dataset}_{args.product}.pt')}")


# Evaluation

In [None]:
# Evaluation
########################

def calculate_threshold(model, G_estimate, data_loaders, patch_size, fpr_threshold=0.3):
    """
    Calculate threshold for anomaly detection

    Args:
        model: Trained VT_AE model
        G_estimate: Trained MDN model
        data_loaders: List of data loaders to use for threshold calculation
        patch_size: Patch size used in the model
        fpr_threshold: False positive rate threshold

    Returns:
        threshold: Calculated threshold
    """
    norm_loss_values = []
    normalised_scores = []
    mask_scores = []

    for data_loader in data_loaders:
        for images, masks in data_loader:
            # Handle single channel images
            if images.size(1) == 1:
                images = torch.stack([images, images, images]).squeeze(2).permute(1, 0, 2, 3)

            # Forward pass
            vector, reconstructions = model(images.cuda())
            pi, mu, sigma = G_estimate(vector)

            # Calculate MDN loss (anomaly score)
            mdn_loss = mdn_loss_function(vector, mu, sigma, pi, test=True)
            norm_loss_values.append(mdn_loss.detach().cpu().numpy())

            # Process masks and scores for ROC calculation
            m = torch.nn.UpsamplingBilinear2d((512, 512))
            norm_score = norm_loss_values[-1].reshape(-1, 1, 512//patch_size, 512//patch_size)
            score_map = m(torch.tensor(norm_score))
            score_map = Filter(score_map, type=0)

            mask_scores.append(masks.squeeze(0).squeeze(0).cpu().numpy())
            normalised_scores.append(score_map)

    # Flatten scores and masks
    scores = np.asarray(normalised_scores).flatten()
    masks = np.asarray(mask_scores).flatten()

    # Calculate ROC curve and find threshold
    fpr, tpr, thresholds = roc_curve(masks, scores)
    fp_indices = np.where(fpr <= fpr_threshold)
    threshold = thresholds[fp_indices[-1][-1]]

    return threshold

def test_model(args):
    """
    Test the anomaly detection model

    Args:
        args: Command line arguments
    """
    print(f"\n{'='*20} Testing on {args.dataset} dataset {'='*20}")
    print(f"Product: {args.product}, Patch Size: {args.patch_size}")

    # Initialize wandb
    wandb.login()
    wandb.init(
        project=f"anomaly-detection-{args.dataset}-test",
        config={
            "dataset": args.dataset,
            "product": args.product,
            "patch_size": args.patch_size
        }
    )

    # Create results directory
    results_dir = os.path.join(MODEL_SAVE_DIR, f"results_{args.dataset}_{args.product}")
    os.makedirs(results_dir, exist_ok=True)

    # Initialize SSIM loss
    ssim_loss = SSIM().cuda()

    # Load dataset
    if args.dataset.lower() == 'mvtec':
        data = Mvtec(1, root=MVTEC_DATA_DIR, product=args.product)
    elif args.dataset.lower() == 'btad':
        # For BTAD, we need different handling for test
        raise NotImplementedError("BTAD test data loading not implemented yet")
    else:
        raise ValueError(f"Dataset {args.dataset} not supported")

    # Load models
    model = VT_AE(patch_size=args.patch_size, train=False).cuda()
    G_estimate = MDN().cuda()

    # Load saved weights
    model_path = os.path.join(MODEL_SAVE_DIR, f'VT_AE_{args.dataset}_{args.product}.pt')
    g_path = os.path.join(MODEL_SAVE_DIR, f'G_estimate_{args.dataset}_{args.product}.pt')

    try:
        model.load_state_dict(torch.load(model_path))
        G_estimate.load_state_dict(torch.load(g_path))
        print(f"Models loaded from {model_path} and {g_path}")
    except Exception as e:
        print(f"Error loading models: {e}")
        return

    # Set models to eval mode
    model.eval()
    G_estimate.eval()

    # Calculate threshold
    print("Calculating threshold...")
    test_loaders = [data.test_norm_loader, data.test_anom_loader]
    threshold = calculate_threshold(model, G_estimate, test_loaders, args.patch_size)
    print(f"Threshold: {threshold}")

    # Log threshold to wandb
    wandb.config.update({"threshold": threshold})

    # Evaluate on test data
    print("Evaluating on test data...")
    with torch.no_grad():
        # Lists to store results
        normal_losses = []
        anomaly_losses = []
        normal_scores = []
        anomaly_scores = []

        # NEW: Lists to store full pixel-level data
        full_score_maps = []
        full_masks = []

        all_y_true = []
        all_y_pred = []

        # Process normal test data
        for images, masks in data.test_norm_loader:
            if images.size(1) == 1:
                images = torch.stack([images, images, images]).squeeze(2).permute(1, 0, 2, 3)

            vector, reconstructions = model(images.cuda())
            pi, mu, sigma = G_estimate(vector)

            mdn_loss = mdn_loss_function(vector, mu, sigma, pi, test=True)
            normal_losses.append(mdn_loss.sum().item())

            # Generate score map
            m = torch.nn.UpsamplingBilinear2d((512, 512))
            norm_score = mdn_loss.detach().cpu().numpy().reshape(-1, 1, 512//args.patch_size, 512//args.patch_size)
            score_map = m(torch.tensor(norm_score))
            score_map = Filter(score_map.numpy(), type=0)

            # Store full score map and corresponding zero mask for normal samples
            full_score_maps.append(score_map[0][0])
            full_masks.append(np.zeros_like(score_map[0][0]))

            score_val = np.max(score_map)
            normal_scores.append(score_val)

            # Binary prediction (0: normal, 1: anomaly)
            all_y_true.append(0)  # Ground truth: normal
            all_y_pred.append(1 if score_val > threshold else 0)  # Prediction based on threshold

        # Process anomalous test data
        for idx, (images, masks) in enumerate(data.test_anom_loader):
            if images.size(1) == 1:
                images = torch.stack([images, images, images]).squeeze(2).permute(1, 0, 2, 3)

            vector, reconstructions = model(images.cuda())
            pi, mu, sigma = G_estimate(vector)

            mdn_loss = mdn_loss_function(vector, mu, sigma, pi, test=True)
            anomaly_losses.append(mdn_loss.sum().item())

            # Generate score map
            m = torch.nn.UpsamplingBilinear2d((512, 512))
            norm_score = mdn_loss.detach().cpu().numpy().reshape(-1, 1, 512//args.patch_size, 512//args.patch_size)
            score_map = m(torch.tensor(norm_score))
            score_map = Filter(score_map.numpy(), type=0)

            # NEW: Store full score map and corresponding mask
            full_score_maps.append(score_map[0][0])
            full_masks.append(masks.squeeze(0).squeeze(0).cpu().numpy())

            score_val = np.max(score_map)
            anomaly_scores.append(score_val)

            # Binary prediction (0: normal, 1: anomaly)
            all_y_true.append(1)  # Ground truth: anomaly
            all_y_pred.append(1 if score_val > threshold else 0)  # Prediction based on threshold

            # Visualize some results and save
            if idx % 5 == 0:
                # Enhanced visualization
                save_path = os.path.join(results_dir, f"anomaly_{idx}.png")
                plot_enhanced(images, masks, score_map[0][0], threshold, save_path)

                # Region visualization
                region_path = os.path.join(results_dir, f"anomaly_regions_{idx}.png")
                visualize_regions(images, score_map[0][0], threshold, min_area=100, save_path=region_path)

                # Log to wandb
                wandb.log({
                    f"anomaly_sample_{idx}": wandb.Image(save_path),
                    f"anomaly_regions_{idx}": wandb.Image(region_path)
                })

        # Calculate metrics
        print("\nCalculating evaluation metrics...")

        # Convert to numpy arrays
        y_true = np.array(all_y_true)
        y_pred = np.array(all_y_pred)
        roc_labels = np.concatenate((np.zeros(len(normal_losses)), np.ones(len(anomaly_losses))))
        roc_scores = np.concatenate((normal_losses, anomaly_losses))

        # Image-level ROC AUC
        image_auc = roc_auc_score(roc_labels, roc_scores)

        # Calculate pixel-level ROC AUC (PRO score)
        # Flatten and concatenate all score maps and masks
        pixel_scores = np.concatenate([s.flatten() for s in full_score_maps])
        pixel_masks = np.concatenate([m.flatten() for m in full_masks])

        # Now the dimensions will match
        pro_score = roc_auc_score(pixel_masks, pixel_scores)

        # Precision-Recall AUC
        precision, recall, _ = precision_recall_curve(roc_labels, roc_scores)
        pr_auc = auc(recall, precision)

        # Additional metrics
        f1 = f1_score(y_true, y_pred)
        accuracy = accuracy_score(y_true, y_pred)
        precision_score_val = precision_score(y_true, y_pred)
        recall_score_val = recall_score(y_true, y_pred)

        # Generate ROC curve
        fpr, tpr, _ = roc_curve(roc_labels, roc_scores)

        # Plot and save ROC curve
        roc_path = os.path.join(results_dir, "roc_curve.png")
        plot_roc_curve(fpr, tpr, image_auc, roc_path)

        # Plot and save PR curve
        pr_path = os.path.join(results_dir, "pr_curve.png")
        plot_precision_recall_curve(precision, recall, pr_auc, pr_path)

        # Plot and save confusion matrix
        cm_path = os.path.join(results_dir, "confusion_matrix.png")
        plot_confusion_matrix(y_true, y_pred, cm_path)

        # Plot and save score distributions
        dist_path = os.path.join(results_dir, "score_distributions.png")
        plot_score_distributions(normal_scores, anomaly_scores, threshold, dist_path)

        # Log results to wandb
        wandb.log({
            "image_level_auc": image_auc,
            "pixel_level_auc": pro_score,
            "pr_auc": pr_auc,
            "f1_score": f1,
            "accuracy": accuracy,
            "precision": precision_score_val,
            "recall": recall_score_val,
            "roc_curve": wandb.Image(roc_path),
            "pr_curve": wandb.Image(pr_path),
            "confusion_matrix": wandb.Image(cm_path),
            "score_distributions": wandb.Image(dist_path)
        })

        # Log classification report
        cls_report = classification_report(y_true, y_pred, target_names=['Normal', 'Anomaly'])
        with open(os.path.join(results_dir, "classification_report.txt"), "w") as f:
            f.write(cls_report)

        print(f"\nResults for {args.dataset} - {args.product}:")
        print(f"Image-level AUC: {image_auc:.4f}")
        print(f"Pixel-level AUC (PRO): {pro_score:.4f}")
        print(f"Precision-Recall AUC: {pr_auc:.4f}")
        print(f"F1 Score: {f1:.4f}")
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision_score_val:.4f}")
        print(f"Recall: {recall_score_val:.4f}")
        print("\nClassification Report:")
        print(cls_report)

    # Finish wandb
    wandb.finish()

# Main

In [None]:
def main():
    # Check if running in Jupyter/Colab environment
    in_colab = 'google.colab' in str(get_ipython())

    if in_colab:
        # Use default values or allow user to set them directly in Colab
        class Args:
            def __init__(self):
                self.dataset = 'mvtec'  # Default dataset
                self.data_path = MVTEC_DATA_DIR  # Default path
                self.product = 'bottle'  # Default product
                self.patch_size = 64  # Default patch size
                self.mode = 'train'  # Default mode
                self.batch_size = 8  # Default batch size
                self.learning_rate = 0.0001  # Default learning rate
                self.epochs = 100  # Default epochs

        args = Args()
        print("\nRunning in Colab environment. Using default values:")
        print(f"Dataset: {args.dataset}")
        print(f"Product: {args.product}")
        print(f"Mode: {args.mode}")
        print(f"Patch size: {args.patch_size}")
        print(f"Batch size: {args.batch_size}")
        print(f"Learning rate: {args.learning_rate}")
        print(f"Epochs: {args.epochs}")
        print("\nTo change these values, set them directly before running:")
        print("Example: args.dataset = 'btad'; args.product = '01'")

    else:
        # Parse command line arguments as before
        parser = argparse.ArgumentParser(description='Anomaly Detection with Vision Transformer')

        # Dataset parameters
        parser.add_argument('--dataset', type=str, default='mvtec', choices=['mvtec', 'btad'],
                            help='Dataset to use (mvtec or btad)')
        parser.add_argument('--data_path', type=str, default=MVTEC_DATA_DIR,
                            help='Path to dataset (defaults to Google Drive path)')
        parser.add_argument('--product', type=str, default='bottle',
                            help='Product category for MVTec dataset')

        # Model parameters
        parser.add_argument('--patch_size', type=int, default=64,
                            help='Patch size for Vision Transformer')

        # Training parameters
        parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'],
                            help='Mode: train or test')
        parser.add_argument('--batch_size', type=int, default=8,
                            help='Batch size for training')
        parser.add_argument('--learning_rate', type=float, default=0.0001,
                            help='Learning rate')
        parser.add_argument('--epochs', type=int, default=100,
                            help='Number of epochs to train')

        args = parser.parse_args()

        # Print configuration
        print("\nConfiguration:")
        print(f"Mode: {args.mode}")
        print(f"Dataset: {args.dataset}")
        print(f"Data path: {args.data_path}")
        print(f"Product: {args.product}")
        print(f"Patch size: {args.patch_size}")
        if args.mode == 'train':
            print(f"Batch size: {args.batch_size}")
            print(f"Learning rate: {args.learning_rate}")
            print(f"Epochs: {args.epochs}")

    # Set random seeds for reproducibility
    random.seed(123)
    np.random.seed(123)
    torch.manual_seed(123)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(123)

    # Check CUDA availability
    if torch.cuda.is_available():
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("CUDA is not available. Using CPU.")

    # Run appropriate function based on mode
    if args.mode == 'train':
        train_model(args)
    elif args.mode == 'test':
        test_model(args)

# Function to run directly in Colab
def run_in_colab(dataset='mvtec', product='bottle', mode='train', patch_size=64,
                batch_size=8, learning_rate=0.0001, epochs=100):
    """
    Run the anomaly detection model directly from Colab

    Args:
        dataset: Dataset name ('mvtec' or 'btad')
        product: Product category
                MVTec: 'bottle', 'cable', etc.
                BTAD: '01', '02', '03'
        mode: 'train' or 'test'
        patch_size: Patch size for Vision Transformer
        batch_size: Batch size for training
        learning_rate: Learning rate
        epochs: Number of epochs
    """
    class Args:
        def __init__(self):
            self.dataset = dataset

            if dataset.lower() == 'mvtec':
                self.data_path = MVTEC_DATA_DIR
            else:
                self.data_path = BTAD_DATA_DIR

            self.product = product
            self.patch_size = patch_size
            self.mode = mode
            self.batch_size = batch_size
            self.learning_rate = learning_rate
            self.epochs = epochs

    args = Args()

    # Print configuration
    print("\nConfiguration:")
    print(f"Dataset: {args.dataset}")
    print(f"Product: {args.product}")
    print(f"Mode: {args.mode}")
    print(f"Patch size: {args.patch_size}")
    print(f"Batch size: {args.batch_size}")
    print(f"Learning rate: {args.learning_rate}")
    print(f"Epochs: {args.epochs}")

    # Set random seeds for reproducibility
    random.seed(123)
    np.random.seed(123)
    torch.manual_seed(123)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(123)

    # Check CUDA availability
    if torch.cuda.is_available():
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        print("CUDA is not available. Using CPU.")

    # Run appropriate function based on mode
    if args.mode == 'train':
        train_model(args)
    elif args.mode == 'test':
        test_model(args)
    else:
        print(f"Invalid mode: {args.mode}. Choose 'train' or 'test'.")

#run_in_colab(dataset='mvtec', product='bottle', mode='train')
run_in_colab(dataset='mvtec', product='bottle', mode='test')
