# Generate and save heatmaps

In [1]:
# make a cell print all the outputs instead of just the last one
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

## Evaluation

### Set up

In [2]:
%cd "C:\Users\Public\Documents\DIMA\fcdd\python\analyse"

c:\Users\Public\Documents\DIMA\fcdd\python\analyse


In [3]:
from pathlib import Path

FIGS_DIR = Path(".") / "figs"
FIGS_DIR.mkdir(exist_ok=True)

DATA_DIR = Path(".") / "data" # est-ce le bon dossier ?
DATA_DIR.mkdir(exist_ok=True)

SNAPSHOTS_DIR = Path("../../data")
assert SNAPSHOTS_DIR.exists()

MVTECAD_DIR= Path("../../data/datasets/mvtec") # est-ce le bon dossier ?
assert MVTECAD_DIR.exists()

RECORDS_FPATH = DATA_DIR / "snapshot.pt"

### Get snapshots

In [4]:
import numpy as np
from numpy import ndarray
from pathlib import Path
import json
import re
from datetime import timedelta, datetime
from typing import Dict, Union, List
import copy


# this is to get the strings associated to the classes in the fcdd code
# copied from: fcdd/python/fcdd/datasets/__init__.py
# inside function `str_labels`
# commit: 9f268d8fd2fee33a5c5f38cdfb781da927bdb614
CLASS_LABELS = {
    'cifar10': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
    'fmnist': [
        't-shirt/top', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'
    ],
    'mvtec': [
        'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather',
        'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor',
        'wood', 'zipper'
    ],
    # 'imagenet': deepcopy(ADImageNet.ad_classes),
    # this one forwards to: fcdd/python/fcdd/datasets/imagenet.py
    # in: ADImageNet.ad_classes
    # at the same commit as above
    'imagenet': ['acorn', 'airliner', 'ambulance', 'American alligator', 'banjo', 'barn', 'bikini', 'digital clock',
                  'dragonfly', 'dumbbell', 'forklift', 'goblet', 'grand piano', 'hotdog', 'hourglass', 'manhole cover',
                  'mosque', 'nail', 'parking meter', 'pillow', 'revolver', 'dial telephone', 'schooner',
                  'snowmobile', 'soccer ball', 'stingray', 'strawberry', 'tank', 'toaster', 'volcano'],
    'pascalvoc': ['horse'],
}

def get_classes_labels_order(dataset: str) -> List[str]:
    return copy.deepcopy(CLASS_LABELS[dataset])

def get_class_label(class_dirname: str, dataset: str) -> str:
    return CLASS_LABELS[dataset][int(class_dirname.lstrip("normal_"))]

class MissingFileInExperiment(FileNotFoundError):
    pass

class UnfinishedExperiment(Exception):
    pass

def get_snapshots(experiment_dir: Path) -> Dict[str, Union[str, float, ndarray]]:
    """
    :param experiment_dir: path to the experiment directory
    """
    snapshots_dir = experiment_dir
    
    snapshots = []
    for snashotpath in snapshots_dir.glob("*.pt"): # modif dans .glob
        snapshot_name = snashotpath.name
        # epoch = int(snashotpath.stem.split("=")[1]) 
        snapshots.append({
            "fpath": snashotpath,
            "snapshot_name": snapshot_name,
            # "epoch": epoch,
        })
        
    return snapshots


def get_all(path: Path, dataset: str) -> List[Dict[str, Union[str, float, ndarray]]]:
    
    """
    :param path: a folder that contains dirs like 'fcdd_20211220193242_fmnist_' 
                 a whole experiment on a dataset with all iterations and nominal classes inside, 
                 the structure should look like
                 
                path/
                path/normal_0
                path/normal_1
                ...
                path/normal_9/
                path/normal_9/it_0
                ...
                path/normal_9/it_4/roc.json
    """
    assert path.is_dir()
    assert dataset in CLASS_LABELS
    
    snapshots = []

    experiment_snapshots = get_snapshots(path)

    for snap in experiment_snapshots:
        # print
        snapshots.append({
            **snap,
            **{
                "rootdir": path,
                # "rundir": rundir,
                # "classdir": classdir,
                # "iterdir": iterdir,
                "rundir_name": "snapshot",
                # "classdir_name": classdir.name,
                "class_idx": 1,
                "class_label": 'carpet',
                # "iterdir_name": iterdir.name,
                "iter_idx": 1,
                "epoch": 1,
            },
        })

    return snapshots


