In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import sys


module_path = os.path.abspath(os.path.join('../..'))

if module_path not in sys.path:
    sys.path.append(module_path)

from nnood.preprocessing.normalisation import IMAGENET_STATS
from nnood.paths import raw_data_base, preprocessed_data_base

def wl_to_lh(window, level):
    low = level - window / 2
    high = level + window / 2
    return low,high

def display_image(img, phys_size=None, window=None, level=None, existing_ax=None):

    if window is None:
        window = np.max(img) - np.min(img)

    if level is None:
        level = window / 2 + np.min(img)

    low,high = wl_to_lh(window,level)

    if existing_ax is None:
        # Display the orthogonal slices
        fig, axes = plt.subplots(figsize=(14, 8))
    else:
        axes = existing_ax

    axes.imshow(img, clim=(low, high), extent= None if phys_size is None else (0, phys_size[0], phys_size[1], 0))

    if existing_ax is None:
        plt.show()
        
def print_stats(arr):
        print(np.mean(arr),', ',np.std(arr))
        print(np.min(arr), '-', np.max(arr))
        print(arr.shape)
        
imagenet_channels_stats = [IMAGENET_STATS[list(IMAGENET_STATS.keys())[i]] for i in range(len(IMAGENET_STATS))]

def unnormalise(image):
    return np.stack([image[i] * imagenet_channels_stats[i][1] + imagenet_channels_stats[i][0] for i in range(image.shape[0])])

def get_fig_ax(ncols, nrows):
    return plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * 5, nrows*5))


In [None]:
from nnood.paths import default_data_identifier, default_plans_identifier
from nnood.utils.file_operations import load_pickle
from nnood.training.dataloading.dataset_loading import load_dataset_filenames, load_npy_or_npz

dataset_name = 'mvtec_ad_bottle'

plans_file = Path(preprocessed_data_base, dataset_name, default_plans_identifier)
test_image_dir = Path(preprocessed_data_base, dataset_name, default_data_identifier + '_stage0')

plans = load_pickle(plans_file)

dataset = load_dataset_filenames(test_image_dir, plans['dataset_properties']['sample_identifiers'])

In [None]:
test_img = load_npy_or_npz(dataset['normal_001']['data_file'], 'r')

In [None]:
from skimage import filters
from skimage.segmentation import flood
from skimage.morphology import binary_opening
from skimage.measure import label

In [None]:
def make_hypersphere_mask(radius: int, dims: int):
    
    L = np.arange(-radius, radius + 1)
    mg = np.meshgrid(*([L] * dims))
    return np.sum([D ** 2 for D in mg], axis=0) <= radius ** 2

make_hypersphere_mask(2, 4).astype(int)

In [None]:
import itertools


