In [None]:

import torch
import torch.nn.functional as F
from models import Create_nets
from datasets import Get_dataloader
#from options import TrainOptions
from torchvision import models
import os
from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#args = TrainOptions().parse() # surpass kernelerror with this:

class TrainOptions:
    def __init__(self):
        self.exp_name = "Exp0-r18"
        self.epoch_start = 0
        self.epoch_num = 150
        self.factor = 1
        self.seed = 233
        self.num_row = 4
        self.activation = 'gelu'
        self.unalign_test = False
        self.data_root = '/home/bule/projects/datasets/mvtec_anomaly_detection/'
        self.dataset_name = "cable"
        self.batch_size = 2
        self.lr = 1e-4
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 8
        self.image_result_dir = 'result_images'
        self.model_result_dir = 'saved_models'
        self.validation_image_dir = 'validation_images'

# Example of how to use this class
args = TrainOptions()

torch.manual_seed(args.seed)

In [None]:
anomaly_categories = {
    'bottle': ['broken_large', 'broken_small', 'contamination'],
    'cable': ['bent_wire', 'cable_swap', 'combined', 'cut_inner_insulation', 'cut_outer_insulation', 'missing_cable', 'missing_wire', 'poke_insulation'],
    'capsule': ['crack', 'faulty_imprint', 'poke', 'scratch','squeeze'],
    'carpet': ['color', 'cut', 'hole', 'metal_contamination', 'thread'],
    'grid': ['bent', 'broken', 'glue', 'metal_contamination', 'thread'],
    'hazelnut': ['crack', 'cut', 'hole', 'print'],
    'leather': ['color', 'cut', 'fold', 'glue', 'poke'],
    'metal_nut': ['bent', 'color', 'flip', 'scratch'],
    'pill': ['color', 'combined','contamination', 'crack', 'faulty_imprint', 'pill_type','scratch'],
    'screw': ['manipulated_front', 'scratch_head', 'scratch_neck','thread_side', 'thread_top'],
    'tile': ['crack', 'glue_strip', 'gray_stroke', 'oil','rough'],
    'toothbrush': ['defective'],
    'transistor': ['bent_lead', 'cut_lead', 'damaged_case', 'misplaced'],
    'wood': ['color', 'combined', 'hole', 'liquid', 'scratch'],
    'zipper': ['broken_teeth', 'combined','fabric_border', 'fabric_interior','split_teeth','rough', 'squeezed_teeth']
}

In [None]:


category= "screw"
anomaly_category="manipulated_front"


# define some images for inspecting good train , good test and anomaly test
good_img_train_path = os.path.join(args.data_root,f'{category}/train/good/003.png') 
good_img_test_path = os.path.join(args.data_root,f'{category}/test/good/003.png')
anomaly_img_test_path = os.path.join(args.data_root,f'{category}/test/{anomaly_category}/004.png')

In [None]:
## laod a saved model to inspect 



SAVE_PATH='./inspects'

## Load Pretrained  Trafo model
transformer = Create_nets(args).to(device)
checkpoint = torch.load(f'./Exp0-r18-{category}/saved_models/checkpoint.pth')
transformer.load_state_dict(checkpoint['transformer'])
#print(transformer)

# Backbone hooks
backbone = models.resnet18(pretrained=True).to(device)
backbone.eval()
outputs = []
def hook(module, input, output):
    outputs.append(output)
backbone.layer1[-1].register_forward_hook(hook)
backbone.layer2[-1].register_forward_hook(hook)
backbone.layer3[-1].register_forward_hook(hook)

def embedding_concat(x, y):
    B, C1, H1, W1 = x.size()
    _, C2, H2, W2 = y.size()
    s = int(H1 / H2)
    x = F.unfold(x, kernel_size=s, dilation=1, stride=s)
    x = x.view(B, C1, -1, H2, W2)
    z = torch.zeros(B, C1 + C2, x.size(2), H2, W2).to(device)
    for i in range(x.size(2)):
        z[:, :, i, :, :] = torch.cat((x[:, :, i, :, :], y), 1)
    z = z.view(B, -1, H2 * W2)
    z = F.fold(z, kernel_size=s, output_size=(H1, W1), stride=s)
    return z