# get_all(path=SNAPSHOTS_DIR, dataset="mvtec")

In [5]:
import copy
import pickle

try:
    del records
    
except NameError:
    pass

try:
    print("loading records")
    with RECORDS_FPATH.open("rw") as f:
        records = pickle.load(f)
        
except:
    
    print("couldn't find records, recomputing")
    records = get_all(path=SNAPSHOTS_DIR, dataset="mvtec")

    # print("saving records")
    # with RECORDS_FPATH.open("wb") as f:
    #     pickle.dump(records, f)
        
f"{len(records)=}"

loading records
couldn't find records, recomputing


'len(records)=1'

In [6]:
import pandas as pd
index_cols = ["rundir_name", "class_label", "iter_idx", "epoch"]
# drop_cols = ["rundir", "classdir", "iterdir", "classdir_name", "iterdir_name"]
df_snapshots = pd.DataFrame.from_records(data=records).set_index(index_cols)
df_snapshots.head(5)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,fpath,snapshot_name,rootdir,class_idx
rundir_name,class_label,iter_idx,epoch,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
snapshot,carpet,1,1,..\..\data\snapshot.pt,snapshot.pt,..\..\data,1


### Get masks

In [None]:
# # A travailler

# IMAGES_GLOB = "*.png"

# imgs = []

# for classdir in MVTECAD_DIR.glob("*"):
    
#     if not classdir.is_dir():
#         continue    
    
#     print(f"{classdir.name=}")
    
#     testdir = classdir / "test"
#     traindir = classdir / "train"
    
#     if not testdir.exists() and not traindir.exists():
#         continue
    
#     print(f"{testdir.name=}")    
    
#     for typedir in testdir.glob("*"):
        
#         if not typedir.is_dir():
#             continue
        
#         print(f"{typedir.name=}")
        
#         img_paths = list(typedir.glob(IMAGES_GLOB))
        
#         if len(img_paths) == 0:
#             print("empty dir")
#             continue  
        
#         print(f"{len(img_paths)=}")      
        
#         for imgpath in img_paths:
#             imgs.append({
#                 "imgpath": imgpath,
#                 "class": classdir.name,
#                 "type": typedir.name,
#                 "set": "test",
#                 "imgidx": int(imgpath.stem),
#             })
            
#     print(f"{traindir.name=}")
    
#     img_paths = list((traindir / "good").glob(IMAGES_GLOB))
    
#     if len(img_paths) == 0:
#         print("empty dir")

#     else:
#         for imgpath in img_paths:
#             imgs.append({
#                 "imgpath": imgpath,
#                 "class": classdir.name,
#                 "type": "good",
#                 "set": "train",
#                 "imgidx": int(imgpath.stem),
#             })    
    
#     print(30 * "-")

# print(f"{len(imgs)=}")

In [7]:
MASKS_GLOB = "*.png"

masks = []

# We only want the "carpet" class
# If you want all the classes, use a for loop (see notebook 010-generate-predictions-from-snapshots)

classdir = MVTECAD_DIR / "carpet"

print(f"{classdir.name=}")

groundtruthdir = classdir / "ground_truth"

print(f"{groundtruthdir.name=}")

for typedir in groundtruthdir.glob("*"):
    
    if not typedir.is_dir():
        continue
    
    print(f"{typedir.name=}")
    
    masks_paths = list(typedir.glob(MASKS_GLOB))
    
    if len(masks_paths) == 0:
        print("empty dir")
        continue
    
    print(f"{len(masks_paths)=}")
    
    for maskpath in masks_paths:
        masks.append({
            "mask_path": maskpath.resolve(),
            "class": classdir.name,
            "type": typedir.name,
            "set": "ground_truth",
            "mask_idx": int(maskpath.stem[:3]),
        })
        

print(30 * "-")

print(f"{len(masks)=}")