def get_object_mask(img):
        
    # Excluding channels dimension
    num_channels = img.shape[0]
    image_shape = np.array(img.shape)[1:]
    
    # 5% of each dimension length
    corner_lens = image_shape // 20
    
    # Used for upper_bounds, so use dim lengths
    img_corners = list(itertools.product(*[[0, s] for s in image_shape]))
    
    num_samples_per_c = np.product(corner_lens) // 2
    
    all_corner_ranges = [[(0, l) if c == 0 else (c - l, c) for c, l in zip(c_coord, corner_lens)]
                         for c_coord in img_corners]
    corner_patch_slices = [tuple([slice(lb, ub) for lb, ub in cr]) for cr in all_corner_ranges]
    
    opening_structure_r = max(np.median(image_shape) // 500, 1)
    opening_structure = make_hypersphere_mask(opening_structure_r, len(image_shape))
    num_corner_seed_points = 3 ** len(image_shape)

    masks = []
    
    for i in range(num_channels):
        
        sobel_channel = filters.sobel(img[i])                                                           
        
        curr_channel_masks = []
        
        for c_c, c_r, c_s in zip(img_corners, all_corner_ranges, corner_patch_slices):
            
            patch_tolerance = sobel_channel[c_s].std()
            
            for _ in range(num_corner_seed_points):
                random_c = tuple([np.random.randint(lb, ub) for lb, ub in c_r])
            
                curr_channel_masks.append(flood(sobel_channel, random_c, tolerance=patch_tolerance))
        
        masks.append(np.any(np.stack(curr_channel_masks), axis=0))
        
    bg_mask = np.all(np.stack(masks), axis=0)
    fg_mask = np.logical_not(bg_mask)
    
    #
    label_fg_mask = label(fg_mask)
    region_sizes = np.bincount(label_fg_mask.flatten())
    # Zero size of background, so is ignored when finding biggest region
    region_sizes[0] = 0
    biggest_region_label = np.argmax(region_sizes)
    
    # Apply binary opening to mask of largest object, to smooth out edges
    return binary_opening(label_fg_mask == biggest_region_label, footprint=opening_structure)

In [None]:
display_image(get_object_mask(test_img).astype(int))

In [None]:
from nnood.data.dataset_conversion.convert_mvtec import CLASS_NAMES, HAS_UNIFORM_BACKGROUND
from time import time

from nnood.preprocessing.foreground_mask import get_object_mask

test_classes = ['toothbrush'] # HAS_UNIFORM_BACKGROUND

num_examples = 40
fig, ax = get_fig_ax(2, len(test_classes) * num_examples)

start = time()

for c_n_i in range(len(test_classes)):
    
    class_name = test_classes[c_n_i]
    
    curr_dataset_name = 'mvtec_ad_' + class_name
    curr_plans_file = Path(preprocessed_data_base, curr_dataset_name, default_plans_identifier)
    curr_image_dir = Path(preprocessed_data_base, curr_dataset_name, default_data_identifier + '_stage0')

    curr_plans = load_pickle(curr_plans_file)

    curr_dataset = load_dataset_filenames(curr_image_dir, curr_plans['dataset_properties']['sample_identifiers'])
    
    test_imgs = [load_npy_or_npz(curr_dataset[f'normal_{i:03d}']['data_file'], 'r') for i in range(num_examples)]
    
    test_imgs_masks = [get_object_mask(i).astype(int) for i in test_imgs]
    
    for j in range(num_examples):
        display_image(np.moveaxis(unnormalise(test_imgs[j]), 0, -1), existing_ax=ax[c_n_i * num_examples + j][0])
        display_image(test_imgs_masks[j], existing_ax=ax[c_n_i * num_examples + j][1])

print('took ', time() - start)

In [None]:
from skimage.io import imread
from scipy.stats import energy_distance

test_img = imread(Path(raw_data_base, 'mvtec_ad_grid', 'imagesTr', 'normal_000_0000.png'))
test_img2 = imread(Path(raw_data_base, 'mvtec_ad_grid', 'imagesTr', 'normal_001_0000.png'))

energy_distance(test_img.flatten(), test_img2.flatten())

In [None]:


INTENSITY_LOGISTIC_PARAMS = {'bottle':(1/12, 24), 'cable':(1/12, 24), 'capsule':(1/2, 4), 'hazelnut':(1/12, 24), 'metal_nut':(1/3, 7), 
            'pill':(1/3, 7), 'screw':(1, 3), 'toothbrush':(1/6, 15), 'transistor':(1/6, 15), 'zipper':(1/6, 15),
            'carpet':(1/3, 7), 'grid':(1/3, 7), 'leather':(1/3, 7), 'tile':(1/3, 7), 'wood':(1/6, 15)}

has_uniform_background = ['bottle', 'capsule', 'hazelnut', 'metal_nut', 'pill', 'screw', 'toothbrush', 'zipper']

num_test_samples = 10

class_stds = []

for c_n_i in range(len(CLASS_NAMES)):
    
    class_name = CLASS_NAMES[c_n_i]
    print('Starting ', class_name)
    
    curr_dataset_name = 'mvtec_ad_' + class_name
    
    raw_data_folder = Path(raw_data_base, curr_dataset_name, 'imagesTr')
    
    init_imgs = [imread(raw_data_folder / f'normal_00{i}_0000.png') for i in range(num_test_samples)]

    test_imgs = [img[None] if len(img.shape) == 2 else np.moveaxis(img, -1, 0) for img in init_imgs]
    
    avg_imgs = [np.mean(img, axis=0) for img in test_imgs]
    
    if class_name in has_uniform_background:
        avg_imgs = [a_i[get_object_mask(i)] for a_i, i in zip(avg_imgs, test_imgs)]
        
    avg_imgs_flat = [i.flatten() for i in avg_imgs]
        
    energy_distances = []
    
    for i in range(len(avg_imgs)):
        for j in range(len(avg_imgs)):
            if i == j:
                continue
            
            energy_distances.append(energy_distance(avg_imgs_flat[i], avg_imgs_flat[j]))
            
    class_stds.append(np.mean(energy_distances))
    
print(class_stds)

In [None]:
list(zip(CLASS_NAMES, class_stds, [INTENSITY_LOGISTIC_PARAMS[c_n][1] for c_n in CLASS_NAMES]))

In [None]:
from nnood.data.dataset_conversion.convert_mvtec import OBJECTS, TEXTURES

plt.xlabel('Average energy distance between images (filtering foreground if possible)')
plt.ylabel('Logistic param 1')

std_dict = {}
for i in range(len(CLASS_NAMES)):
    c_n = CLASS_NAMES[i]
    
    std = class_stds[i]
    l_p = INTENSITY_LOGISTIC_PARAMS[c_n][1]
    std_dict[c_n] = (std, l_p)

# exclude bottle as rotation invariant
unaligned_objects = ['hazelnut', 'metal_nut', 'screw']
has_foreground = [cn for cn in has_uniform_background if cn not in unaligned_objects]
other_objects = [cn for cn in OBJECTS if cn not in unaligned_objects and cn not in has_foreground] 

for cat, m, l in [(TEXTURES, 'gx', 'Textures'), (unaligned_objects, 'ro', 'Unaligned'),
               (has_foreground, 'bo', 'Has foreground'), (other_objects, 'go', 'Other objects')]:
    
    plt.plot([std_dict[c_n][0] for c_n in cat],
             [std_dict[c_n][1] for c_n in cat], m, label=l)

plt.legend()