# Datasets

## MVTec LOCO Dataset

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class MVTecDataset(Dataset):
    def __init__(self, root_dir, category="all", phase="train", transform=None, anomaly_types=None, clip_transform=None):
        """
        root_dir (string): Directory with all the images.
        category (string): Category of images ('all' or any of the specific categories).
        phase (string): One of 'train', 'validation', or 'test'.
        transform (callable, optional): Optional transform to be applied on a sample.
        anomaly_type (string, optional): Specifies the type of anomaly for the test phase. 
                                          Can be 'good', 'logical_anomalies', or 'structural_anomalies'.
                                          This parameter is ignored if phase is not 'test'.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.phase = phase
        self.anomaly_types = anomaly_types
        self.categories = [category] if category != "all" else ['breakfast_box', 'juice_bottle', 'pushpins', 'screw_bag', 'splicing_connectors']
        self.data = []
        self.clip_transform = clip_transform
        self._load_dataset()

    def _load_dataset(self):
        for category in self.categories:
            if self.phase in ['train', 'validation']:
                category_path = os.path.join(self.root_dir, category, self.phase, 'good')
                for img_name in os.listdir(category_path):
                    if img_name.endswith('.png'):
                        self.data.append(os.path.join(category_path, img_name))
            elif self.phase == 'combined_validation':
                category_path = os.path.join(self.root_dir, category, 'validation', 'good')
                for img_name in os.listdir(category_path):
                    if img_name.endswith('.png'):
                        self.data.append(os.path.join(category_path, img_name))
                category_path = os.path.join(self.root_dir, category, 'test', 'good')
                for img_name in os.listdir(category_path):
                    if img_name.endswith('.png'):
                        self.data.append(os.path.join(category_path, img_name))
            elif self.phase == 'test':
                anomaly_list = ['good', 'logical_anomalies', 'structural_anomalies']
                if self.anomaly_types is not None:
                    anomaly_list = self.anomaly_types
                for anomaly_type in anomaly_list:
                    category_path = os.path.join(self.root_dir, category, self.phase, anomaly_type)
                    if os.path.exists(category_path):
                        for img_name in os.listdir(category_path):
                            if img_name.endswith('.png'):
                                self.data.append((os.path.join(category_path, img_name), anomaly_type))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.phase == 'test':
            img_path, anomaly_type = self.data[idx]
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            gt_mask = torch.zeros(256, 256)
            if anomaly_type != 'good':
                mask_path = os.path.join(img_path.replace('test', 'ground_truth').replace('.png', ''), '000.png')
                mask = Image.open(mask_path).convert('L')
                if self.transform:
                    gt_mask = self.transform(mask)
                gt_mask = gt_mask.squeeze()
            if self.clip_transform:
                clip_image = Image.open(img_path).convert('RGB')
                clip_image = self.clip_transform(clip_image)
                return {
                    'image': image,
                    'gt_mask': gt_mask,
                    'anomaly_type': anomaly_type,
                    'clip_image': clip_image
                }
            return {
                    'image': image,
                    'gt_mask': gt_mask,
                    'anomaly_type': anomaly_type,
                }
                
        else:
            img_path = self.data[idx]
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image

## Imagenet Dataset

In [None]:
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class ImagenetDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        root_dir (string): Directory with all the images.
        category (string): Category of images ('all' or any of the specific categories).
        phase (string): One of 'train', 'validation', or 'test'.
        transform (callable, optional): Optional transform to be applied on a sample.
        anomaly_type (string, optional): Specifies the type of anomaly for the test phase. 
                                          Can be 'good', 'logical_anomalies', or 'structural_anomalies'.
                                          This parameter is ignored if phase is not 'test'.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self._load_dataset()

    def _load_dataset(self):
        train_path = os.path.join(self.root_dir, 'train')
        for category in os.listdir(train_path):
            if os.path.isdir(os.path.join(train_path, category)):
                category_path = os.path.join(train_path, category)
                for img_name in os.listdir(category_path):
                    if img_name.endswith('.JPEG'):
                        self.data.append(os.path.join(category_path, img_name))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

## Definitions

In [None]:
batch_size = 6

In [None]:
image_size = (256, 256)

# Dataset and DataLoader setup
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

data_transforms_imagenet = transforms.Compose(
    [
        transforms.Resize((image_size[0] * 2, image_size[1] * 2)),
        transforms.RandomGrayscale(p=0.3),
        transforms.CenterCrop((image_size[0], image_size[1])),
        transforms.ToTensor(),
    ],
)

train_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='breakfast_box', phase='train', transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

validation_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='breakfast_box', phase='validation', transform=transform)
combined_validation_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='breakfast_box', phase='combined_validation', transform=transform)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, shuffle=False)
combined_validation_loader = DataLoader(dataset=combined_validation_dataset, batch_size=batch_size, shuffle=False)

test_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='breakfast_box', phase='test', transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

imagenet_dataset = ImagenetDataset(root_dir='imagenette2', transform=data_transforms_imagenet)
imagenet_loader = DataLoader(dataset=imagenet_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

# EffecientAD

## Pytorch definition

In [None]:
"""Torch model for student, teacher and autoencoder model in EfficientAd."""

import logging
import math
from enum import Enum

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F  # noqa: N812
from torchvision import transforms

logger = logging.getLogger(__name__)


def imagenet_norm_batch(x: torch.Tensor) -> torch.Tensor:
    """Normalize batch of images with ImageNet mean and std.

    Args:
        x (torch.Tensor): Input batch.

    Returns:
        torch.Tensor: Normalized batch using the ImageNet mean and std.
    """
    mean = torch.tensor([0.485, 0.456, 0.406])[None, :, None, None].to(x.device)
    std = torch.tensor([0.229, 0.224, 0.225])[None, :, None, None].to(x.device)
    return (x - mean) / std


def reduce_tensor_elems(tensor: torch.Tensor, m: int = 2**24) -> torch.Tensor:
    """Reduce tensor elements.

    This function flatten n-dimensional tensors,  selects m elements from it
    and returns the selected elements as tensor. It is used to select
    at most 2**24 for torch.quantile operation, as it is the maximum
    supported number of elements.
    https://github.com/pytorch/pytorch/blob/b9f81a483a7879cd3709fd26bcec5f1ee33577e6/aten/src/ATen/native/Sorting.cpp#L291.

    Args:
        tensor (torch.Tensor): input tensor from which elements are selected
        m (int): number of maximum tensor elements.
            Defaults to ``2**24``

    Returns:
            Tensor: reduced tensor
    """
    tensor = torch.flatten(tensor)
    if len(tensor) > m:
        # select a random subset with m elements.
        perm = torch.randperm(len(tensor), device=tensor.device)
        idx = perm[:m]
        tensor = tensor[idx]
    return tensor


class EfficientAdModelSize(str, Enum):
    """Supported EfficientAd model sizes."""

    M = "medium"
    S = "small"


class SmallPatchDescriptionNetwork(nn.Module):
    """Patch Description Network small.

    Args:
        out_channels (int): number of convolution output channels
        padding (bool): use padding in convoluional layers
            Defaults to ``False``.
    """

    def __init__(self, out_channels: int, padding: bool = False) -> None:
        super().__init__()
        pad_mult = 1 if padding else 0
        self.conv1 = nn.Conv2d(3, 128, kernel_size=4, stride=1, padding=3 * pad_mult)
        self.conv2 = nn.Conv2d(128, 256, kernel_size=4, stride=1, padding=3 * pad_mult)
        self.conv3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1 * pad_mult)
        self.conv4 = nn.Conv2d(256, out_channels, kernel_size=4, stride=1, padding=0 * pad_mult)
        self.avgpool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult)
        self.avgpool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform a forward pass through the network.

        Args:
            x (torch.Tensor): Input batch.

        Returns:
            torch.Tensor: Output from the network.
        """
        x = imagenet_norm_batch(x)
        x = F.relu(self.conv1(x))
        x = self.avgpool1(x)
        x = F.relu(self.conv2(x))
        x = self.avgpool2(x)
        x = F.relu(self.conv3(x))
        return self.conv4(x)