classdir.name='carpet'
groundtruthdir.name='ground_truth'
typedir.name='color'
len(masks_paths)=19
typedir.name='cut'
len(masks_paths)=17
typedir.name='hole'
len(masks_paths)=17
typedir.name='metal_contamination'
len(masks_paths)=17
typedir.name='thread'
len(masks_paths)=19
------------------------------
len(masks)=89


In [8]:
import pandas as pd
df_masks = pd.DataFrame.from_records(data=masks).set_index(["class", "type", "set", "mask_idx"])
df_masks

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,mask_path
class,type,set,mask_idx,Unnamed: 4_level_1
carpet,color,ground_truth,0,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,color,ground_truth,1,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,color,ground_truth,2,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,color,ground_truth,3,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,color,ground_truth,4,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,...,...,...,...
carpet,thread,ground_truth,14,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,thread,ground_truth,15,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,thread,ground_truth,16,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,thread,ground_truth,17,C:\Users\Public\Documents\DIMA\fcdd\data\datas...


### Get datasets images

In [None]:
# # A travailler

# IMAGES_GLOB = "*.png"

# imgs = []

# for classdir in MVTECAD_DIR.glob("*"):
    
#     if not classdir.is_dir():
#         continue    
    
#     print(f"{classdir.name=}")
    
#     testdir = classdir / "test"
#     traindir = classdir / "train"
    
#     if not testdir.exists() and not traindir.exists():
#         continue
    
#     print(f"{testdir.name=}")    
    
#     for typedir in testdir.glob("*"):
        
#         if not typedir.is_dir():
#             continue
        
#         print(f"{typedir.name=}")
        
#         img_paths = list(typedir.glob(IMAGES_GLOB))
        
#         if len(img_paths) == 0:
#             print("empty dir")
#             continue  
        
#         print(f"{len(img_paths)=}")      
        
#         for imgpath in img_paths:
#             imgs.append({
#                 "imgpath": imgpath,
#                 "class": classdir.name,
#                 "type": typedir.name,
#                 "set": "test",
#                 "imgidx": int(imgpath.stem),
#             })
            
#     print(f"{traindir.name=}")
    
#     img_paths = list((traindir / "good").glob(IMAGES_GLOB))
    
#     if len(img_paths) == 0:
#         print("empty dir")

#     else:
#         for imgpath in img_paths:
#             imgs.append({
#                 "imgpath": imgpath,
#                 "class": classdir.name,
#                 "type": "good",
#                 "set": "train",
#                 "imgidx": int(imgpath.stem),
#             })    
    
#     print(30 * "-")

# print(f"{len(imgs)=}")

In [9]:
IMAGES_GLOB = "*.png"

imgs = []

# We only want the "carpet" class
# If you want all the classes, use a for loop (see notebook 010-generate-predictions-from-snapshots)
classdir = MVTECAD_DIR / "carpet"


print(f"{classdir.name=}")

testdir = classdir / "test"
traindir = classdir / "train"

print(f"{testdir.name=}")

for typedir in testdir.glob("*"):
    
    if not typedir.is_dir():
        continue
    
    print(f"{typedir.name=}")
    
    img_paths = list(typedir.glob(IMAGES_GLOB))
    
    if len(img_paths) == 0:
        print("empty dir")
        continue  
    
    print(f"{len(img_paths)=}")      
    
    for imgpath in img_paths:
        imgs.append({
            "imgpath": imgpath.resolve(),
            "class": classdir.name,
            "type": typedir.name,
            "set": "test",
            "imgidx": int(imgpath.stem),
        })
        
print(f"{traindir.name=}")

img_paths = list((traindir / "good").glob(IMAGES_GLOB))

if len(img_paths) == 0:
    print("empty dir")

else:
    for imgpath in img_paths:
        imgs.append({
            "imgpath": imgpath.resolve(),
            "class": classdir.name,
            "type": "good",
            "set": "train",
            "imgidx": int(imgpath.stem),
        })    

print(30 * "-")

print(f"{len(imgs)=}")

