In [1]:
import warnings 
warnings.filterwarnings('ignore')
import torch 
import torch.nn as nn 
import torchvision 
import cv2 
import numpy as np 
import matplotlib.pyplot as plt 
from PIL import Image
from tqdm import tqdm 
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader,Dataset
import os 
import yaml 
from sklearn.metrics import roc_curve,auc
from sklearn.metrics import f1_score
from sklearn.metrics import classification_report

from src import Convolution_Auto_Encoder, Mnist_Dataset,MVtecADDataset,Datadir_init
from src import MVtecEncoder,MVtecDecoder,Convolution_Auto_Encoder


In [2]:
cfg = yaml.load(open('./Save_models/MVtecAD4/config.yaml','r'), Loader=yaml.FullLoader)

In [3]:
def preprocess(cfg,augmentation=None):
    #mk save dir 
    try:
        os.mkdir(f"./Save_models/{cfg['save_dir']}")
    except:
        pass
    torch.manual_seed(cfg['seed'])
    data_dir = cfg['Dataset_dir']
    Data_dir = Datadir_init()
    train_dirs = Data_dir.train_load()
    test_dirs,test_labels = Data_dir.test_load()

    train_dset = MVtecADDataset(cfg,train_dirs,Augmentation=augmentation)
    test_dset = MVtecADDataset(cfg,test_dirs,test_labels,Augmentation=augmentation)

    train_loader = DataLoader(train_dset,batch_size=cfg['batch_size'],shuffle=True)
    test_loader = DataLoader(test_dset,batch_size=cfg['batch_size'],shuffle=False)
    return train_loader,test_loader 

In [71]:
#변수 리스트 
normal_vecs = [] 
normal_labels = [] 
#데이터 및 모델 로드 
train_loader,test_loader  = preprocess(cfg)    
model = torch.load(f"./Save_models/{cfg['save_dir']}/best.pt").to(cfg['device'])
encoder = model.encoder
#Inference 
for img,label in train_loader:
    img = img.to(cfg['device']).type(torch.float32)
    
    with torch.no_grad():
        normal_vec =  encoder(img)
        
    normal_vecs.extend(normal_vec.detach().cpu().numpy())
    normal_labels.extend(label.detach().cpu().numpy())
    
normal_vecs = np.array(normal_vecs)
normal_labels = np.array(normal_labels)


In [72]:
from sklearn.preprocessing import MinMaxScaler
minmax = MinMaxScaler()
normalized_vecs = minmax.fit_transform(normal_vecs)

In [73]:
from sklearn.svm import OneClassSVM

model = OneClassSVM()
model.fit(normalized_vecs)

OneClassSVM()

In [83]:
#변수 리스트 
test_vecs = [] 
test_labels = [] 
#데이터 및 모델 로드 
train_loader,test_loader  = preprocess(cfg)    

#Inference 
for img,label in test_loader:
    img = img.to(cfg['device']).type(torch.float32)
    
    with torch.no_grad():
        test_vec =  encoder(img)
        
    test_vecs.extend(test_vec.detach().cpu().numpy())
    test_labels.extend(label.detach().cpu().numpy())
    
test_vecs = np.array(test_vecs)
test_labels = np.array(test_labels)
test_labels = np.where(test_labels==0,test_labels,-1) # Anomaly:-1, normal:0
test_labels = np.where(test_labels==1,test_labels,1) # Anomaly:-1, normal:1

In [84]:
normalized_test_vecs = minmax.transform(test_vecs)
y_pred = model.predict(normalized_test_vecs)

In [85]:
print(classification_report(y_pred,test_labels))

              precision    recall  f1-score   support

          -1       0.00      0.00      0.00        40
           1       0.64      1.00      0.78        70

    accuracy                           0.64       110
   macro avg       0.32      0.50      0.39       110
weighted avg       0.40      0.64      0.49       110



1이 normal인지 anomaly인지 확인 필요 