class MediumPatchDescriptionNetwork(nn.Module):
    """Patch Description Network medium.

    Args:
        out_channels (int): number of convolution output channels
        padding (bool): use padding in convoluional layers
            Defaults to ``False``.
    """

    def __init__(self, out_channels: int, padding: bool = False) -> None:
        super().__init__()
        pad_mult = 1 if padding else 0
        self.conv1 = nn.Conv2d(3, 256, kernel_size=4, stride=1, padding=3 * pad_mult)
        self.conv2 = nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=3 * pad_mult)
        self.conv3 = nn.Conv2d(512, 512, kernel_size=1, stride=1, padding=0 * pad_mult)
        self.conv4 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1 * pad_mult)
        self.conv5 = nn.Conv2d(512, out_channels, kernel_size=4, stride=1, padding=0 * pad_mult)
        self.conv6 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0 * pad_mult)
        self.avgpool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult)
        self.avgpool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform a forward pass through the network.

        Args:
            x (torch.Tensor): Input batch.

        Returns:
            torch.Tensor: Output from the network.
        """
        x = imagenet_norm_batch(x)
        x = F.relu(self.conv1(x))
        x = self.avgpool1(x)
        x = F.relu(self.conv2(x))
        x = self.avgpool2(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        return self.conv6(x)


class Encoder(nn.Module):
    """Autoencoder Encoder model."""

    def __init__(self) -> None:
        super().__init__()
        self.enconv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
        self.enconv2 = nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=1)
        self.enconv3 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.enconv4 = nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1)
        self.enconv5 = nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1)
        self.enconv6 = nn.Conv2d(64, 64, kernel_size=8, stride=1, padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform the forward pass through the network.

        Args:
            x (torch.Tensor): Input batch.

        Returns:
            torch.Tensor: Output from the network.
        """
        x = F.relu(self.enconv1(x))
        x = F.relu(self.enconv2(x))
        x = F.relu(self.enconv3(x))
        x = F.relu(self.enconv4(x))
        x = F.relu(self.enconv5(x))
        return self.enconv6(x)