classdir.name='carpet'
testdir.name='test'
typedir.name='color'
len(img_paths)=19
typedir.name='cut'
len(img_paths)=17
typedir.name='good'
len(img_paths)=28
typedir.name='hole'
len(img_paths)=17
typedir.name='metal_contamination'
len(img_paths)=17
typedir.name='thread'
len(img_paths)=19
traindir.name='train'
------------------------------
len(imgs)=397


In [10]:
import pandas as pd
imgs = pd.DataFrame.from_records(data=imgs).set_index(["class", "set", "type", "imgidx"])
imgs

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,imgpath
class,set,type,imgidx,Unnamed: 4_level_1
carpet,test,color,0,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,test,color,1,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,test,color,2,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,test,color,3,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,test,color,4,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,...,...,...,...
carpet,train,good,275,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,train,good,276,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,train,good,277,C:\Users\Public\Documents\DIMA\fcdd\data\datas...
carpet,train,good,278,C:\Users\Public\Documents\DIMA\fcdd\data\datas...


### Select images

In [11]:
%matplotlib inline

In [12]:
from pandas import DataFrame

selected_class = "carpet"

class_imgs = imgs.loc[selected_class]

def get_indices_summary(df_: DataFrame) -> DataFrame:
    return df_.reset_index()[df_.index.names].groupby(df_.index.names[0:2]).agg(["min", "max"])

get_indices_summary(class_imgs)

Unnamed: 0_level_0,Unnamed: 1_level_0,imgidx,imgidx
Unnamed: 0_level_1,Unnamed: 1_level_1,min,max
set,type,Unnamed: 2_level_2,Unnamed: 3_level_2
test,color,0,18
test,cut,0,16
test,good,0,27
test,hole,0,16
test,metal_contamination,0,16
test,thread,0,18
train,good,0,279


In [13]:
class_imgs.index.get_level_values(1).unique()

Index(['color', 'cut', 'good', 'hole', 'metal_contamination', 'thread'], dtype='object', name='type')

In [14]:
from pandas import DataFrame
import time
from pathlib import Path
from typing import List


def create_tmpdir(img_fpaths: List[Path], tmp_dpath: Path):

    tmp_img_dir = tmp_dpath / tmp_dpath.name
    tmp_img_dir.mkdir(parents=True)

    for img_idx, img_fpath in enumerate(img_fpaths):
        # create a file name with the parents names
        symlink_name = f"{img_idx:05}-{img_fpath.parent.parent.parent.name}-{img_fpath.parent.parent.name}-{img_fpath.parent.name}-{img_fpath.name}"
        symlink_fpath = tmp_img_dir / symlink_name
        symlink_fpath.symlink_to(img_fpath)

tmp_dpath = DATA_DIR / f"tmp_{int(time.time())}"
tmp_dpath = tmp_dpath.absolute()

create_tmpdir(
    img_fpaths = list(class_imgs["imgpath"]),
    tmp_dpath = tmp_dpath,
)

In [15]:
img_names = []
img_fpaths = list(class_imgs["imgpath"])
for img_idx, img_fpath in enumerate(img_fpaths):
    # create a file name with the parents names
    name = f"{img_idx:05}-{img_fpath.parent.parent.parent.name}-{img_fpath.parent.parent.name}-{img_fpath.parent.name}-{img_fpath.name}"
    img_names.append(name)

### Generate explanations

In [16]:
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader

In [17]:
%load_ext autoreload

In [18]:
%autoreload 2
from fcdd.training.fcdd import FCDDTrainer
from fcdd.models.fcdd_cnn_224 import FCDD_CNN224_VGG_F
from fcdd.datasets.image_folder import ImageFolder
from fcdd.datasets.preprocessing import local_contrast_normalization
from fcdd.util.logging import Logger

