In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import pandas as pd
import cv2
import json
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image, ImageOps
import time

from models import CAE_v1 
from models import CAE_v2
from models import Classifier_v1
from models import Classifier_v2
import method as MM
import Grad_CAM 

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay, classification_report

sns.set_theme(style="white", context="talk")

plt.rc('font', size=12)        
plt.rc('axes', labelsize=12)   
plt.rc('xtick', labelsize=12) 
plt.rc('ytick', labelsize=12) 
    
device = torch.device("cuda")
print(torch.cuda.get_device_name(device))
print(f"Using device: {device}")

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import Subset,ConcatDataset, DataLoader


mean=[0.7101, 0.4827, 0.3970]#val mean
std=[0.2351, 0.2195, 0.1862]#val std

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

ag_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),  
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), 
    transforms.RandomGrayscale(p=0.75),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])



path_train='output_dataset/train'

path_val='output_dataset/val'

path_test='output_dataset/test'


trainset_1 = datasets.ImageFolder(root=path_train, transform=transform)
print(trainset_1.class_to_idx)

trainset_2 = datasets.ImageFolder(root=path_train, transform=ag_transform)
print(trainset_2.class_to_idx)

trainset=ConcatDataset([trainset_1,trainset_2])

valset = datasets.ImageFolder(root=path_val, transform=transform)
print(valset.class_to_idx)
testset = datasets.ImageFolder(root=path_test, transform=transform)
print(testset.class_to_idx)

partition = {'train': trainset, 'val':valset, 'test':testset}


path='model_save/'


In [None]:
path='model_save/'

densenet = torch.load(path+'densenet.pth')
densenet.eval()

efficientnet = torch.load(path+'efficientnet.pth')
efficientnet.eval()

resnet = torch.load(path+'resnet.pth')
resnet.eval()

vgg = torch.load(path+'vgg.pth')
vgg.eval()

SWIN_T = torch.load(path+'SWIN_T.pth')
SWIN_T.eval()

Cnn_based_cae_v1 = torch.load(path+'CNN_based_on_CAE_v1.pth')
Cnn_based_cae_v1.eval()

Cnn_based_cae_v2 = torch.load(path+'CNN_based_on_CAE_v2.pth')
Cnn_based_cae_v2.eval()



name_list=["CNN_based_on_CAE_v1","CNN_based_on_CAE_v2","densenet","efficientnet","resnet","vgg","SWIN_T"]
total_result=[]
model_list=[Cnn_based_cae_v1,Cnn_based_cae_v2,densenet,efficientnet,resnet,vgg,SWIN_T]
for i,model in zip(name_list,model_list):
    with open(path+i+'.json', 'r') as file:
        loaded_data = json.load(file)
        
    start = time.time()
    test_acc,test_pred,test_real = MM.test(model, partition)
    end = time.time()
    
    total_result.append(loaded_data)
    if i=="CNN_based_on_CAE_v1":
        total_result[-1]['model']="CNN_based_on_CAE"
    if i=="CNN_based_on_CAE_v2":
        total_result[-1]['model']="Residual_Feature_Classifier"
    
    total_result[-1]['test_acc']=test_acc
    total_result[-1]['test_pred']=test_pred
    total_result[-1]['test_real']=test_real
    total_result[-1]['inference_time']=end - start
    total_result[-1]['num_parameter']=sum(p.numel() for p in model.parameters())
    
MM.plot_loss_and_accuracy(total_result)

In [None]:
from sklearn.metrics import f1_score, accuracy_score, precision_score

for i in range(len(total_result)):
    print(total_result[i]['model']+": ", round(accuracy_score(torch.cat(total_result[i]['test_real']).cpu().detach(),torch.cat(total_result[i]['test_pred']).cpu().detach()),4))
    print(total_result[i]['model']+": ", round(f1_score(torch.cat(total_result[i]['test_real']).cpu().detach(),torch.cat(total_result[i]['test_pred']).cpu().detach()),4))
    print(total_result[i]['model']+": ", round(precision_score(torch.cat(total_result[i]['test_real']).cpu().detach(),torch.cat(total_result[i]['test_pred']).cpu().detach()),4))
    print()
    
    

In [None]:
model_list=[Cnn_based_cae_v1,Cnn_based_cae_v2,densenet,efficientnet,resnet,vgg,SWIN_T]
name_list=["CNN_based_on_CAE_v1","CNN_based_on_CAE_v2","densenet","efficientnet","resnet","vgg","SWIN_T"]

