# train mvtec

i want to have a simpler script to integrate to wandb and later adapt it to unetdd

# consts

In [1]:
from copy import deepcopy
from typing import List
from fcdd.datasets.bases import TorchvisionDataset
from fcdd.datasets.cifar import ADCIFAR10
from fcdd.datasets.fmnist import ADFMNIST
from fcdd.datasets.imagenet import ADImageNet
from fcdd.datasets.mvtec import ADMvTec
from fcdd.datasets.pascal_voc import ADPascalVoc
from fcdd.datasets.image_folder import ADImageFolderDataset
from fcdd.datasets.image_folder_gtms import ADImageFolderDatasetGTM

DS_CHOICES = ('mnist', 'cifar10', 'fmnist', 'mvtec', 'imagenet', 'pascalvoc')
PREPROC_CHOICES = (
    'lcn', 'lcnaug1', 'aug1', 'aug1_blackcenter', 'aug1_blackcenter_inverted', 'none'
)
SUPERVISE_MODES = ('unsupervised', 'other', 'noise', 'malformed_normal', 'malformed_normal_gt')
NOISE_MODES = [
    'gaussian', 'uniform', 'blob', 'mixed_blob', 'solid', 'confetti',  # Synthetic Anomalies
    'imagenet', 'imagenet22k', 'cifar100', 'emnist',  # Outlier Exposure
    'mvtec', 'mvtec_gt'  # Outlier Exposure online supervision only
]

def str_labels(dataset_name: str) -> List[str]:
    return {
        'cifar10': [
            'airplane', 'automobile', 'bird', 'cat', 'deer', 
            'dog', 'frog', 'horse', 'ship', 'truck'
        ],
        'fmnist': [
            't-shirt/top', 'trouser', 'pullover', 'dress', 'coat', 
            'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'
        ],
        'mvtec': [
            'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather',
            'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor',
            'wood', 'zipper'
        ],
        'imagenet': deepcopy(ADImageNet.ad_classes),
        'pascalvoc': ['horse'],
    }[dataset_name]


def no_classes(dataset_name: str) -> int:
    return len(str_labels[dataset_name])

# args

In [2]:
import os.path as pt
from argparse import ArgumentParser

import numpy as np
from fcdd.models import choices


def default_parser_config(parser: ArgumentParser) -> ArgumentParser:
    """
    Defines all the arguments for running an FCDD experiment.
    :param parser: instance of an ArgumentParser.
    :return: the parser with added arguments
    """

    # define directories for datasets and logging
    parser.add_argument(
        '--logdir', type=str, default=pt.join('..', '..', 'data', 'results', 'fcdd_{t}'),
        help='Directory where log data is to be stored. The pattern {t} is replaced by the start time. '
             'Defaults to ../../data/results/fcdd_{t}. '
    )
    parser.add_argument(
        '--logdir-suffix', type=str, default='',
        help='String suffix for log directory, again {t} is replaced by the start time. '
    )
    parser.add_argument(
        '--datadir', type=str, default=pt.join('..', '..', 'data', 'datasets'),
        help='Directory where datasets are found or to be downloaded to. Defaults to ../../data/datasets.',
    )
    parser.add_argument(
        '--readme', type=str, default='',
        help='Some notes to be stored in the automatically created config.txt configuration file.'
    )

    # training parameters
    parser.add_argument('-b', '--batch-size', type=int, default=128)
    parser.add_argument('-e', '--epochs', type=int, default=200)
    parser.add_argument('-w', '--workers', type=int, default=4)
    parser.add_argument('-lr', '--learning_rate', type=float, default=1e-3)
    parser.add_argument('-wd', '--weight-decay', type=float, default=1e-6)
    parser.add_argument(
        '--optimizer-type', type=str, default='sgd', choices=['sgd', 'adam'],
        help='The type of optimizer. Defaults to "sgd". '
    )
    parser.add_argument(
        '--scheduler-type', type=str, default='lambda', choices=['lambda', 'milestones'],
        help='The type of learning rate scheduler. Either "lambda", which reduces the learning rate each epoch '
             'by a certain factor, or "milestones", which sets the learning rate to certain values at certain '
             'epochs. Defaults to "lambda"'
    )
    parser.add_argument(
        '--lr-sched-param', type=float, nargs='*', default=[0.985],
        help='Sequence of learning rate scheduler parameters. '
             'For the "lambda" scheduler, just one parameter is allowed, '
             'which sets the factor the learning rate is reduced per epoch. '
             'For the "milestones" scheduler, at least two parameters are needed, '
             'the first determining the factor by which the learning rate is reduced at each milestone, '
             'and the others being each a milestone. For instance, "0.1 100 200 300" reduces the learning rate '
             'by 0.1 at epoch 100, 200, and 300. '
    )
    parser.add_argument(
        '--load', type=str, default=None,
        help='Path to a file that contains a snapshot of the network model. '
             'When given, the network loads the found weights and state of the training. '
             'If epochs are left to be trained, the training is continued. '
             'Note that only one snapshot is given, thus using a runner that trains for multiple different classes '
             'to be nominal is not applicable. '
    )
    parser.add_argument('-d', '--dataset', type=str, default='custom', choices=DS_CHOICES)
    parser.add_argument(
        '-n', '--net', type=str, default='FCDD_CNN224_VGG_F', choices=choices(),
        help='Chooses a network architecture to train.'
    )
    parser.add_argument(
        '--preproc', type=str, default='aug1', choices=PREPROC_CHOICES,
        help='Determines the kind of preprocessing pipeline (augmentations and such). '
             'Have a look at the code (dataset implementation, e.g. fcdd.datasets.cifar.py) for details.'
    )
    parser.add_argument(
        '--acc-batches', type=int, default=1,
        help='To speed up data loading, '
             'this determines the number of batches that are accumulated to be used for training. '
             'For instance, acc_batches=2 iterates the data loader two times, concatenates the batches, and '
             'passes the result to the further training procedure. This has no impact on the performance '
             'if the batch size is reduced accordingly (e.g. one half in this example), '
             'but can decrease training time. '
    )
    parser.add_argument('--no-bias', dest='bias', action='store_false', help='Uses no bias in network layers.')
    parser.add_argument('--cpu', dest='cuda', action='store_false', help='Trains on CPU only.')

    # artificial anomaly settings
    parser.add_argument(
        '--supervise-mode', type=str, default='noise', choices=SUPERVISE_MODES,
        help='This determines the kind of artificial anomalies. '
             '"unsupervised" uses no anomalies at all. '
             '"other" uses ground-truth anomalies. '
             '"noise" uses pure noise images or Outlier Exposure. '
             '"malformed_normal" adds noise to nominal images to create malformed nominal anomalies. '
             '"malformed_normal_gt" is like malformed_normal, but with ground-truth anomaly heatmaps for training. '
    )
    parser.add_argument(
        '--noise-mode', type=str, default='imagenet22k', choices=NOISE_MODES,
        help='The type of noise used when artificial anomalies are activated. Dataset names refer to OE. '
             'See fcdd.datasets.noise_modes.py.'
    )
    parser.add_argument(
        '--oe-limit', type=int, default=np.infty,
        help='Determines the amount of different samples used for Outlier Exposure. '
             'Has no impact on synthetic anomalies.'
    )
    parser.add_argument(
        '--offline-supervision', dest='online_supervision', action='store_false',
        help='Instead of sampling artificial anomalies during training by having a 50%% chance to '
             'replace nominal samples, this mode samples them once at the start of the training and adds them to '
             'the training set. '
             'This yields less performance and higher RAM utilization, but reduces the training time. '
    )
    parser.add_argument(
        '--nominal-label', type=int, default=0,
        help='Determines the label that marks nominal samples. '
             'Note that this is not the class that is considered nominal! '
             'For instance, class 5 is the nominal class, which is labeled with the nominal label 0.'
    )

    # heatmap generation parameters
    parser.add_argument(
        '--blur-heatmaps', dest='blur_heatmaps', action='store_true',
        help='Blurs heatmaps, like done for the explanation baseline experiments in the paper.'
    )
    parser.add_argument(
        '--gauss-std', type=float, default=10,
        help='Sets a constant value for the standard deviation of the Gaussian kernel used for upsampling and '
             'blurring.'
    )
    parser.add_argument(
        '--quantile', type=float, default=0.97,
        help='The quantile that is used to normalize the generated heatmap images. '
             'This is explained in the Appendix of the paper.'
    )
    parser.add_argument(
        '--resdown', type=int, default=64,
        help='Sets the maximum resolution of logged images (per heatmap), images will be downsampled '
             'if they exceed this threshold. For instance, resdown=64 makes every image of heatmaps contain '
             'individual heatmaps and inputs of width 64 and height 64 at most.'
    )
    parser.add_argument(
        '--no-test', dest="test", action="store_false",
        help='If set then the model will not be tested at the end of the training. It will by default.'
    )
    return parser