In [19]:
min_max_l1 = [
    [(-1.3336724042892456, -1.3107913732528687, -1.2445921897888184),
     (1.3779616355895996, 1.3779616355895996, 1.3779616355895996)],
    [(-2.2404820919036865, -2.3387579917907715, -2.2896201610565186),
     (4.573435306549072, 4.573435306549072, 4.573435306549072)],
    [(-3.184587001800537, -3.164201259613037, -3.1392977237701416),
     (1.6995097398757935, 1.6011602878570557, 1.5209171772003174)],
    [(-3.0334954261779785, -2.958242416381836, -2.7701096534729004),
     (6.503103256225586, 5.875098705291748, 5.814228057861328)],
    [(-3.100773334503174, -3.100773334503174, -3.100773334503174),
     (4.27892541885376, 4.27892541885376, 4.27892541885376)],
    [(-3.6565306186676025, -3.507692813873291, -2.7635035514831543),
     (18.966819763183594, 21.64590072631836, 26.408710479736328)],
    [(-1.5192601680755615, -2.2068002223968506, -2.3948357105255127),
     (11.564697265625, 10.976534843444824, 10.378695487976074)],
    [(-1.3207964897155762, -1.2889339923858643, -1.148416519165039),
     (6.854909896850586, 6.854909896850586, 6.854909896850586)],
    [(-0.9883341193199158, -0.9822461605072021, -0.9288841485977173),
     (2.290637969970703, 2.4007883071899414, 2.3044068813323975)],
    [(-7.236185073852539, -7.236185073852539, -7.236185073852539),
     (3.3777384757995605, 3.3777384757995605, 3.3777384757995605)],
    [(-3.2036616802215576, -3.221003532409668, -3.305514335632324),
     (7.022546768188477, 6.115569114685059, 6.310940742492676)],
    [(-0.8915618658065796, -0.8669204115867615, -0.8002046346664429),
     (4.4255571365356445, 4.642300128936768, 4.305730819702148)],
    [(-1.9086798429489136, -2.0004451274871826, -1.929288387298584),
     (5.463134765625, 5.463134765625, 5.463134765625)],
    [(-2.9547364711761475, -3.17536997795105, -3.143850803375244),
     (5.305514812469482, 4.535006523132324, 3.3618252277374268)],
    [(-1.2906527519226074, -1.2906527519226074, -1.2906527519226074),
     (2.515115737915039, 2.515115737915039, 2.515115737915039)]
]

In [20]:
from fcdd.datasets import load_dataset
from fcdd.runners.run_mvtec import MvtecConfig
from argparse import ArgumentParser

In [21]:
parser = ArgumentParser()
parser = MvtecConfig()(parser)
dftargs = parser.parse_args("")

In [22]:
# todo change me
OUTPUTS_DIR = DATA_DIR.absolute()

In [23]:
# [optional] to generate heatmaps, define a logger (with the path where the heatmaps should be saved to) and a quantile
import time
exp_start_time = int(time.time())
logger = Logger(str(OUTPUTS_DIR), exp_start_time=exp_start_time)
quantile = 0.97

In [24]:
def gen_heatmaps_for_training(list_snaphots_paths, normal_class):
    
    net = FCDD_CNN224_VGG_F((3, 224, 224), bias=True)
    
    # Use the same test transform as was used for training the snapshot
    # (e.g., for mvtec, per default, the following)
    transform = transforms.Compose([
        transforms.Resize(224),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: local_contrast_normalization(x, scale='l1')),
        transforms.Normalize(
            min_max_l1[normal_class][0],
            [ma - mi for ma, mi in zip(min_max_l1[normal_class][1], min_max_l1[normal_class][0])]
        )
    ])
    
    all_all_inputs = []
    all_all_labels = []
    all_all_anomaly_scores_pixelwise = [] 
    all_all_anomaly_scores_imgwise = [] 
    
    for snapshot_fpath in list_snaphots_paths:

        trainer = FCDDTrainer(net, None, None, (None, None), logger, 'fcdd', 8, quantile, 224)
        trainer.load(str(snapshot_fpath), cpu=True)
        trainer.net.eval()

        data_set = ImageFolder(tmp_dpath, transform)
        loader = DataLoader(data_set, batch_size=8, num_workers=0)

        all_anomaly_scores, all_inputs, all_labels = [], [], []
        for inputs, labels in loader:
            # inputs = inputs.cuda()
            with torch.no_grad():
                outputs = trainer.net(inputs)
                anomaly_scores = trainer.anomaly_score(trainer.loss(outputs, inputs, labels, reduce='none'))
                anomaly_scores = trainer.net.receptive_upsample(anomaly_scores, reception=True, std=8, cpu=False)
                all_anomaly_scores.append(anomaly_scores.cpu())
                all_inputs.append(inputs.cpu())
                all_labels.append(labels)

        all_inputs = torch.cat(all_inputs)
        all_labels = torch.cat(all_labels)
        # all_anomaly_scores will be a tensor containing pixel-wise anomaly scores for all images
        anomaly_scores_pixelwise = torch.cat(all_anomaly_scores)
        anomaly_scores_imgwise = trainer.reduce_ascore(anomaly_scores_pixelwise)
        
        all_all_inputs.append(all_inputs)
        all_all_labels.append(all_labels)
        all_all_anomaly_scores_pixelwise.append(anomaly_scores_pixelwise)
        all_all_anomaly_scores_imgwise.append(anomaly_scores_imgwise)
    
    all_all_inputs = torch.cat([torch.unsqueeze(t, 0) for t in all_all_inputs], dim=0)
    all_all_labels = torch.cat([torch.unsqueeze(t, 0) for t in all_all_labels], dim=0)
    all_all_anomaly_scores_pixelwise = torch.cat([torch.unsqueeze(t, 0) for t in all_all_anomaly_scores_pixelwise], dim=0)
    all_all_anomaly_scores_imgwise = torch.cat([torch.unsqueeze(t, 0) for t in all_all_anomaly_scores_imgwise], dim=0)
    
    return all_all_inputs, all_all_labels, all_all_anomaly_scores_pixelwise, all_all_anomaly_scores_imgwise