inference_time={
    "CNN_based_on_CAE_v1":[],
    "CNN_based_on_CAE_v2":[],
    "densenet":[],
    "efficientnet":[],
    "resnet":[],
    "vgg":[],
    "SWIN_T":[]
}
for infer_model,name in zip(model_list,name_list):
    for l in range(10):
        torch.cuda.empty_cache()

        torch.cuda.synchronize()
        start = time.time()
        
        MM.test_for_inference_time(infer_model, partition)
        
        torch.cuda.synchronize()
        end = time.time()
        inference_time[name].append(end-start)
    print(inference_time[name])
        

In [None]:
total_info=pd.DataFrame(columns=name_list, index=['mean','std','acc'])

for i in range(len(total_result)):
    total_info.loc['mean',total_result[i]['model']]=np.mean(inference_time[total_result[i]['model']])
    total_info.loc['std',total_result[i]['model']]=np.std(inference_time[total_result[i]['model']])
    total_info.loc['acc',total_result[i]['model']]=total_result[i]['test_acc']



sns.set_theme(style="darkgrid")
plt.figure()
sns.relplot(data=total_info.transpose().reset_index(drop=False),
                x='mean',y='acc',style='index',hue="index",s=300)
total_info

In [None]:

with open(path+'CAE_v2.json', 'r') as file:
        CAE_loss = json.load(file)
f=plt.figure(figsize=(10,5))
for type,num in zip([CAE_loss['train_losses'],CAE_loss['val_losses']],[1,2]):
        ax=f.add_subplot(1,2,num)
        ax.plot( type, label=('train_losses' if num==1 else 'val_losses'))
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('mse')
        ax.set_title(('train_losses' if num==1 else 'val_losses'))
        ax.grid()
        ax.legend()


In [None]:
def inverse_nor(image):
    mean=torch.tensor([0.7101, 0.4827, 0.3970])#val mean
    std=torch.tensor([0.2351, 0.2195, 0.1862])#val std
    raw_=(image * std) + mean
    return raw_    



In [None]:
torch.cuda.empty_cache()
    

path='model_save/'

model = torch.load(path+'CNN_based_on_CAE_v2.pth')

model.eval()

grad_cam = Grad_CAM.GradCam(model)

testloader = torch.utils.data.DataLoader(partition['test'],
                                        batch_size=1,
                                             shuffle=True,num_workers=2)
i=0
total_cam=[]
total_pred=[]
total_label=[]
total_images=[]
for data in testloader:
    images, labels = data
    if labels==1:
    
        images = images.to('cuda')
        labels = labels.to('cuda')     
        cam, pred = grad_cam(images)
    
        total_cam.append(cam)
        total_label.append(labels.cpu().detach().numpy())
        total_pred.append([float(f"{AA:.4f}") for AA in pred.cpu().detach().numpy()[0]  ])
        total_images.append(images)

        i+=1
        
    if i==20 :
        break



In [None]:
for a in range(20):
    f=plt.figure(figsize=(10,5))
    plt.axis('off')
    cam=np.uint8(255*cv2.resize(total_cam[a].cpu().detach().numpy(),None, fx=32,fy=32, interpolation=cv2.INTER_LINEAR))#cv2.INTER_NEAREST))#cv2.INTER_CUBIC))#cv2.INTER_LANCZOS4))
    cam=255-cam
    
    heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
    original=total_images[a][0].transpose(0,2).transpose(0,1).cpu().detach()
    raw_image=inverse_nor(original)

    plt.title('real:'+str(total_label[a][0])+', pred:'+str(total_pred[a]))
    
    ax1=f.add_subplot(1, 2, 1)
    ax1.imshow(raw_image)
    ax1.set_title('input Image')
    ax1.axis('off')

    #plt.subplot(1, 3, 2)
    #plt.imshow(heatmap)
    #plt.title('Grad-CAM Heatmap')


    ax2=f.add_subplot(1, 2, 2)    
    ax2.imshow(raw_image)
    cax=ax2.imshow(heatmap, cmap='jet',  alpha=0.4) 
    ax2.set_title('result Image')
    ax2.axis('off')
    
    f.colorbar(plt.cm.ScalarMappable(cmap='jet'), ax=[ax1, ax2],location='bottom',orientation='horizontal')
#cbar.set_label("COLORMAP_HOT Intensity")
    