class Decoder(nn.Module):
    """Autoencoder Decoder model.

    Args:
        out_channels (int): number of convolution output channels
        padding (int): use padding in convoluional layers
    """

    def __init__(self, out_channels: int, padding: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.padding = padding
        # use ceil to match output shape of PDN
        self.deconv1 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
        self.deconv2 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
        self.deconv3 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
        self.deconv4 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
        self.deconv5 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
        self.deconv6 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
        self.deconv7 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.deconv8 = nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
        self.dropout1 = nn.Dropout(p=0.2)
        self.dropout2 = nn.Dropout(p=0.2)
        self.dropout3 = nn.Dropout(p=0.2)
        self.dropout4 = nn.Dropout(p=0.2)
        self.dropout5 = nn.Dropout(p=0.2)
        self.dropout6 = nn.Dropout(p=0.2)

    def forward(self, x: torch.Tensor, image_size: tuple[int, int] | torch.Size) -> torch.Tensor:
        """Perform a forward pass through the network.

        Args:
            x (torch.Tensor): Input batch.
            image_size (tuple): size of input images.

        Returns:
            torch.Tensor: Output from the network.
        """
        last_upsample = (
            math.ceil(image_size[0] / 4) if self.padding else math.ceil(image_size[0] / 4) - 8,
            math.ceil(image_size[1] / 4) if self.padding else math.ceil(image_size[1] / 4) - 8,
        )
        x = F.interpolate(x, size=(image_size[0] // 64 - 1, image_size[1] // 64 - 1), mode="bilinear")
        x = F.relu(self.deconv1(x))
        x = self.dropout1(x)
        x = F.interpolate(x, size=(image_size[0] // 32, image_size[1] // 32), mode="bilinear")
        x = F.relu(self.deconv2(x))
        x = self.dropout2(x)
        x = F.interpolate(x, size=(image_size[0] // 16 - 1, image_size[1] // 16 - 1), mode="bilinear")
        x = F.relu(self.deconv3(x))
        x = self.dropout3(x)
        x = F.interpolate(x, size=(image_size[0] // 8, image_size[1] // 8), mode="bilinear")
        x = F.relu(self.deconv4(x))
        x = self.dropout4(x)
        x = F.interpolate(x, size=(image_size[0] // 4 - 1, image_size[1] // 4 - 1), mode="bilinear")
        x = F.relu(self.deconv5(x))
        x = self.dropout5(x)
        x = F.interpolate(x, size=(image_size[0] // 2 - 1, image_size[1] // 2 - 1), mode="bilinear")
        x = F.relu(self.deconv6(x))
        x = self.dropout6(x)
        x = F.interpolate(x, size=last_upsample, mode="bilinear")
        x = F.relu(self.deconv7(x))
        return self.deconv8(x)


class AutoEncoder(nn.Module):
    """EfficientAd Autoencoder.

    Args:
       out_channels (int): number of convolution output channels
       padding (int): use padding in convoluional layers
    """

    def __init__(self, out_channels: int, padding: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.encoder = Encoder()
        self.decoder = Decoder(out_channels, padding)

    def forward(self, x: torch.Tensor, image_size: tuple[int, int] | torch.Size) -> torch.Tensor:
        """Perform the forward pass through the network.

        Args:
            x (torch.Tensor): Input batch.
            image_size (tuple): size of input images.

        Returns:
            torch.Tensor: Output from the network.
        """
        x = imagenet_norm_batch(x)
        x = self.encoder(x)
        return self.decoder(x, image_size)


class EfficientAdModel(nn.Module):
    """EfficientAd model.

    Args:
        teacher_out_channels (int): number of convolution output channels of the pre-trained teacher model
        model_size (str): size of student and teacher model
        padding (bool): use padding in convoluional layers
            Defaults to ``False``.
        pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the
            output anomaly maps so that their size matches the size in the padding = True case.
            Defaults to ``True``.
    """

    def __init__(
        self,
        teacher_out_channels: int,
        model_size: EfficientAdModelSize = EfficientAdModelSize.S,
        padding: bool = False,
        pad_maps: bool = True,
    ) -> None:
        super().__init__()

        self.pad_maps = pad_maps
        self.teacher: MediumPatchDescriptionNetwork | SmallPatchDescriptionNetwork
        self.student: MediumPatchDescriptionNetwork | SmallPatchDescriptionNetwork

        if model_size == EfficientAdModelSize.M:
            self.teacher = MediumPatchDescriptionNetwork(out_channels=teacher_out_channels, padding=padding).eval()
            self.student = MediumPatchDescriptionNetwork(out_channels=teacher_out_channels * 2, padding=padding)

        elif model_size == EfficientAdModelSize.S:
            self.teacher = SmallPatchDescriptionNetwork(out_channels=teacher_out_channels, padding=padding).eval()
            self.student = SmallPatchDescriptionNetwork(out_channels=teacher_out_channels * 2, padding=padding)

        else:
            msg = f"Unknown model size {model_size}"
            raise ValueError(msg)

        self.ae: AutoEncoder = AutoEncoder(out_channels=teacher_out_channels, padding=padding)
        self.teacher_out_channels: int = teacher_out_channels

        self.mean_std: nn.ParameterDict = nn.ParameterDict(
            {
                "mean": torch.zeros((1, self.teacher_out_channels, 1, 1)),
                "std": torch.zeros((1, self.teacher_out_channels, 1, 1)),
            },
        )

        self.quantiles: nn.ParameterDict = nn.ParameterDict(
            {
                "qa_st": torch.tensor(0.0),
                "qb_st": torch.tensor(0.0),
                "qa_ae": torch.tensor(0.0),
                "qb_ae": torch.tensor(0.0),
            },
        )

    def is_set(self, p_dic: nn.ParameterDict) -> bool:
        """Check if any of the parameters in the parameter dictionary is set.

        Args:
            p_dic (nn.ParameterDict): Parameter dictionary.

        Returns:
            bool: Boolean indicating whether any of the parameters in the parameter dictionary is set.
        """
        return any(value.sum() != 0 for _, value in p_dic.items())

    def choose_random_aug_image(self, image: torch.Tensor) -> torch.Tensor:
        """Choose a random augmentation function and apply it to the input image.

        Args:
            image (torch.Tensor): Input image.

        Returns:
            Tensor: Augmented image.
        """
        transform_functions = [
            transforms.functional.adjust_brightness,
            transforms.functional.adjust_contrast,
            transforms.functional.adjust_saturation,
        ]
        # Sample an augmentation coefficient Î» from the uniform distribution U(0.8, 1.2)
        coefficient = np.random.default_rng().uniform(0.8, 1.2)
        transform_function = np.random.default_rng().choice(transform_functions)
        return transform_function(image, coefficient)

    def forward(
        self,
        batch: torch.Tensor,
        batch_imagenet: torch.Tensor | None = None,
        normalize: bool = True,
    ) -> torch.Tensor | dict:
        """Perform the forward-pass of the EfficientAd models.

        Args:
            batch (torch.Tensor): Input images.
            batch_imagenet (torch.Tensor): ImageNet batch. Defaults to None.
            normalize (bool): Normalize anomaly maps or not

        Returns:
            Tensor: Predictions
        """
        
        image_size = batch.shape[-2:]
        with torch.no_grad():
            teacher_output = self.teacher(batch)
            if self.is_set(self.mean_std):
                teacher_output = (teacher_output - self.mean_std["mean"]) / self.mean_std["std"]

        student_output = self.student(batch)
        distance_st = torch.pow(teacher_output - student_output[:, : self.teacher_out_channels, :, :], 2)

        if self.training:
            # Student loss
            distance_st = reduce_tensor_elems(distance_st)
            d_hard = torch.quantile(distance_st, 0.999)
            loss_hard = torch.mean(distance_st[distance_st >= d_hard])
            student_output_penalty = self.student(batch_imagenet)[:, : self.teacher_out_channels, :, :]
            loss_penalty = torch.mean(student_output_penalty**2)
            loss_st = loss_hard + loss_penalty

            # Autoencoder and Student AE Loss
            aug_img = self.choose_random_aug_image(batch)
            ae_output_aug = self.ae(aug_img, image_size)

            with torch.no_grad():
                teacher_output_aug = self.teacher(aug_img)
                if self.is_set(self.mean_std):
                    teacher_output_aug = (teacher_output_aug - self.mean_std["mean"]) / self.mean_std["std"]

            student_output_ae_aug = self.student(aug_img)[:, self.teacher_out_channels :, :, :]

            distance_ae = torch.pow(teacher_output_aug - ae_output_aug, 2)
            distance_stae = torch.pow(ae_output_aug - student_output_ae_aug, 2)

            loss_ae = torch.mean(distance_ae)
            loss_stae = torch.mean(distance_stae)
            return (loss_st, loss_ae, loss_stae)

        # Eval mode.
        with torch.no_grad():
            ae_output = self.ae(batch, image_size)

            map_st = torch.mean(distance_st, dim=1, keepdim=True)
            map_stae = torch.mean(
                (ae_output - student_output[:, self.teacher_out_channels :]) ** 2,
                dim=1,
                keepdim=True,
            )

        if self.pad_maps:
            map_st = F.pad(map_st, (4, 4, 4, 4))
            map_stae = F.pad(map_stae, (4, 4, 4, 4))
        map_st = F.interpolate(map_st, size=image_size, mode="bilinear")
        map_stae = F.interpolate(map_stae, size=image_size, mode="bilinear")

        if self.is_set(self.quantiles) and normalize:
            map_st = 0.1 * (map_st - self.quantiles["qa_st"]) / (self.quantiles["qb_st"] - self.quantiles["qa_st"])
            map_stae = 0.1 * (map_stae - self.quantiles["qa_ae"]) / (self.quantiles["qb_ae"] - self.quantiles["qa_ae"])

        map_combined = 0.5 * map_st + 0.5 * map_stae
        return {"anomaly_map": map_combined, "map_st": map_st, "map_ae": map_stae}

In [None]:
import tqdm

@torch.no_grad()
def teacher_channel_mean_std(model, dataloader: DataLoader) -> dict[str, torch.Tensor]:
    """Calculate the mean and std of the teacher models activations.

    Adapted from https://math.stackexchange.com/a/2148949

    Args:
        dataloader (DataLoader): Dataloader of the respective dataset.

    Returns:
        dict[str, torch.Tensor]: Dictionary of channel-wise mean and std
    """
    arrays_defined = False
    n: torch.Tensor | None = None
    chanel_sum: torch.Tensor | None = None
    chanel_sum_sqr: torch.Tensor | None = None

    for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean & std", position=0, leave=True):
        y = model.teacher(batch.to(device))
        if not arrays_defined:
            _, num_channels, _, _ = y.shape
            n = torch.zeros((num_channels,), dtype=torch.int64, device=y.device)
            chanel_sum = torch.zeros((num_channels,), dtype=torch.float32, device=y.device)
            chanel_sum_sqr = torch.zeros((num_channels,), dtype=torch.float32, device=y.device)
            arrays_defined = True

        n += y[:, 0].numel()
        chanel_sum += torch.sum(y, dim=[0, 2, 3])
        chanel_sum_sqr += torch.sum(y**2, dim=[0, 2, 3])

    assert n is not None

    channel_mean = chanel_sum / n

    channel_std = (torch.sqrt((chanel_sum_sqr / n) - (channel_mean**2))).float()[None, :, None, None]
    channel_mean = channel_mean.float()[None, :, None, None]

    return {"mean": channel_mean, "std": channel_std}

def _get_quantiles_of_maps(maps: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
    """Calculate 90% and 99.5% quantiles of the given anomaly maps.

    If the total number of elements in the given maps is larger than 16777216
    the returned quantiles are computed on a random subset of the given
    elements.

    Args:
        maps (list[torch.Tensor]): List of anomaly maps.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: Two scalars - the 90% and the 99.5% quantile.
    """
    maps_flat = reduce_tensor_elems(torch.cat(maps))
    qa = torch.quantile(maps_flat, q=0.9).to(device)
    qb = torch.quantile(maps_flat, q=0.995).to(device)
    return qa, qb

@torch.no_grad()
def map_norm_quantiles(model, dataloader: DataLoader) -> dict[str, torch.Tensor]:
    """Calculate 90% and 99.5% quantiles of the student(st) and autoencoder(ae).

    Args:
        dataloader (DataLoader): Dataloader of the respective dataset.

    Returns:
        dict[str, torch.Tensor]: Dictionary of both the 90% and 99.5% quantiles
        of both the student and autoencoder feature maps.
    """
    maps_st = []
    maps_ae = []
    logger.info("Calculate Validation Dataset Quantiles")
    for batch in tqdm.tqdm(dataloader, desc="Calculate Validation Dataset Quantiles", position=0, leave=True):
        for img in batch:
            output = model(img.to(device), normalize=False)
            map_st = output["map_st"]
            map_ae = output["map_ae"]
            maps_st.append(map_st)
            maps_ae.append(map_ae)
                

    qa_st, qb_st = _get_quantiles_of_maps(maps_st)
    qa_ae, qb_ae = _get_quantiles_of_maps(maps_ae)
    return {"qa_st": qa_st, "qa_ae": qa_ae, "qb_st": qb_st, "qb_ae": qb_ae}

## Tools for training

In [None]:
import torch
import torch.optim as optim
from pathlib import Path

def train_model(model, train_loader, optimizer, scheduler, device):
    model.train()
    data_iter = iter(imagenet_loader)
    counter = 0
    total_loss = 0
    for batch_images in train_loader:
        batch_images = batch_images.to(device)
        imagenet_batch = next(data_iter).to(device)
        
        loss_st, loss_ae, loss_stae = model(batch_images, imagenet_batch)
        
        loss = loss_st + loss_ae + loss_stae
        if counter % 10 == 0:
            print(f'Batch {counter}, Training Loss: {loss.item()}')
        counter += 1
        total_loss += loss.item()
        
        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    print(f'Total Training Loss in Epoch: {total_loss}')

def validate_model(model, validation_loader, device):
    model.eval()
    loss = 0
    with torch.no_grad():
        for batch_images in validation_loader:
            batch_images = batch_images.to(device)
            anomaly_maps = model(batch_images, normalize=False)["anomaly_map"]
            score = anomaly_maps.squeeze().mean(dim=(1, 2))
            loss += score.sum().cpu().numpy()
        print(f'Validation Loss: {loss}')
    return loss

## Training

In [None]:
device = torch.device('mps')
use_pretrained = True
pretrained_path = 'efficientad_model_medium_5_3.9978248327970505.pth'

In [None]:
# Parameters
num_epochs = 80
learning_rate = 0.0001
weight_decay=0.00001

model_size = EfficientAdModelSize.M

model = EfficientAdModel(teacher_out_channels=384, model_size=model_size, padding=False, pad_maps=True).to(device)
optimizer = optim.Adam(list(model.student.parameters()) + list(model.ae.parameters()),
            lr=learning_rate,
            weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.95 * num_epochs * len(train_loader)), gamma=0.1)

if use_pretrained:
    model.load_state_dict(torch.load(os.path.join('results', pretrained_path), map_location=torch.device(device)))
else:
    teacher_path = (Path("./pre_trained/") / "efficientad_pretrained_weights" / f"pretrained_teacher_{model_size.value}.pth")
    model.teacher.load_state_dict(torch.load(teacher_path, map_location=torch.device(device)))
    channel_mean_std = teacher_channel_mean_std(model, train_loader)
    model.mean_std.update(channel_mean_std)

In [None]:
best_validation_loss = float('inf')
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    train_model(model, train_loader, optimizer, scheduler, device)
    loss = validate_model(model, combined_validation_loader, device)
    if loss < best_validation_loss:
        best_validation_loss = loss
        torch.save(model.state_dict(), f"efficientad_model_{model_size.value}_{epoch + 1}_{loss}.pth")

torch.save(model.state_dict(), f"efficientad_model_{model_size.value}_{num_epochs}.pth")

## Model testing

In [None]:
from sklearn.metrics import roc_auc_score

def test_model(model, test_loader, device='mps'):
    model.eval()
    map_norm_q = map_norm_quantiles(model, validation_loader)
    model.quantiles.update(map_norm_q)
    y_true = []
    y_score = []
    with torch.no_grad():
        for batch_images in test_loader:
            images = batch_images['image'].to(device)
            anomaly_maps = model(images)["anomaly_map"]
            anomalous_images = [anomaly_map.squeeze().amax(dim=(0, 1)) for anomaly_map in anomaly_maps]
            y_true += [1 if batch_images['anomaly_type'][i] != 'good' else 0 for i in range(len(batch_images['anomaly_type']))]
            y_score += [anomalous_images[i].item() for i in range(len(anomalous_images))]
            # print(f'Batches done: {len(y_true)}')
    auc = roc_auc_score(y_true, y_score)
    return auc * 100

### Test a single model

In [None]:
test_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='breakfast_box', phase='test', transform=transform, anomaly_types=['logical_anomalies', 'good', 'structural_anomalies'])
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
model_path = os.path.join('results', 'efficientad_model_medium_2_3.864782866090536.pth')
model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
model.to(device)
model.eval()
accuracy = test_model(model, test_loader, device=device)
print(f'AUC: {accuracy}')

### Test the results folder

In [None]:
model_list = os.listdir('results')
model_list.sort()
# model_list = model_list[:-2]

for model_path_idx in range(0, len(model_list)):
    model_path = model_list[model_path_idx]
    try:
        if os.path.isdir(model_path):
            continue
        model.load_state_dict(torch.load(os.path.join('results', model_path), map_location=torch.device(device)))
        model.to(device)
        model.eval()
        accuracy = test_model(model, test_loader, device=device)
        print(f'Model: {model_path}, AUC: {accuracy}')
    except Exception as e:
        print(f'Error in model {model_path}: {e}')

### Test structural anomalies

In [None]:
test_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='screw_bag', phase='test', transform=transform, anomaly_types=['structural_anomalies', 'good'])
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
model_path = os.path.join('results', 'efficientad_model_medium_28_2.9415077567100525.pth')
model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
model.to(device)
model.eval()
accuracy = test_model(model, test_loader, device=device)
print(f'AUC: {accuracy}')

### Test logical anomalies

In [None]:
test_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='screw_bag', phase='test', transform=transform, anomaly_types=['logical_anomalies', 'good'])
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
model_path = os.path.join('results', 'efficientad_model_medium_28_2.9415077567100525.pth')
model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
model.to(device)
model.eval()
accuracy = test_model(model, test_loader, device=device)
print(f'AUC: {accuracy}')

# CLIP model

In [None]:
import clip

device = torch.device('mps')

clip_model, preprocess = clip.load("ViT-L/14@336px", device=device)
# clip_model, preprocess = clip.load("ViT-L/14@336px", device=device)

In [None]:
clip_train_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='screw_bag', phase='train', transform=preprocess)
clip_train_loader = DataLoader(dataset=clip_train_dataset, batch_size=batch_size, shuffle=True)

clip_validation_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='screw_bag', phase='validation', transform=preprocess)
clip_validation_loader = DataLoader(dataset=clip_validation_dataset, batch_size=batch_size, shuffle=True)

clip_test_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='screw_bag', phase='test', transform=transform, clip_transform=preprocess)
clip_test_loader = DataLoader(dataset=clip_test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def normalize_with_validation(scores, mean_val, std_val, ratio=0.05):
    normalized_scores = (scores - mean_val) / std_val
    normalized_scores = np.where(normalized_scores < 0, 0, normalized_scores)
    return normalized_scores * ratio

## CLIP method 1

In [None]:
def compute_mean_and_cov(model, dataloader):
    embeddings_list = []
    
    with torch.no_grad():
        for batch in dataloader:
            images = batch.to(device)

            embeddings = model.encode_image(images).cpu().numpy()
            
            embeddings_list.extend(embeddings)

            print(f'Images done: {len(embeddings_list)}')
    
    mean_embedding = np.mean(embeddings_list, axis=0)
    
    cov_embedding = np.cov(embeddings_list, rowvar=False)
    
    return mean_embedding, cov_embedding

mean_embedding, cov_embedding = compute_mean_and_cov(clip_model, clip_train_loader)

In [None]:
def compute_mahalanobis_distance(embedding, mean_embedding, cov_embedding):
    diff = embedding - mean_embedding
    epsilon = 1e-5
    regularized_cov = cov_embedding + np.eye(cov_embedding.shape[0]) * epsilon
    inv_cov = np.linalg.inv(regularized_cov)
    dist = np.sqrt(np.dot(np.dot(diff, inv_cov), diff))
    return dist

def compute_distance_mean_std(model, dataloader, mean_embedding, cov_embedding):
    distances = []
    
    with torch.no_grad():
        for batch in dataloader:
            images = batch.to(device)

            embeddings = model.encode_image(images).cpu().numpy()
            for embedding in embeddings:
                dist = compute_mahalanobis_distance(embedding, mean_embedding, cov_embedding)
                distances.append(dist)
    
    mean = np.mean(distances)
    std = np.std(distances)
    
    return mean, std

def compute_mahalanobis_distances(test_loader, mean_embedding, cov_embedding):
    distances = []
    
    with torch.no_grad():
        for batch in test_loader:
            images_preprocessed = torch.stack([preprocess(Image.fromarray(img.numpy())) for img in batch]).to(device)
            
            embeddings = model.encode_image(images_preprocessed).cpu().numpy()
            
            for embedding in embeddings:
                dist = compute_mahalanobis_distance(embedding, mean_embedding, cov_embedding)
                distances.append(dist)
    
    return distances

In [None]:
def CLIP_method1_compute_parameters(model, validation_loader, train_loader):
    mean_embedding, cov_embedding = compute_mean_and_cov(model, train_loader)
    return compute_distance_mean_std(model, validation_loader, mean_embedding, cov_embedding)

def CLIP_method1(model, clip_images, parameters):
    distances_mean, distances_std = parameters
    embeddings = model.encode_image(clip_images).cpu().numpy()
    distances = [compute_mahalanobis_distance(embedding, mean_embedding, cov_embedding) for embedding in embeddings]
    normalized_distances = normalize_with_validation(distances, distances_mean, distances_std)
    return normalized_distances

## CLIP method 2

In [None]:
from sklearn.mixture import GaussianMixture

def CLIP_method2_compute_parameters(model, train_loader, validation_loader, optimal_n_components=1):
    train_embeddings = []
    with torch.no_grad():
        for batch in train_loader:
            images = batch.to(device)
            embeddings = model.encode_image(images).cpu().numpy()
            train_embeddings.extend(embeddings)
    validation_embeddings = []
    with torch.no_grad():
        for batch in validation_loader:
            images = batch.to(device)
            embeddings = model.encode_image(images).cpu().numpy()
            validation_embeddings.extend(embeddings)

    gmm = GaussianMixture(n_components=optimal_n_components)
    gmm.fit(train_embeddings)

    anomaly_scores = -gmm.score_samples(validation_embeddings)
    
    std = np.std(anomaly_scores)
    mean = np.mean(anomaly_scores)

    return gmm, std, mean

def CLIP_method2(model, images, parameters, ratio = 0.05):
    gmm, std, mean = parameters
    with torch.no_grad():
        embeddings = model.encode_image(images).cpu().numpy()
    anomaly_scores = -gmm.score_samples(embeddings)
    normalized_scores = normalize_with_validation(anomaly_scores, mean, std, ratio)
    return normalized_scores

## Test CLIP

In [None]:
from sklearn.metrics import roc_auc_score

def test_clip_model(clip_model, test_loader, device='mps'):
    parameters = CLIP_method2_compute_parameters(clip_model, clip_train_loader, clip_validation_loader)
    y_true = []
    y_score = []
    with torch.no_grad():
        for batch_images in test_loader:
            clip_images = batch_images['clip_image'].to(device)
            clip_scores = CLIP_method2(clip_model, clip_images, parameters)
            print(f"Anomality scores CLIP: {clip_scores}")
            print(f'Anomaly type: {batch_images["anomaly_type"]}')

            y_true += [1 if batch_images['anomaly_type'][i] != 'good' else 0 for i in range(len(batch_images['anomaly_type']))]
            print(f"y_true: {[1 if batch_images['anomaly_type'][i] != 'good' else 0 for i in range(len(batch_images['anomaly_type']))]}")
            
            y_score += list(clip_scores)
            
            print(f'Batches done: {len(y_true)}\n\n')
    auc = roc_auc_score(y_true, y_score)
    return auc * 100

In [None]:
accuracy = test_clip_model(clip_model, clip_test_loader, device=device)
accuracy

## Test full architecture

In [None]:
#run after dataset changes

map_norm_q = map_norm_quantiles(model, validation_loader)
model.quantiles.update(map_norm_q)
parameters = CLIP_method2_compute_parameters(clip_model, clip_train_loader, clip_validation_loader)

In [None]:
from sklearn.metrics import roc_auc_score

def test_full_model(efficient_model, clip_model, test_loader, device='mps', ratio = 0.05, verbose=True):
    efficient_model.eval()
    y_true = []
    y_score = []
    counter = 0
    with torch.no_grad():
        for batch_images in test_loader:
            counter += 1
            images = batch_images['image'].to(device)
            anomaly_maps = efficient_model(images)["anomaly_map"]
            anomaly_scores_EAD = [anomaly_map.squeeze().amax(dim=(0, 1)).item() for anomaly_map in anomaly_maps]
            anomaly_scores_EAD = [anomaly_scores_EAD[i] if anomaly_scores_EAD[i] > 0 else 0 for i in range(len(anomaly_scores_EAD))]

            clip_images = batch_images['clip_image'].to(device)
            clip_scores = CLIP_method2(clip_model, clip_images, parameters, ratio=ratio)
            
            if verbose:
                print(f"Anomality scores EffecientAD: {anomaly_scores_EAD}")
                print(f"Anomality scores CLIP: {clip_scores}")
                print(f'Anomaly type: {batch_images["anomaly_type"]}')
                print(f"y_true: {[1 if batch_images['anomaly_type'][i] != 'good' else 0 for i in range(len(batch_images['anomaly_type']))]}")
                print(f'Batches done: {len(y_true)}\n\n')

            y_true += [1 if batch_images['anomaly_type'][i] != 'good' else 0 for i in range(len(batch_images['anomaly_type']))]
            y_score += [anomaly_scores_EAD[i] + clip_scores[i] for i in range(len(anomaly_scores_EAD))]
    auc = roc_auc_score(y_true, y_score)
    return auc * 100

In [None]:
test_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='screw_bag', phase='test', transform=transform, anomaly_types=['logical_anomalies', 'good', 'structural_anomalies'], clip_transform=preprocess)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
model_size = EfficientAdModelSize.M

model = EfficientAdModel(teacher_out_channels=384, model_size=model_size, padding=False, pad_maps=True).to(device)
model.load_state_dict(torch.load(os.path.join('results', 'screw_bag', 'efficientad_model_medium_screw_bag.pth'), map_location=torch.device(device)))
model.to(device)
model.eval()

In [None]:
accuracy = test_full_model(model, clip_model, test_loader, device=device)
accuracy

In [None]:
for ratio in [0.03, 0.04, 0.5, 0.06, 0.07, 0.08, 0.09, 0.1]:
    accuracy = test_full_model(model, clip_model, test_loader, device=device, ratio=ratio, verbose=False)
    print(f'Ratio: {ratio}, AUC: {accuracy}')

In [None]:
test_dataset = MVTecDataset(root_dir='mvtec_loco_anomaly_detection', category='screw_bag', phase='test', transform=transform, anomaly_types=['good', 'logical_anomalies'], clip_transform=preprocess)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
map_norm_q = map_norm_quantiles(model, validation_loader)
model.quantiles.update(map_norm_q)

In [None]:
model.eval()

img_idx = 331

dataset = test_dataset

image = dataset[img_idx]['image']

with torch.no_grad():
    print(f"Image path {dataset.data[img_idx]}")
    img = torch.unsqueeze(image, 0).to(device)
    outputs = model(img)
    map = outputs['anomaly_map']
    map_st = outputs['map_st']
    map_ae = outputs['map_ae']
    map = map.squeeze().cpu().numpy()
    map_st = map_st.squeeze().cpu().numpy()
    map_ae = map_ae.squeeze().cpu().numpy()

map = np.where(map < 0.10, 0, map)
map_st = np.where(map_st < 0.10, 0, map_st)
map_ae = np.where(map < 0.10, 0, map_ae)

print(np.max(map), np.min(map))
print(np.max(map_st), np.min(map_st))
print(np.max(map_ae), np.min(map_ae))

import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(image.permute(1, 2, 0))
ax.imshow(map, cmap='hot', alpha=0.5)
plt.show()

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(image.permute(1, 2, 0))
ax.imshow(map_st, cmap='hot', alpha=0.5)
plt.show()

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(image.permute(1, 2, 0))
ax.imshow(map_ae, cmap='hot', alpha=0.5)
plt.show()