In [1]:
import torch
import torch.nn as nn
import numpy as np

import torchvision
import torchvision.transforms as transforms
from PIL import Image
import os
from datetime import datetime
import sys
from sklearn.utils import shuffle

import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, random_split
from utils import MVTecMetaDataset, MVTecDataset
from models import MVTecLearner

device = 'cuda:0'
data_dir = './data/data'

In [2]:
def criterion(x, y):
    remove_zero_losses = (x!=y)
    x = x[remove_zero_losses]
    y = y[remove_zero_losses]
    loss = -(x.log() * y + (1 - x).log() * (1 - y))
    return loss.mean()

In [3]:
def TestDataset(target_class, num_train_points):

    normal_list_dir = [os.path.join(data_dir, target_class, 'train', 'good'), os.path.join(data_dir, target_class, 'test', 'good')]

    test_dir = os.path.join(data_dir, target_class, 'test')
    test_subfolders = next(os.walk(test_dir))[1]

    abnormal_list_dir=[]

    for item in test_subfolders:
        if item != 'good':
            abnormal_list_dir.append(os.path.join(data_dir, target_class, 'test', item))

    NormalImgs = MVTecMetaDataset(normal_list_dir)
    AbnormalImgs = MVTecMetaDataset(abnormal_list_dir)
    
    train_normal_set, val_normal_set = random_split(NormalImgs, [num_train_points, len(NormalImgs)-num_train_points])    
    train_abnormal_set, val_abnormal_set = random_split(AbnormalImgs, [num_train_points, len(AbnormalImgs)-num_train_points])
    
    return train_normal_set, train_abnormal_set, val_normal_set, val_abnormal_set

In [4]:
def evaluateMvtecMaml(target_class, num_train_points, num_grad_update=1, maml=True):
    
    # num_grad_update = K  for evaluation

    mvtec_learner = MVTecLearner(device=device)

    if maml:
        if target_class == 'capsule':
            loadpath = "./mvtec_saves/mvtec_targetcapsule_n2_k5_lr1e-07_final.pth"
        elif target_class == 'pill':
            loadpath = "./mvtec_saves/mvtec_targetpill_n2_k5_lr1e-07_final.pth"
        elif target_class == 'zipper':
            loadpath = "./mvtec_saves/mvtec_targetzipper_n2_k5_lr1e-07_final.pth"
        elif target_class == 'metal_nut':
            loadpath = "./mvtec_saves/mvtec_targetmetal_nut_n2_k5_lr1e-07_final.pth"
        mvtec_learner.load_state_dict(torch.load(loadpath, map_location=device))    

    batch_size      = 2*num_train_points
    lr_a            = 0.01

    train_normal_set, train_abnormal_set, val_normal_set, val_abnormal_set  = TestDataset(target_class, num_train_points)
    
    count_correct_pred = 0
    count_total_pred   = 0
    
    # 2.1 sample K datapoints from Ti
    
    num_train_batch = 5
    
    normal_sampler = DataLoader(train_normal_set, batch_size=int(num_train_points/num_train_batch), shuffle=True)
    abnormal_sampler = DataLoader(train_abnormal_set, batch_size=int(num_train_points/num_train_batch), shuffle=True)    
    
    fast_weights = mvtec_learner.copy_model_weights()

    for j in range(num_grad_update):
        for i in range(num_train_batch):
            normal_imgs, _ = next(iter(normal_sampler))
            abnormal_imgs, _ = next(iter(abnormal_sampler))                    
            
            X_batch = torch.cat((normal_imgs, abnormal_imgs), dim=0).to(device)            
            Y_batch = torch.tensor(np.concatenate((np.zeros([len(normal_imgs)]),np.ones([len(abnormal_imgs)]))), dtype=torch.float, device=device).view(-1,1)
            Y_pred = mvtec_learner.forward_fast_weights(X_batch, fast_weights)
            
            train_loss = criterion(Y_pred, Y_batch)
            grad = torch.autograd.grad(train_loss, fast_weights, create_graph=True)
            fast_weights = mvtec_learner.update_fast_grad(fast_weights, grad, lr_a)  

    # 3. evaluation
    
    num_eval_batch = 4
    
    normal_sampler = DataLoader(val_normal_set, batch_size=int(len(val_normal_set)/num_eval_batch), shuffle=False)
    abnormal_sampler = DataLoader(val_abnormal_set, batch_size=int(len(val_abnormal_set)/num_eval_batch), shuffle=False)
    
    correct_sum=0
    total_sum = 0
    
    for i in range(num_eval_batch):
        normal_imgs, _ = next(iter(normal_sampler))
        abnormal_imgs, _ = next(iter(abnormal_sampler))    

        X_batch_eval = torch.cat((normal_imgs, abnormal_imgs), dim=0).to(device)
        Y_batch_eval = torch.tensor(np.concatenate((np.zeros([len(normal_imgs)]),np.ones([len(abnormal_imgs)]))), dtype=torch.float, device=device).view(-1,1)

        Y_pred_eval = mvtec_learner.forward_fast_weights(X_batch_eval, fast_weights)
        Y_pred_eval = (Y_pred_eval > 0.5).float()

        corr_pred  = (Y_batch_eval == Y_pred_eval).int().sum().item()
        total_pred = len(Y_batch_eval)
        
        correct_sum+=corr_pred
        total_sum+=total_pred

    print("PREDICTION ACCURACY = {}".format(correct_sum/total_sum))

In [5]:
evaluateMvtecMaml(target_class='capsule', num_train_points= 5, num_grad_update= 1, maml=True)

PREDICTION ACCURACY = 0.5617647058823529


In [6]:
evaluateMvtecMaml('capsule', 5,1, maml=False)

PREDICTION ACCURACY = 0.3058823529411765


In [7]:
evaluateMvtecMaml('zipper', 5,1, maml=True)

PREDICTION ACCURACY = 0.7021276595744681


In [8]:
evaluateMvtecMaml('zipper', 5,1, maml=False)

PREDICTION ACCURACY = 0.42819148936170215


In [9]:
evaluateMvtecMaml('pill', 5,1, maml=True)

PREDICTION ACCURACY = 0.6792452830188679


In [10]:
evaluateMvtecMaml('pill', 5,1, maml=False)

PREDICTION ACCURACY = 0.3938679245283019


In [11]:
evaluateMvtecMaml('metal_nut', 5,1, maml=True)

PREDICTION ACCURACY = 0.4691358024691358


In [12]:
evaluateMvtecMaml('metal_nut', 5,1, maml=False)

PREDICTION ACCURACY = 0.28703703703703703
