In [None]:
# Importing libraries
import torch
import os, yaml
from timm import create_model
from data import create_dataset
from models import MemSeg

import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

import warnings
warnings.filterwarnings("ignore")

In [None]:
# Dataset
def load_dataset(object_name):
    global testset
    testset = create_dataset(
                        datadir                = cfg['DATASET']['datadir'],
                        target                 = object_name, 
                        train                  = False,
                        resize                 = cfg['DATASET']['resize'],
                        texture_source_dir     = cfg['DATASET']['texture_source_dir'],
                        structure_grid_size    = cfg['DATASET']['structure_grid_size'],
                        transparency_range     = cfg['DATASET']['transparency_range'],
                        perlin_scale           = cfg['DATASET']['perlin_scale'], 
                        min_perlin_scale       = cfg['DATASET']['min_perlin_scale'], 
                        perlin_noise_threshold = cfg['DATASET']['perlin_noise_threshold']
                        )

# Model
def load_model(object_name, device='cpu'):
    global model, cfg
    cfg = yaml.load(open(f'./configs/{object_name.split("-")[-1]}.yaml','r'), Loader=yaml.FullLoader)
    memory_bank = torch.load(f'./saved_model/original/{object_name}/memory_bank.pt')
    memory_bank.device = device
    
    for k in memory_bank.memory_information.keys():
        memory_bank.memory_information[k] = memory_bank.memory_information[k].to(device)

    encoder = encoder = create_model(cfg['MODEL']['feature_extractor_name'], 
                                    pretrained=True, 
                                    features_only = True
                                    )
    model = MemSeg(memory_module=memory_bank, encoder=encoder)
    model.load_state_dict(torch.load(f'./saved_model/original/{object_name}/best_model.pt'))

    return model

def minmax_scaling(img):
    return (((img - img.min()) / (img.max() - img.min())) * 255).to(torch.uint8)

def visualize_output(idx):
    input_i, mask_i, target_i = testset[idx]

    output_i = model(input_i.unsqueeze(0)).detach()
    output_i = torch.nn.functional.softmax(output_i, dim=1)

    fig, ax = plt.subplots(1,4, figsize=(15,10))
    
    ax[0].imshow(minmax_scaling(input_i.permute(1,2,0)))
    ax[0].set_title('Input: {}'.format('Normal' if target_i == 0 else 'Abnormal'))
    ax[1].imshow(mask_i, cmap='gray')
    ax[1].set_title('Ground Truth')
    ax[2].imshow(output_i[0][1], cmap='gray')
    ax[2].set_title('Predicted Mask')
    ax[3].imshow(minmax_scaling(input_i.permute(1,2,0)), alpha=1)
    ax[3].imshow(output_i[0][1], cmap='gray', alpha=0.5)
    ax[3].set_title(f'Input X Predicted Mask')
    
    plt.show()

In [None]:
# Widgets

model_list = widgets.Dropdown(
    options=os.listdir('./saved_model/original/'),
    value=f'MemSeg-Original-capsule',
    description='Model:',
    disabled=False,
)
model_button = widgets.Button(description="Select Model")
output = widgets.Output()


@output.capture()
def on_model_button_clicked(b):
    clear_output(wait=True)
    load_model(object_name=model_list.value)
    load_dataset(object_name=model_list.value.split('-')[-1])
    # vizualization
    file_list = widgets.Dropdown(
            options=[(file_path, i) for i, file_path in enumerate(testset.file_list)],
            value=0,
            description='image:',
        )

    widgets.interact(visualize_output, idx=file_list)

model_button.on_click(on_model_button_clicked)
display(widgets.HBox([model_list, model_button]), output)

### TO RUN LOCAL HOST
voila "inference.ipynb" --port 8866 --Voila.ip 127.0.0.1