def default_parser_config_mvtec(parser: ArgumentParser) -> ArgumentParser:
    parser = default_parser_config(parser)
    parser.set_defaults(
        batch_size=16, 
        acc_batches=8, 
        supervise_mode='malformed_normal',
        gauss_std=12, 
        weight_decay=1e-4, 
        epochs=200, 
        preproc='lcnaug1',
        quantile=0.99, 
        net='FCDD_CNN224_VGG_F', 
        dataset='mvtec', 
        noise_mode='confetti',
    )

    parser.add_argument(
        '--it', type=int, default=5, 
        help='Number of runs per class with different random seeds.')
    parser.add_argument(
        '--cls-restrictions', type=int, nargs='+', default=None,
        help='Run only training sessions for some of the classes being nominal.'
    )
    return parser


In [3]:
parser = ArgumentParser(
    description="""
    Train a neural network module as explained in the `Explainable Deep Anomaly Detection` paper.
    Train FCDD, and log achieved scores, metrics, plots, and heatmaps
    for both test and training data. 
    """
)
parser = default_parser_config_mvtec(parser)

In [4]:
import time
from datetime import datetime
from pathlib import Path


def time_format(i: float) -> str:
    """ takes a timestamp (seconds since epoch) and transforms that into a datetime string representation """
    return datetime.fromtimestamp(i).strftime('%Y%m%d%H%M%S')


def args_post_parse(args_):
    
    args_.logdir = Path(args_.logdir)
    logdir_name = args_.logdir.name
    
    # it is duplicated for compatibility with setup_trainer
    args_.log_start_time = int(time.time())
    args_.log_start_time_str = time_format(args_.log_start_time)
    
    logdir_name = f"{args_.dataset}_" + logdir_name
    
    if 'logdir_suffix' in vars(args_):
        logdir_name += args_.logdir_suffix
        del vars(args_)['logdir_suffix']
        
    logdir_name = logdir_name.replace('{t}', args_.log_start_time_str)
    
    args_.logdir = args_.logdir.parent / logdir_name
            
    return args_

# setup

In [5]:
from collections import namedtuple

TrainSetup = namedtuple(
    "TrainSetup",
    [
        "net",
        "dataset_loaders",
        "opt",
        "sched",
        "logger",
        "device",
        "quantile",
        "resdown",
        "gauss_std",
        "blur_heatmaps",
    ]
)

def trainer_setup(
    dataset: str, 
    datadir: str, 
    logdir: str, 
    net: str, 
    bias: bool,
    learning_rate: float, 
    weight_decay: float, 
    lr_sched_param: List[float], 
    batch_size: int,
    optimizer_type: str, 
    scheduler_type: str,
    preproc: str, 
    supervise_mode: str, 
    nominal_label: int,
    online_supervision: bool, 
    oe_limit: int, 
    noise_mode: str,
    workers: int, 
    quantile: float, 
    resdown: int, 
    gauss_std: float, 
    blur_heatmaps: bool,
    cuda: bool, 
    config: str, 
    log_start_time: int = None, 
    normal_class: int = 0,
) -> TrainSetup:
    """
    Creates a complete setup for training, given all necessary parameter from a runner (seefcdd.runners.bases.py).
    This includes loading networks, datasets, data loaders, optimizers, and learning rate schedulers.
    :param dataset: dataset identifier string (see :data:`fcdd.datasets.DS_CHOICES`).
    :param datadir: directory where the datasets are found or to be downloaded to.
    :param logdir: directory where log data is to be stored.
    :param net: network model identifier string (see :func:`fcdd.models.choices`).
    :param bias: whether to use bias in the network layers.
    :param learning_rate: initial learning rate.
    :param weight_decay: weight decay (L2 penalty) regularizer.
    :param lr_sched_param: learning rate scheduler parameters. Format depends on the scheduler type.
        For 'milestones' needs to have at least two elements, the first corresponding to the factor
        the learning rate is decreased by at each milestone, the rest corresponding to milestones (epochs).
        For 'lambda' needs to have exactly one element, i.e. the factor the learning rate is decreased by
        at each epoch.
    :param batch_size: batch size, i.e. number of data samples that are returned per iteration of the data loader.
    :param optimizer_type: optimizer type, needs to be one of {'sgd', 'adam'}.
    :param scheduler_type: learning rate scheduler type, needs to be one of {'lambda', 'milestones'}.
    :param preproc: data preprocessing pipeline identifier string (see :data:`fcdd.datasets.PREPROC_CHOICES`).
    :param supervise_mode: the type of generated artificial anomalies.
        See :meth:`fcdd.datasets.bases.TorchvisionDataset._generate_artificial_anomalies_train_set`.
    :param nominal_label: the label that is to be returned to mark nominal samples.
    :param online_supervision: whether to sample anomalies online in each epoch,
        or offline before training (same for all epochs in this case).
    :param oe_limit: limits the number of different anomalies in case of Outlier Exposure (defined in noise_mode).
    :param noise_mode: the type of noise used, see :mod:`fcdd.datasets.noise_mode`.
    :param workers: how many subprocesses to use for data loading.
    :param quantile: the quantile that is used to normalize the generated heatmap images.
    :param resdown: the maximum resolution of logged images, images will be downsampled if necessary.
    :param gauss_std: a constant value for the standard deviation of the Gaussian kernel used for upsampling and
        blurring, the default value is determined by :func:`fcdd.datasets.noise.kernel_size_to_std`.
    :param blur_heatmaps: whether to blur heatmaps.
    :param cuda: whether to use GPU.
    :param config: some config text that is to be stored in the config.txt file.
    :param log_start_time: the start time of the experiment.
    :param normal_class: the class that is to be considered nominal.
    :return: a dictionary containing all necessary parameters to be passed to a Trainer instance.
    """
    assert supervise_mode in SUPERVISE_MODES, 'unknown supervise mode: {}'.format(supervise_mode)
    assert noise_mode in NOISE_MODES, 'unknown noise mode: {}'.format(noise_mode)
    
    device = torch.device('cuda:0') if cuda else torch.device('cpu')
    
    logger = Logger(
        logdir=logdir, 
        exp_start_time=log_start_time,
    )
    
    ds = load_dataset(
        dataset_name=dataset,
        data_path=datadir,
        normal_class=normal_class,
        preproc=preproc,   
        supervise_mode=supervise_mode,
        noise_mode=noise_mode,
        online_supervision=online_supervision,
        nominal_label=nominal_label,
        oe_limit=oe_limit,
        logger=logger,
    )
    
    loaders = ds.loaders(
        batch_size=batch_size, 
        num_workers=workers
    )
    
    net = load_nets(name=net, in_shape=ds.shape, bias=bias)
    net = net.to(device)

    optimizer, scheduler = pick_opt_sched(
        net=net, 
        lr=learning_rate, 
        wdk=weight_decay, 
        sched_params=lr_sched_param, 
        opt=optimizer_type, 
        sched=scheduler_type,
    )
    
    logger.save_params(net, config)
    
    if not hasattr(ds, 'nominal_label') or ds.nominal_label < ds.anomalous_label:
        ds_order = ['norm', 'anom']
    else:
        ds_order = ['anom', 'norm']
        
    images = ds.preview(percls=20, train=True)
    
    rowheaders = (
        ds_order 
        if not isinstance(ds.train_set, GTMapADDataset) else 
        [*ds_order, '', *['gtno' if s == 'norm' else 'gtan' for s in ds_order]]
    )
        
    logger.imsave(
        name='ds_preview', 
        tensors=torch.cat([*images]), 
        nrow=images.size(1),
        rowheaders=rowheaders,
    )
    
    return TrainSetup(
        net=net, 
        dataset_loaders=loaders, 
        opt=optimizer, 
        sched=scheduler, 
        logger=logger,
        device=device, 
        quantile=quantile, 
        resdown=resdown,
        gauss_std=gauss_std, 
        blur_heatmaps=blur_heatmaps,
    )