def load_image(filename,crop_size=256,aligned=True, img_size=280):
        img = Image.open(filename)
        img = img.convert('RGB')
        
        if aligned:
            img = TF.resize(img, crop_size, Image.BICUBIC)
            img = TF.to_tensor(img)
            img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ])
        else:
            img = TF.resize(img, img_size, Image.BICUBIC)
            angle = transforms.RandomRotation.get_params([-10, 10])
            img = TF.rotate(img, angle, fill=(0,))
            i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(crop_size, crop_size))
            img = TF.crop(img, i, j, h, w)
            img = TF.to_tensor(img)
            img = TF.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225 ])    
            img=img.to(torch.float32)
        return img
    
def plot_images(images):
    num_images = images.shape[0]
    num_rows = int(num_images ** 0.5)
    num_cols = num_images // num_rows
    fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(20, 20))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i], cmap='viridis')
        ax.axis('off')
    plt.subplots_adjust(wspace=0.01, hspace=0.05)  # Adjust the spacing between subplots
    plt.show()
    
def plot_multi_map(resdict,title,index):
    squared_difftotal=(resdict['featuremaps'][index][:,:,:,:].squeeze().cpu().numpy()-resdict['recons'][index][:,:,:,:].squeeze().cpu().numpy())**2
    squared_diffhigh=(resdict['featuremaps'][index][:,:64,:,:].squeeze().cpu().numpy()-resdict['recons'][index][:,:64,:,:].squeeze().cpu().numpy())**2
    squared_diffmid=(resdict['featuremaps'][index][:,64:128,:,:].squeeze().cpu().numpy()-resdict['recons'][index][:,64:128,:,:].squeeze().cpu().numpy())**2
    squared_difflow=(resdict['featuremaps'][index][:,128:,:,:].squeeze().cpu().numpy()-resdict['recons'][index][:,128:,:,:].squeeze().cpu().numpy())**2

    squared_diffs = [squared_difftotal, squared_diffhigh, squared_diffmid, squared_difflow]
    titles = ['Total Squared Difference', 'High Squared Difference', 'Mid Squared Difference', 'Low Squared Difference']
    fig, axes = plt.subplots(1, len(squared_diffs) + 3, figsize=(20, 4))
    axes[0].imshow(Image.open(resdict['paths'][index]))
    axes[0].set_title(f'Original {title}'),axes[0].axis('off')
    
    axes[1].imshow( sum(squared_difftotal)/resdict['stds'][index])
    axes[1].set_title(f'Total Squared Difference/ stds'),axes[1].axis('off')
    
    axes[2].imshow(resdict['distances'][index])
    axes[2].set_title(f'distances (paper)'),axes[2].axis('off')
    
    for i, squared_diff in enumerate(squared_diffs):
        axes[i+3].imshow(sum(squared_diff))
        axes[i+3].set_title(titles[i])
        axes[i+3].axis('off')
    plt.tight_layout(),plt.show()

 define some images for ineference

In [None]:
path_list=[good_img_train_path,good_img_test_path,anomaly_img_test_path]
img_list=[load_image(path,aligned=False) for path in path_list]

featuremaps=[]
recons=[]
stds=[]
distances=[]
with torch.no_grad():
    for img in img_list:
        img=img.unsqueeze(0).to(device)
        outputs = []
        _ = backbone(img)  
        outputs = embedding_concat(embedding_concat(outputs[0],outputs[1]),outputs[2])
        recon, std = transformer(outputs)
        featuremaps.append(outputs)
        recons.append(recon)
        stds.append(std.squeeze().cpu().numpy())
        
        dist = torch.norm(recon - outputs, p = 2, dim = 1, keepdim = True).div(std.abs())
        
        print(dist.shape)
        
        distances.append(dist.squeeze().cpu().numpy())

        ### TODO implementation viszualiasation   of scores , ev with other methods can be done better..