In [25]:
index0 = df_snapshots.index.get_level_values(0).unique()[0]
df_ = df_snapshots.loc[index0]

normal_class_label = df_.index.get_level_values(0).unique()[0]
df_ = df_.loc[normal_class_label]

iter_idx = df_.index.get_level_values(0).unique()[0]
df_ = df_.loc[iter_idx]

df_ = df_.sort_index(0)

list_snaphots_paths = list(df_snapshots["fpath"])

inputs, labels, as_pixelwise, as_imgwise = gen_heatmaps_for_training(
    list_snaphots_paths=list_snaphots_paths,
    normal_class=get_classes_labels_order("mvtec").index(normal_class_label)
)

  df_ = df_.sort_index(0)


Loaded net_state, opt_state, sched_state with starting epoch 200 for fcdd.training.fcdd.FCDDTrainer


In [26]:
inputs.shape # images
as_pixelwise.shape # result pixel-wise (mask)
as_imgwise.shape # result image-wise

torch.Size([1, 397, 3, 224, 224])

torch.Size([1, 397, 1, 224, 224])

torch.Size([1, 397])

In [27]:
inputs2 = torch.squeeze(inputs)
as_pixelwise2 = torch.squeeze(as_pixelwise)
as_imgwise2 = torch.squeeze(as_imgwise)

### Save heatmaps

In [28]:
import matplotlib as mpl
from matplotlib import image

# We want to save the heatmaps in as_pixelwise2

HEATMAPS_DIR = Path(".") / "data" / "generated_heatmaps" / "carpet_train_test"
HEATMAPS_DIR.mkdir(parents=True, exist_ok=True)

as_pixelwise2_np = as_pixelwise2.numpy()
global_heatmaps_min = as_pixelwise2_np.min()
global_heatmaps_max = as_pixelwise2_np.max()

# # A améliorer: récupérer plutôt le nom du fichier
# labels_list = [f"1-color-{i:03}" for i in range(19)]
# labels_list += [f"2-cut-{i:03}" for i in range(17)]
# labels_list += [f"3-good-{i:03}" for i in range(28)]
# labels_list += [f"4-hole-{i:03}" for i in range(17)]
# labels_list += [f"5-metal-{i:03}" for i in range(17)]
# labels_list += [f"6-thread-{i:03}" for i in range(19)]
# labels_list += [f"7-good-{i:03}" for i in range(280)]

for array, name in zip(as_pixelwise2_np, img_names) :
    array = (array - global_heatmaps_min)/(global_heatmaps_max - global_heatmaps_min)
    image.imsave(HEATMAPS_DIR/name, array);