# trainer

## BaseADTrainer

In [13]:
import collections
import os.path as pt
from abc import abstractmethod, ABC
from typing import List, Tuple
from torch import Tensor
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import collections
import os.path as pt
from abc import abstractmethod, ABC
from typing import List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from fcdd.datasets.bases import GTMapADDataset
from fcdd.datasets.noise import kernel_size_to_std
from fcdd.models.bases import BaseNet, ReceptiveNet
from fcdd.training import balance_labels
from fcdd.util.logging import colorize as colorize_img, Logger
from kornia import gaussian_blur2d
from sklearn.metrics import roc_auc_score, roc_curve
from torch import Tensor
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset



class BaseADTrainer(ABC):
    
    def __init__(
        self, 
        net: BaseNet, 
        opt: Optimizer, 
        sched: _LRScheduler, 
        dataset_loaders: Tuple[DataLoader, DataLoader],
        logger: Logger, 
        objective: str, 
        gauss_std: float, 
        quantile: float, 
        resdown: int, 
        blur_heatmaps=False,
        device='cuda:0',
        **kwargs
    ):
        """
        Anomaly detection trainer that defines a test phase where scores are computed and heatmaps are generated.
        The train method is modified to be able to handle ground-truth maps.
        :param net: some neural network instance
        :param opt: optimizer.
        :param sched: learning rate scheduler.
        :param dataset_loaders:
        :param logger: some logger.
        :param device: some torch device, either cpu or gpu.
        :param gauss_std: a constant value for the standard deviation of the Gaussian kernel used for upsampling and
            blurring, the default value is determined by :func:`fcdd.datasets.noise.kernel_size_to_std`.
        :param quantile: the quantile that is used to normalize the generated heatmap images.
        :param resdown: the maximum resolution of logged images, images will be downsampled if necessary.
        :param blur_heatmaps: whether to blur heatmaps.
        """
        self.net = net
        self.opt = opt
        self.sched = sched
        self.train_loader, self.test_loader = dataset_loaders
        self.logger = logger
        self.device = device
        self.objective = objective
        self.gauss_std = gauss_std
        self.quantile = quantile
        self.resdown = resdown
        self.blur_heatmaps = blur_heatmaps
        
    @abstractmethod
    def loss(self, outs: Tensor, ins: Tensor, labels: Tensor, gtmaps: Tensor = None, reduce='mean'):
        pass
        
    def load(self, path: str, cpu=False) -> int:
        """ Loads a snapshot of the training state, including network weights """
        if cpu:
            snapshot = torch.load(path, map_location=torch.device('cpu'))
        else:
            snapshot = torch.load(path)
        net_state = snapshot.pop('net', None)
        opt_state = snapshot.pop('opt', None)
        sched_state = snapshot.pop('sched', None)
        epoch = snapshot.pop('epoch', None)
        if net_state is not None and self.net is not None:
            self.net.load_state_dict(net_state)
        if opt_state is not None and self.opt is not None:
            self.opt.load_state_dict(opt_state)
        if sched_state is not None and self.sched is not None:
            self.sched.load_state_dict(sched_state)
        print('Loaded {}{}{} with starting epoch {} for {}'.format(
            'net_state, ' if net_state else '', 'opt_state, ' if opt_state else '',
            'sched_state' if sched_state else '', epoch, str(self.__class__)[8:-2]
        ))
        return epoch

    def anomaly_score(self, loss: Tensor) -> Tensor:
        """ This assumes the loss is already the anomaly score. If this is not the case, reimplement the method! """
        return loss

    def reduce_ascore(self, ascore: Tensor) -> Tensor:
        """ Reduces the anomaly score to be a score per image (detection). """
        return ascore.reshape(ascore.size(0), -1).mean(1)

    def reduce_pixelwise_ascore(self, ascore: Tensor) -> Tensor:
        """ Reduces the anomaly score to be a score per pixel (explanation). """
        return ascore.mean(1).unsqueeze(1)

    def train(self, epochs: int, acc_batches=1, wandb=None) -> BaseNet:
        """
        Does epochs many full iteration of the data loader and trains the network with the data using self.loss.
        Supports ground-truth maps, logs losses for
        nominal and anomalous samples separately, and introduces another parameter to
        accumulate batches for faster data loading.
        :param epochs: number of full data loader iterations to train.
        :param acc_batches: To speed up data loading, this determines the number of batches that are accumulated
            before forwarded through the network. For instance, acc_batches=2 iterates the data loader two times,
            concatenates the batches, and passes this to the network. This has no impact on the performance
            if the batch size is reduced accordingly (e.g. one half in this example), but can decrease training time.
        :return: the trained network
        """
        
        assert 0 < acc_batches and isinstance(acc_batches, int)
        
        self.net = self.net.to(self.device).train()
        
        for epoch in range(epochs):
            
            acc_data, acc_counter = [], 1
            
            for n_batch, data in enumerate(self.train_loader):
                
                if acc_counter < acc_batches and n_batch < len(self.train_loader) - 1:
                    acc_data.append(data)
                    acc_counter += 1
                    continue
                elif acc_batches > 1:
                    acc_data.append(data)
                    data = [torch.cat(d) for d in zip(*acc_data)]
                    acc_data, acc_counter = [], 1

                if isinstance(self.train_loader.dataset, GTMapADDataset):
                    inputs, labels, gtmaps = data
                    gtmaps = gtmaps.to(self.device)
                else:
                    inputs, labels = data
                    gtmaps = None
                    
                inputs = inputs.to(self.device)
                self.opt.zero_grad()
                outputs = self.net(inputs)
                loss = self.loss(outputs, inputs, labels, gtmaps)
                loss.backward()
                self.opt.step()
                with torch.no_grad():
                    info = {}
                    if len(set(labels.tolist())) > 1:
                        swloss = self.loss(outputs, inputs, labels, gtmaps, reduce='none')
                        swloss = swloss.reshape(swloss.size(0), -1).mean(-1)
                        info = {'err_normal': swloss[labels == 0].mean(),
                                'err_anomalous': swloss[labels != 0].mean()}
                    self.logger.log(
                        epoch, 
                        n_batch, 
                        len(self.train_loader), 
                        loss,
                        infoprint='LR {} ID {}{}'.format(
                            ['{:.0e}'.format(p['lr']) for p in self.opt.param_groups],
                            str(self.__class__)[8:-2],
                            ' NCLS {}'.format(self.train_loader.dataset.normal_classes)
                            if hasattr(self.train_loader.dataset, 'normal_classes') else ''
                        ),
                        info=info
                    )
                    if wandb is not None:
                        wandb.log(dict(
                            epoch=epoch,
                            epoch_percent=epoch / epochs,
                            n_batch=n_batch,
                            n_batch_percent=n_batch / len(self.train_loader),
                            loss=loss.data.item(),
                        ))
            self.sched.step()



        return self.net

    def test(self, specific_viz_ids: Tuple[List[int], List[int]] = (), train_data=True, subdir='.') -> dict:
        """
        Does a full iteration of the data loaders, remembers all data (i.e. inputs, labels, outputs, loss),
        and computes scores and heatmaps with it. Scores and heatmaps are computed for both, the training
        and the test data. For each, one heatmap picture is generated that contains (row-wise):
            -   The first 20 nominal samples (label == 0, if nominal_label==1 this shows anomalies instead).
            -   The first 20 anomalous samples (label == 1, if nominal_label==1 this shows nominal samples instead).
                The :func:`reorder` takes care that the first anomalous test samples are not all from the same class.
            -   The 10 most nominal rated samples from the nominal set on the left and
                the 10 most anomalous rated samples from the nominal set on the right.
            -   The 10 most nominal rated samples from the anomalous set on the left and
                the 10 most anomalous  rated samples from the anomalous set on the right.
        Additionally, for the test set only, four heatmap pictures are generated that show six samples with
        increasing anomaly score from left to right. Thereby the leftmost heatmap shows the most nominal rated example
        and the rightmost sample the most anomalous rated one. There are two heatmaps for the anomalous set and
        two heatmaps for the nominal set. Both with either local normalization -- i.e. each heatmap is normalized
        w.r.t itself only, there is a complete red and complete blue pixel in each heatmap -- or semi-global
        normalization -- each heatmap is normalized w.r.t. to all heatmaps shown in the picture.
        These four heatmap pictures are also stored as tensors in a 'tim' subdirectory for later usage.
        The score computes AUC values and complete ROC curves for detection. It also computes explanation ROC curves
        if ground-truth maps are available.

        :param specific_viz_ids: in addition to the heatmaps generated above, this also generates heatmaps
            for specific sample indices. The first element of specific_viz_ids is for nominal samples
            and the second for anomalous ones. The resulting heatmaps are stored in a `specific_viz_ids` subdirectory.
        :return: A dictionary of ROC results, each ROC result is again represented by a dictionary of the form: {
                'tpr': [], 'fpr': [], 'ths': [], 'auc': int, ...
            }.
        """
        self.net = self.net.to(self.device).eval()

        if train_data:
            self.logger.print('Test training data...', fps=False)
            labels, loss, anomaly_scores, imgs, outputs, gtmaps, grads = self._gather_data(
                self.train_loader
            )
            self.heatmap_generation(labels, anomaly_scores, imgs, gtmaps, grads, name='train_heatmaps',)
            
        else:
            self.logger.print('Test training data SKIPPED', fps=False)

        self.logger.print('Test test data...', fps=False)
        labels, loss, anomaly_scores, imgs, outputs, gtmaps, grads = self._gather_data(
            self.test_loader,
        )
        
        def reorder(labels: List[int], loss: Tensor, anomaly_scores: Tensor, imgs: Tensor, outputs: Tensor, gtmaps: Tensor,
                    grads: Tensor, ds: Dataset = None) -> Tuple[List[int], Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
            """ returns all inputs in an identical new order if the dataset offers a predefined (random) order """
            if ds is not None and hasattr(ds, 'fixed_random_order'):
                assert gtmaps is None, \
                    'original gtmaps loaded in score do not know order! Hence reordering is not allowed for GT datasets'
                o = ds.fixed_random_order
                labels = labels[o] if isinstance(labels, (Tensor, np.ndarray)) else np.asarray(labels)[o].tolist()
                loss, anomaly_scores, imgs = loss[o], anomaly_scores[o], imgs[o]
                outputs, gtmaps = outputs[o], gtmaps
                grads = grads[o] if grads is not None else None
            return labels, loss, anomaly_scores, imgs, outputs, gtmaps, grads
        
        labels, loss, anomaly_scores, imgs, outputs, gtmaps, grads = reorder(
            labels, loss, anomaly_scores, imgs, outputs, gtmaps, grads, ds=self.test_loader.dataset
        )
        self.heatmap_generation(labels, anomaly_scores, imgs, gtmaps, grads, name='test_heatmaps',)

        with torch.no_grad():
            sc = self.score(labels, anomaly_scores, imgs, outputs, gtmaps, grads, subdir=subdir)
        return sc

    def _gather_data(self, loader: DataLoader,
                     gather_all=False) -> Tuple[List[int], Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
        all_labels, all_loss, all_anomaly_scores, all_imgs, all_outputs = [], [], [], [], []
        all_gtmaps, all_grads = [], []
        for n_batch, data in enumerate(loader):
            if isinstance(loader.dataset, GTMapADDataset):
                inputs, labels, gtmaps = data
                all_gtmaps.append(gtmaps)
            else:
                inputs, labels = data
            bk_inputs = inputs.detach().clone()
            inputs = inputs.to(self.device)
            if gather_all:
                outputs, loss, anomaly_score, _ = self._regular_forward(inputs, labels)
                inputs = bk_inputs.clone().to(self.device)
                _, _, _, grads = self._grad_forward(inputs, labels)
            elif self.objective == 'hsc':
                outputs, loss, anomaly_score, grads = self._grad_forward(inputs, labels)
            else:
                outputs, loss, anomaly_score, grads = self._regular_forward(inputs, labels)
            all_labels += labels.detach().cpu().tolist()
            all_loss.append(loss.detach().cpu())
            all_anomaly_scores.append(anomaly_score.detach().cpu())
            all_imgs.append(inputs.detach().cpu())
            all_outputs.append(outputs.detach().cpu())
            if grads is not None:
                all_grads.append(grads.detach().cpu())
            self.logger.print(
                'TEST {:04d}/{:04d} ID {}{}'.format(
                    n_batch, len(loader), str(self.__class__)[8:-2],
                    ' NCLS {}'.format(loader.dataset.normal_classes)
                    if hasattr(loader.dataset, 'normal_classes') else ''
                ),
                fps=True
            )
        all_imgs = torch.cat(all_imgs)
        all_outputs = torch.cat(all_outputs)
        all_gtmaps = torch.cat(all_gtmaps) if len(all_gtmaps) > 0 else None
        all_loss = torch.cat(all_loss)
        all_anomaly_scores = torch.cat(all_anomaly_scores)
        all_grads = torch.cat(all_grads) if len(all_grads) > 0 else None
        ret = (
            all_labels, all_loss, all_anomaly_scores, all_imgs, all_outputs, all_gtmaps,
            all_grads
        )
        return ret

    def _regular_forward(self, inputs: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        with torch.no_grad():
            outputs = self.net(inputs)
            loss = self.loss(outputs, inputs, labels, reduce='none')
            anomaly_score = self.anomaly_score(loss)
            grads = None
        return outputs, loss, anomaly_score, grads

    def _grad_forward(self, inputs: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        inputs.requires_grad = True
        outputs = self.net(inputs)
        loss = self.loss(outputs, inputs, labels, reduce='none')
        anomaly_score = self.anomaly_score(loss)
        grads = self.net.get_grad_heatmap(loss, inputs)
        inputs.requires_grad = False
        self.opt.zero_grad()
        return outputs, loss, anomaly_score, grads

    def score(self, labels: List[int], ascores: Tensor, imgs: Tensor, outs: Tensor, gtmaps: Tensor = None,
              grads: Tensor = None, subdir='.') -> dict:
        """
        Computes the ROC curves and the AUC for detection performance.
        Also computes those for the explanation performance if ground-truth maps are available.
        :param labels: labels
        :param ascores: anomaly scores
        :param imgs: input images
        :param outs: outputs of the neural network
        :param gtmaps: ground-truth maps (can be None)
        :param grads: gradients of anomaly scores w.r.t. inputs (can be None)
        :param subdir: subdirectory to store the data in (plots and numbers)
        :return:  A dictionary of ROC results, each ROC result is again represented by a dictionary of the form: {
                'tpr': [], 'fpr': [], 'ths': [], 'auc': int, ...
            }.
        """
        # Logging
        self.logger.print('Computing test score...')
        if torch.isnan(ascores).sum() > 0:
            self.logger.logtxt('Could not compute test scores, since anomaly scores contain nan values!!!', True)
            return None
        red_ascores = self.reduce_ascore(ascores).tolist()
        std = self.gauss_std

        # Overall ROC for sample-wise anomaly detection
        fpr, tpr, thresholds = roc_curve(labels, red_ascores)
        roc_score = roc_auc_score(labels, red_ascores)
        roc_res = {'tpr': tpr, 'fpr': fpr, 'ths': thresholds, 'auc': roc_score}
        self.logger.single_plot(
            'roc_curve', tpr, fpr, xlabel='false positive rate', ylabel='true positive rate',
            legend=['auc={}'.format(roc_score)], subdir=subdir
        )
        self.logger.single_save('roc', roc_res, subdir=subdir)
        self.logger.logtxt('##### ROC TEST SCORE {} #####'.format(roc_score), print=True)

        # GTMAPS pixel-wise anomaly detection = explanation performance
        gtmap_roc_res, gtmap_prc_res = None, None
        use_grads = grads is not None
        if gtmaps is not None:
            try:
                self.logger.print('Computing GT test score...')
                ascores = self.reduce_pixelwise_ascore(ascores) if not use_grads else grads
                gtmaps = self.test_loader.dataset.dataset.get_original_gtmaps_normal_class()
                if isinstance(self.net, ReceptiveNet):  # Receptive field upsampling for FCDD nets
                    ascores = self.net.receptive_upsample(ascores, std=std)
                # Further upsampling for original dataset size
                ascores = torch.nn.functional.interpolate(ascores, (gtmaps.shape[-2:]))
                flat_gtmaps, flat_ascores = gtmaps.reshape(-1).int().tolist(), ascores.reshape(-1).tolist()

                gtfpr, gttpr, gtthresholds = roc_curve(flat_gtmaps, flat_ascores)
                gt_roc_score = roc_auc_score(flat_gtmaps, flat_ascores)
                gtmap_roc_res = {'tpr': gttpr, 'fpr': gtfpr, 'ths': gtthresholds, 'auc': gt_roc_score}
                self.logger.single_plot(
                    'gtmap_roc_curve', gttpr, gtfpr, xlabel='false positive rate', ylabel='true positive rate',
                    legend=['auc={}'.format(gt_roc_score)], subdir=subdir
                )
                self.logger.single_save(
                    'gtmap_roc', gtmap_roc_res, subdir=subdir
                )
                self.logger.logtxt('##### GTMAP ROC TEST SCORE {} #####'.format(gt_roc_score), print=True)
            except AssertionError as e:
                self.logger.warning(f'Skipped computing the gtmap ROC score. {str(e)}')

        return {'roc': roc_res, 'gtmap_roc': gtmap_roc_res}

    def heatmap_generation(
        self, 
        labels: List[int], 
        ascores: Tensor, 
        imgs: Tensor, 
        gtmaps: Tensor = None, 
        grads: Tensor = None, 
        show_per_cls: int = 20,
        name='heatmaps', 
        subdir='.'
    ):
        minsamples = min(collections.Counter(labels).values())
        lbls = torch.IntTensor(labels)

        if minsamples < 2:
            self.logger.warning(
                f"Heatmap '{name}' cannot be generated. For some labels there are too few samples!", unique=False
            )
        else:
            this_show_per_cls = min(show_per_cls, minsamples)
            if this_show_per_cls % 2 != 0:
                this_show_per_cls -= 1
            # Evaluation Picture with 4 rows. Each row splits into 4 subrows with input-output-heatmap-gtm:
            # (1) 20 first nominal samples (2) 20 first anomalous samples
            # (3) 10 most nominal nominal samples - 10 most anomalous nominal samples
            # (4) 10 most nominal anomalies - 10 most anomalous anomalies
            idx = []
            for l in sorted(set(labels)):
                idx.extend((lbls == l).nonzero().squeeze(-1).tolist()[:this_show_per_cls])
            rascores = self.reduce_ascore(ascores)
            k = max(this_show_per_cls // 2, 1)
            for l in sorted(set(labels)):
                lid = set((lbls == l).nonzero().squeeze(-1).tolist())
                sort = [
                    i for i in np.argsort(rascores.detach().reshape(rascores.size(0), -1).sum(1)).tolist() if i in lid
                ]
                idx.extend([*sort[:k], *sort[-k:]])
            self._create_heatmaps_picture(
                idx, name, imgs.shape, subdir, this_show_per_cls, imgs, ascores, grads, gtmaps, labels
            )

        # Concise paper picture: Samples grow from most nominal to most anomalous (equidistant).
        # 2 versions: with local normalization and semi-global normalization
        if 'train' not in name:
            res = self.resdown * 2  # increase resolution limit because there are only a few heatmaps shown here
            rascores = self.reduce_ascore(ascores)
            inpshp = imgs.shape
            for l in sorted(set(labels)):
                lid = set((torch.from_numpy(np.asarray(labels)) == l).nonzero().squeeze(-1).tolist())
                if len(lid) < 1:
                    break
                k = min(show_per_cls // 3, len(lid))
                sort = [
                    i for i in np.argsort(rascores.detach().reshape(rascores.size(0), -1).sum(1)).tolist() if i in lid
                ]
                splits = np.array_split(sort, k)
                idx = [s[int(n / (k - 1) * len(s)) if n != len(splits) - 1 else -1] for n, s in enumerate(splits)]
                self.logger.logtxt(
                    'Interpretation visualization paper image {} indicies for label {}: {}'
                    .format('{}_paper_lbl{}'.format(name, l), l, idx)
                )
                self._create_singlerow_heatmaps_picture(
                    idx, name, inpshp, l, subdir, res, imgs, ascores, grads, gtmaps, labels
                )

    def _create_heatmaps_picture(self, idx: List[int], name: str, inpshp: torch.Size, subdir: str,
                                 nrow: int, imgs: Tensor, ascores: Tensor, grads: Tensor, gtmaps: Tensor,
                                 labels: List[int], norm: str = 'global'):
        """
        Creates a picture of inputs, heatmaps (either based on ascores or grads, if grads is not None),
        and ground-truth maps (if not None, otherwise omitted). Each row contains nrow many samples.
        One row contains always only one of {input, heatmaps, ground-truth maps}.
        The order of rows thereby is (1) inputs (2) heatmaps (3) ground-truth maps (4) blank.
        For instance, for 20 samples and nrow=10, the picture would show:
            - 10 inputs
            - 10 corresponding heatmaps
            - 10 corresponding ground-truth maps
            - blank
            - 10 inputs
            - 10 corresponding heatmaps
            - 10 corresponding ground-truth maps
        :param idx: limit the inputs (and corresponding other rows) to these indices.
        :param name: name to be used to store the picture.
        :param inpshp: the input shape (heatmaps will be resized to this).
        :param subdir: some subdirectory to store the data in.
        :param nrow: number of images per row.
        :param imgs: the input images.
        :param ascores: anomaly scores.
        :param grads: gradients.
        :param gtmaps: ground-truth maps.
        :param norm: what type of normalization to apply.
            None: no normalization.
            'local': normalizes each heatmap w.r.t. itself only.
            'global': normalizes each heatmap w.r.t. all heatmaps available (without taking idx into account),
                though it is ensured to consider equally many anomalous and nominal samples (if there are e.g. more
                nominal samples, randomly chosen nominal samples are ignored to match the correct amount).
            'semi-global: normalizes each heatmap w.r.t. all heatmaps chosen in idx.
        """
        number_of_rows = int(np.ceil(len(idx) / nrow))
        rows = []
        for s in range(number_of_rows):
            rows.append(self._image_processing(imgs[idx][s * nrow:s * nrow + nrow], inpshp, maxres=self.resdown, qu=1))
            if self.objective != 'hsc':
                rows.append(
                    self._image_processing(
                        ascores[idx][s * nrow:s * nrow + nrow], inpshp, maxres=self.resdown, qu=self.quantile,
                        colorize=True, ref=balance_labels(ascores, labels, False) if norm == 'global' else ascores[idx],
                        norm=norm.replace('semi_', ''),  # semi case is handled in the line above
                    )
                )
            if grads is not None:
                rows.append(
                    self._image_processing(
                        grads[idx][s * nrow:s * nrow + nrow], inpshp, self.blur_heatmaps,
                        self.resdown, qu=self.quantile,
                        colorize=True, ref=balance_labels(grads, labels, False) if norm == 'global' else grads[idx],
                        norm=norm.replace('semi_', ''),  # semi case is handled in the line above
                    )
                )
            if gtmaps is not None:
                rows.append(
                    self._image_processing(
                        gtmaps[idx][s * nrow:s * nrow + nrow], inpshp, maxres=self.resdown, norm=None
                    )
                )
            rows.append(torch.zeros_like(rows[-1]))
        name = '{}_{}'.format(name, norm)
        self.logger.imsave(name, torch.cat(rows), nrow=nrow, scale_mode='none', subdir=subdir)

    def _create_singlerow_heatmaps_picture(self, idx: List[int], name: str, inpshp: torch.Size, lbl: int, subdir: str,
                                           res: int, imgs: Tensor, ascores: Tensor, grads: Tensor, gtmaps: Tensor,
                                           labels: List[int]):
        """
        Creates a picture of inputs, heatmaps (either based on ascores or grads, if grads is not None),
        and ground-truth maps (if not None, otherwise omitted).
        Row-wise: (1) inputs (2) heatmaps (3) ground-truth maps.
        Creates one version with local normalization and one with semi_global normalization.
        :param idx: limit the inputs (and corresponding other rows) to these indices.
        :param name: name to be used to store the picture.
        :param inpshp: the input shape (heatmaps will be resized to this).
        :param lbl: label of samples (indices), only used for naming.
        :param subdir: some subdirectory to store the data in.
        :param res: maximum allowed resolution in pixels (images are downsampled if they exceed this threshold).
        :param imgs: the input images.
        :param ascores: anomaly scores.
        :param grads: gradients.
        :param gtmaps: ground-truth maps.
        """
        for norm in ['local', 'global']:
            rows = [self._image_processing(imgs[idx], inpshp, maxres=res, qu=1)]
            if self.objective != 'hsc':
                rows.append(
                    self._image_processing(
                        ascores[idx], inpshp, maxres=res, colorize=True,
                        ref=balance_labels(ascores, labels, False) if norm == 'global' else None,
                        norm=norm.replace('semi_', ''),  # semi case is handled in the line above
                    )
                )
            if grads is not None:
                rows.append(
                    self._image_processing(
                        grads[idx], inpshp, self.blur_heatmaps, res, colorize=True,
                        ref=balance_labels(grads, labels, False) if norm == 'global' else None,
                        norm=norm.replace('semi_', ''),  # semi case is handled in the line above
                    )
                )
            if gtmaps is not None:
                rows.append(self._image_processing(gtmaps[idx], inpshp, maxres=res, norm=None))
            tim = torch.cat(rows)
            imname = '{}_paper_{}_lbl{}'.format(name, norm, lbl)
            self.logger.single_save(imname, torch.stack(rows), subdir=pt.join('tims', subdir))
            self.logger.imsave(imname, tim, nrow=len(idx), scale_mode='none', subdir=subdir)

    def _image_processing(self, imgs: Tensor, input_shape: torch.Size, blur: bool = False, maxres: int = 64,
                          qu: float = None, norm: str = 'local', colorize: bool = False, ref: Tensor = None,
                          cmap: str = 'jet') -> Tensor:
        """
        Applies basic image processing techniques, including resizing, blurring, colorizing, and normalizing.
        The resize operation resizes the images automatically to match the input_shape. Other transformations
        are optional. Can be used to create pseudocolored heatmaps!
        :param imgs: a tensor of some images.
        :param input_shape: the shape of the inputs images the data loader returns.
        :param blur: whether to blur the image (has no effect for FCDD anomaly scores, where the
            anomaly scores are upsampled using a Gaussian kernel anyway).
        :param maxres: maximum allowed resolution in pixels (images are downsampled if they exceed this threshold).
        :param norm: what type of normalization to apply.
            None: no normalization.
            'local': normalizes each image w.r.t. itself only.
            'global': normalizes each image w.r.t. to ref (ref defaults to imgs).
        :param qu: quantile used for normalization, qu=1 yields the typical 0-1 normalization.
        :param colorize: whether to colorize grayscaled images using colormaps (-> pseudocolored heatmaps!).
        :param ref: a tensor of images used for global normalization (defaults to imgs).
        :param cmap: the colormap that is used to colorize grayscaled images.
        :return: transformed tensor of images
        """
        imgs = imgs.detach().clone()
        assert imgs.dim() == len(input_shape) == 4  # n x c x h x w
        std = self.gauss_std
        if qu is None:
            qu = self.quantile

        # upsample if necessary (img.shape != input_shape)
        if imgs.shape[2:] != input_shape[2:]:
            assert isinstance(self.net, ReceptiveNet), \
                'Some images are not of full resolution, and network is not a receptive net. This should not occur! '
            imgs = self.net.receptive_upsample(imgs, reception=True, std=std)

        # blur if requested
        if blur:
            if isinstance(self.net, ReceptiveNet):
                r = self.net.reception['r']
            elif self.objective == 'hsc':
                r = self.net.fcdd_cls(self.net.in_shape, bias=True).reception['r']
            elif self.objective == 'ae':
                enc = self.net.encoder
                if isinstance(enc, ReceptiveNet):
                    r = enc.reception['r']
                else:
                    r = enc.fcdd_cls(enc.in_shape, bias=True).reception['r']
            else:
                raise NotImplementedError()
            r = (r - 1) if r % 2 == 0 else r
            std = std or kernel_size_to_std(r)
            imgs = gaussian_blur2d(imgs, (r,) * 2, (std,) * 2)

        # downsample if resolution exceeds the limit given with maxres
        if maxres < max(imgs.shape[2:]):
            assert imgs.shape[-2] == imgs.shape[-1], 'Image provided is no square!'
            imgs = F.interpolate(imgs, (maxres, maxres), mode='nearest')

        # apply requested normalization
        if norm is not None:
            apply_norm = {
                'local': self.__local_norm, 'global': self.__global_norm
            }
            imgs = apply_norm[norm](imgs, qu, ref)

        # if image is grayscaled, colorize, i.e. provide a pseudocolored heatmap!
        if colorize:
            imgs = imgs.mean(1).unsqueeze(1)
            imgs = colorize_img([imgs, ], norm=False, cmap=cmap)[0]
        else:
            imgs = imgs.repeat(1, 3, 1, 1) if imgs.size(1) == 1 else imgs

        return imgs

    @staticmethod
    def __global_norm(imgs: Tensor, qu: int, ref: Tensor = None) -> Tensor:
        """
        Applies a global normalization of tensor, s.t. the highest value of the complete tensor is 1 and
        the lowest value is >= zero. Uses a non-linear normalization based on quantiles as explained in the appendix
        of the paper.
        :param imgs: images tensor
        :param qu: quantile used
        :param ref: if this is None, normalizes w.r.t. to imgs, otherwise normalizes w.r.t. to ref.
        """
        ref = ref if ref is not None else imgs
        imgs.sub_(ref.min())
        ref = ref.sub(ref.min())
        quantile = ref.reshape(-1).kthvalue(int(qu * ref.reshape(-1).size(0)))[0]  # qu% are below that
        imgs.div_(quantile)  # (1 - qu)% values will end up being out of scale ( > 1)
        plosses = imgs.clamp(0, 1)  # clamp those
        return plosses

    @staticmethod
    def __local_norm(imgs: Tensor, qu: int, ref: Tensor = None) -> Tensor:
        """
        Applies a local normalization of tensor, s.t. the highest value of each element (dim=0) in the tensor is 1 and
        the lowest value is >= zero. Uses a non-linear normalization based on quantiles as explained in the appendix
        of the paper.
        :param imgs: images tensor
        :param qu: quantile used
        """
        imgs.sub_(imgs.reshape(imgs.size(0), -1).min(1)[0][(...,) + (None,) * (imgs.dim() - 1)])
        quantile = imgs.reshape(imgs.size(0), -1).kthvalue(
            int(qu * imgs.reshape(imgs.size(0), -1).size(1)), dim=1
        )[0]  # qu% are below that
        imgs.div_(quantile[(...,) + (None,) * (imgs.dim() - 1)])
        imgs = imgs.clamp(0, 1)  # clamp those
        return imgs


## FCDDTrainer

In [14]:
class FCDDTrainer(BaseADTrainer):
    def loss(self, outs: Tensor, ins: Tensor, labels: Tensor, gtmaps: Tensor = None, reduce='mean'):
        """ computes the FCDD loss """
        assert reduce in ['mean', 'none']
        if self.objective in ['fcdd']:
            loss = self.__fcdd_loss(outs, ins, labels, gtmaps, reduce)
        else:
            raise NotImplementedError('Objective {} is not defined yet.'.format(self.objective))
        return loss

    def __fcdd_loss(self, outs: Tensor, ins: Tensor, labels: Tensor, gtmaps: Tensor, reduce: str):
        loss = outs ** 2
        loss = (loss + 1).sqrt() - 1
        if gtmaps is None and len(set(labels.tolist())) > 1:
            loss = self.__supervised_loss(loss, labels)
        elif gtmaps is not None and isinstance(self.net, FCDDNet):
            loss = self.__gt_loss(loss, gtmaps)
        return loss.mean() if reduce == 'mean' else loss

    def __supervised_loss(self, loss: Tensor, labels: Tensor):
        if self.net.training:
            loss = loss.reshape(labels.size(0), -1).mean(-1)
            norm = loss[labels == 0]
            anom = (-(((1 - (-loss[labels == 1]).exp()) + 1e-31).log()))
            loss[(1-labels).nonzero().squeeze()] = norm
            loss[labels.nonzero().squeeze()] = anom
        else:
            loss = loss
        return loss

    def __gt_loss(self, loss: Tensor, gtmaps: Tensor):
        if self.net.training:
            std = self.gauss_std
            loss = self.net.receptive_upsample(loss, reception=True, std=std, cpu=False)
            norm = (loss * (1 - gtmaps)).view(loss.size(0), -1).mean(-1)
            exclude_complete_nominal_samples = ((gtmaps == 1).view(gtmaps.size(0), -1).sum(-1) > 0)
            anom = torch.zeros_like(norm)
            if exclude_complete_nominal_samples.sum() > 0:
                a = (loss * gtmaps)[exclude_complete_nominal_samples]
                anom[exclude_complete_nominal_samples] = (
                    -(((1 - (-a.view(a.size(0), -1).mean(-1)).exp()) + 1e-31).log())
                )
            loss = norm + anom
        else:
            loss = loss
        return loss

# train

In [15]:
def use_wandb() -> bool:
    return bool(int(os.environ.get("WANDB", "0")))

In [21]:
import os.path as pt
from typing import List
import torch
import torch.optim as optim
from fcdd.datasets import load_dataset
from fcdd.datasets.bases import GTMapADDataset
from fcdd.models import load_nets
from fcdd.models.bases import BaseNet
from fcdd.util.logging import Logger
from fcdd.training.setup import pick_opt_sched
import json
import os.path as pt
import re
import time
import traceback
from collections import defaultdict
from copy import deepcopy
from argparse import ArgumentParser
from fcdd.training.fcdd import FCDDTrainer


# the names come from trainer.test()
RunResults = namedtuple('RunResults', ["roc", "gtmap_roc",])


def run_one(**kwargs):
    """
    kwargs should contain all parameters of the setup function in training.setup
    """
    logdir = kwargs["logdir"]
    
    if use_wandb():
        import wandb
        wandb.init(
            name=f"{logdir.parent.parent.name}.{logdir.parent.name}.{logdir.name}",
            project="fcdd-train-mvtec-dev00", 
            entity="mines-paristech-cmm",
            config=kwargs,
        )
        
    kwargs["logdir"] = str(logdir.absolute())
    kwargs["datadir"] = str(Path(kwargs["datadir"]).absolute())
    readme = kwargs.pop("readme")
    kwargs['config'] = f'{json.dumps(kwargs)}\n\n{readme}'

    acc_batches = kwargs.pop('acc_batches', 1)
    epochs = kwargs.pop('epochs')
    load_snapshot = kwargs.pop('load', None)  # pre-trained model, path to model snapshot
    test = kwargs.pop("test")
    
    del kwargs["log_start_time_str"]
    del kwargs["normal_class_label"]
    
    try:
        # this was the part
        # setup = trainer_setup(**kwargs)
        # trainer = SuperTrainer(**setup)
        setup: TrainSetup = trainer_setup(**kwargs)
        trainer = FCDDTrainer(
            net=setup.net,
            opt=setup.opt,
            sched=setup.sched,
            dataset_loaders=setup.dataset_loaders,
            logger=setup.logger,
            gauss_std=setup.gauss_std,
            quantile=setup.quantile,
            resdown=setup.resdown,
            blur_heatmaps=setup.blur_heatmaps,
            device=setup.device,
            objective="fcdd",  # hardcoded because this is not relevant anymore
        )
        
        if load_snapshot is None:
            epoch_start = 0
        
        else:
            epoch_start = trainer.load(load_snapshot)
            
    except:
        if use_wandb():
            import wandb
            wandb.finish()
        raise

    try:
        # this was the part
        # trainer.train(epochs, load, acc_batches=acc_batches)
        # epochs: from kwargs, ok
        # load: from kwargs, ok
        # acc_batches: from kwargs, ok
        trainer.train(
            epochs=epochs - epoch_start, 
            acc_batches=acc_batches,
            wandb=wandb if use_wandb() else None, 
        )

        if test and (epochs > 0 or load_snapshot is not None):
            ret = trainer.test()  # keys = {roc, gtmap_roc}
        else:
            ret = trainer.res  # keys = {roc, gtmap_roc}

        return RunResults(
            roc=ret["roc"],
            gtmap_roc=ret["gtmap_roc"],
        )
        
    except:
        setup.logger.printlog += traceback.format_exc()
        raise  # the re-raise is executed after the 'finally' clause

    finally:
        # joao: the original code had this comment about logger.print_logs()
        # no finally statement, because that breaks debugger
        # joao: i'm ignoring it to see what happens
        # and it was in the except clause of the BaseRunner.run_one()
        setup.logger.log_prints() 
        
        setup.logger.save()
        setup.logger.plot()
        setup.logger.snapshot(trainer.net, trainer.opt, trainer.sched, epochs)

        if use_wandb():
            wandb.finish()    
            
                    
def run(**kwargs) -> dict:
    
    original_logdir = kwargs['logdir']
    dataset = kwargs['dataset']
    
    cls_restrictions = kwargs.pop("cls_restrictions", None)
    classes = cls_restrictions or range(no_classes(dataset))

    number_it = kwargs.pop('it')
    its_restrictions = kwargs.pop("its_restrictions", None)
    its = its_restrictions or range(number_it)

    results = []
    
    for c in classes:
        cls_logdir = original_logdir / f'normal_{c}'
        
        kwargs['normal_class'] = c
        kwargs['normal_class_label'] = str_labels(dataset)[c]
    
        for i in its:
            it_logdir = cls_logdir / 'it_{}'.format(i)
            res = run_one(**{**kwargs, **dict(logdir=it_logdir)})  # overwrite logdir
            results.append(dict(class_idx=c, it=i, results=res))

    return results


# launch

In [22]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["WANDB"] = "1"

ARG_STRING = "--cls-restrictions 0 --it 1 --epochs 3"

args = ARG_STRING.split(" ")
args = parser.parse_args(args=args)
args = args_post_parse(args)

results = run(**vars(args))



Files already downloaded.
Loading dataset from /home/bertoldo/repos/fcdd/python/dev/../../data/datasets/mvtec/admvtec_240x240.pt...
Dataset complete.
Files already downloaded.
Loading dataset from /home/bertoldo/repos/fcdd/python/dev/../../data/datasets/mvtec/admvtec_240x240.pt...
Dataset complete.
Successfully saved code at /home/bertoldo/repos/fcdd/python/dev/../../data/results/mvtec_fcdd_20220420165008/normal_0/it_0/./src.tar.gz
Generating dataset preview...
Dataset preview generated.
EPOCH 00 NBAT 0007/0131 ERR nan ERR_NORMAL nan ERR_ANOMALOUS nan INFO LR ['1e-03'] ID fcdd.training.fcdd.FCDDTrainer
EPOCH 00 NBAT 0023/0131 ERR nan ERR_NORMAL nan ERR_ANOMALOUS nan INFO LR ['1e-03'] ID fcdd.training.fcdd.FCDDTrainer
EPOCH 00 NBAT 0031/0131 ERR nan ERR_NORMAL nan ERR_ANOMALOUS nan INFO LR ['1e-03'] ID fcdd.training.fcdd.FCDDTrainer
EPOCH 00 NBAT 0039/0131 ERR nan ERR_NORMAL nan ERR_ANOMALOUS nan INFO LR ['1e-03'] ID fcdd.training.fcdd.FCDDTrainer
EPOCH 00 NBAT 0047/0131 ERR nan ERR_NOR

In [None]:
# this was in run_seeds
# for key in results:
#     plot_many_roc(
#         logdir.replace('{t}', kwargs["log_start_time_str"], results[key],
#         labels=its, mean=True, name=key
#     )
    
# return {key: mean_roc(val) for key, val in results.items()}

In [None]:
# this was in run_classes

        # this was in the finally of the class loop
            # print('Plotting ROC for completed classes up to {}...'.format(c))
            # for key in results:
            #     plot_many_roc(
            #         logdir.replace('{t}', kwargs["log_start_time_str"], results[key],
            #         labels=str_labels(kwargs['dataset']), mean=True, name=key
            #     )
                
    # for key in results:
    #     plot_many_roc(
    #         logdir.replace('{t}', kwargs["log_start_time_str"], results[key],
    #         labels=str_labels(kwargs['dataset']), mean=True, name=key
    #     )