resdict={'paths':path_list,'featuremaps':featuremaps,'recons':recons,'stds':stds, 'distances':distances}

In [None]:
#TODO understand predicitions as in paper 

# dist = torch.norm(recon - outputs, p = 2, dim = 1, keepdim = True).div(std.abs())



# dist = dist.view(batch_size, 1, width, height)

# patch_normed_score = []
# for j in range(4):
#     patch_size = pow(4, j)
#     patch_score = F.conv2d(input=dist, 
#         weight=(torch.ones(1,1,patch_size,patch_size) / (patch_size*patch_size)).to(device), 
#         bias=None, stride=patch_size, padding=0, dilation=1)
#     patch_score = F.avg_pool2d(dist,patch_size,patch_size)
#     patch_score = F.interpolate(patch_score, (width,height), mode='bilinear')
#     patch_normed_score.append(patch_score)
    
# score = torch.zeros(batch_size,1,64,64).to(device)


# for j in range(4):
#     score = embedding_concat(score, patch_normed_score[j])

# score = F.conv2d(input=score, weight=torch.tensor([[[[0.0]],[[0.25]],[[0.25]],[[0.25]],[[0.25]]]]).to(device), bias=None, stride=1, padding=0, dilation=1)

# score = F.interpolate(score, (ground_truth.size(2),ground_truth.size(3)), mode='bilinear')

# heatmap = score.repeat(1,3,1,1)


## Inspect results

Original images

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(10, 4))
axes[0].imshow(Image.open(resdict['paths'][0]))
axes[0].set_title('Good Train Image')
axes[0].axis('off')
axes[1].imshow(Image.open(resdict['paths'][1]))
axes[1].set_title('Good Test Image')
axes[1].axis('off')
axes[2].imshow(Image.open(resdict['paths'][2]))
axes[2].set_title('Anomaly Test Image')
axes[2].axis('off')
plt.tight_layout(),plt.show()

In [None]:
plot_multi_map(resdict,title="train good",index=0)
plot_multi_map(resdict,title="test good",index=1)
plot_multi_map(resdict,title=f"test anomaly {anomaly_category}",index=2)


In [None]:
squared_difftotal=(resdict['featuremaps'][0][:,:,:,:].squeeze().cpu().numpy()-resdict['recons'][0][:,:,:,:].squeeze().cpu().numpy())**2
squared_diffhigh=(resdict['featuremaps'][0][:,:64,:,:].squeeze().cpu().numpy()-resdict['recons'][0][:,:64,:,:].squeeze().cpu().numpy())**2
squared_diffmid=(resdict['featuremaps'][0][:,64:128,:,:].squeeze().cpu().numpy()-resdict['recons'][0][:,64:128,:,:].squeeze().cpu().numpy())**2
squared_difflow=(resdict['featuremaps'][0][:,128:,:,:].squeeze().cpu().numpy()-resdict['recons'][0][:,128:,:,:].squeeze().cpu().numpy())**2

squared_diffs = [squared_difftotal, squared_diffhigh, squared_diffmid, squared_difflow]
titles = ['Total Squared Difference', 'High Squared Difference', 'Mid Squared Difference', 'Low Squared Difference']
fig, axes = plt.subplots(1, len(squared_diffs) + 1, figsize=(20, 4))
axes[0].imshow(Image.open(resdict['paths'][0]))
axes[0].set_title('Original')
axes[0].axis('off')
for i, squared_diff in enumerate(squared_diffs):
    axes[i+1].imshow(sum(squared_diff))
    axes[i+1].set_title(titles[i])
    axes[i+1].axis('off')
plt.tight_layout(),plt.show()




plot_images(resdict['featuremaps'][0][:,:64:,:,:].squeeze().cpu().numpy())
plot_images(resdict['recons'][0][:,:64,:,:].squeeze().cpu().numpy())
plot_images((resdict['featuremaps'][0][:,:64,:,:].squeeze().cpu().numpy()-resdict['recons'][0][:,:64,:,:].squeeze().cpu().numpy())**2)

In [None]:
squared_difftotal=(resdict['featuremaps'][1][:,:,:,:].squeeze().cpu().numpy()-resdict['recons'][1][:,:,:,:].squeeze().cpu().numpy())**2
squared_diffhigh=(resdict['featuremaps'][1][:,:64,:,:].squeeze().cpu().numpy()-resdict['recons'][1][:,:64,:,:].squeeze().cpu().numpy())**2
squared_diffmid=(resdict['featuremaps'][1][:,64:128,:,:].squeeze().cpu().numpy()-resdict['recons'][1][:,64:128,:,:].squeeze().cpu().numpy())**2
squared_difflow=(resdict['featuremaps'][1][:,128:,:,:].squeeze().cpu().numpy()-resdict['recons'][1][:,128:,:,:].squeeze().cpu().numpy())**2

squared_diffs = [squared_difftotal, squared_diffhigh, squared_diffmid, squared_difflow]
titles = ['Total Squared Difference', 'High Squared Difference', 'Mid Squared Difference', 'Low Squared Difference']
fig, axes = plt.subplots(1, len(squared_diffs) + 1, figsize=(20, 4))
axes[0].imshow(Image.open(resdict['paths'][1]))
axes[0].set_title('Original')
axes[0].axis('off')
for i, squared_diff in enumerate(squared_diffs):
    axes[i+1].imshow(sum(squared_diff))
    axes[i+1].set_title(titles[i])
    axes[i+1].axis('off')
plt.tight_layout(),plt.show()


plot_images(resdict['featuremaps'][1][:,:64,:,:].squeeze().cpu().numpy())
plot_images(resdict['recons'][1][:,:64,:,:].squeeze().cpu().numpy())
plot_images((resdict['featuremaps'][1][:,:64,:,:].squeeze().cpu().numpy()-resdict['recons'][1][:,:64,:,:].squeeze().cpu().numpy())**2)

In [None]:
squared_difftotal=(resdict['featuremaps'][2][:,:,:,:].squeeze().cpu().numpy()-resdict['recons'][2][:,:,:,:].squeeze().cpu().numpy())**2
squared_diffhigh=(resdict['featuremaps'][2][:,:64,:,:].squeeze().cpu().numpy()-resdict['recons'][2][:,:64,:,:].squeeze().cpu().numpy())**2
squared_diffmid=(resdict['featuremaps'][2][:,64:128,:,:].squeeze().cpu().numpy()-resdict['recons'][2][:,64:128,:,:].squeeze().cpu().numpy())**2
squared_difflow=(resdict['featuremaps'][2][:,128:,:,:].squeeze().cpu().numpy()-resdict['recons'][2][:,128:,:,:].squeeze().cpu().numpy())**2

squared_diffs = [squared_difftotal, squared_diffhigh, squared_diffmid, squared_difflow]
titles = ['Total Squared Difference', 'High Squared Difference', 'Mid Squared Difference', 'Low Squared Difference']
fig, axes = plt.subplots(1, len(squared_diffs) + 1, figsize=(20, 4))
axes[0].imshow(Image.open(resdict['paths'][2]))
axes[0].set_title('Original')
axes[0].axis('off')
for i, squared_diff in enumerate(squared_diffs):
    axes[i+1].imshow(sum(squared_diff))
    axes[i+1].set_title(titles[i])
    axes[i+1].axis('off')
plt.tight_layout(),plt.show()


plot_images(resdict['featuremaps'][2][:,:64,:,:].squeeze().cpu().numpy())
plot_images(resdict['recons'][2][:,:64,:,:].squeeze().cpu().numpy())
plot_images((resdict['featuremaps'][2][:,:64,:,:].squeeze().cpu().numpy()-resdict['recons'][2][:,:64,:,:].squeeze().cpu().numpy())**2)

In [None]:
squared_diff=(resdict['featuremaps'][2][:,:,:,:].squeeze().cpu().numpy()-resdict['recons'][2][:,:,:,:].squeeze().cpu().numpy())**2



plt.imshow(sum(squared_diff))