In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import pandas as pd
import cv2
import json
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image, ImageOps

from models import CAE_v1 
from models import CAE_v2
from models import Classifier_v1
from models import Classifier_v2
import method as MM

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay, classification_report

sns.set_theme(style="white", context="talk")

plt.rc('font', size=12)        # 기본 폰트 크기
plt.rc('axes', labelsize=12)   # x,y축 label 폰트 크기
plt.rc('xtick', labelsize=12)  # x축 눈금 폰트 크기 
plt.rc('ytick', labelsize=12)  # y축 눈금 폰트 크기

    
device = torch.device("cuda" )
print(torch.cuda.get_device_name(device))
print(f"Using device: {device}")

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import Subset,ConcatDataset, DataLoader


mean=[0.7101, 0.4827, 0.3970]#val mean
std=[0.2351, 0.2195, 0.1862]#val std

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

ag_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 좌우 반전
    transforms.RandomVerticalFlip(),  # 상하 반전
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # 밝기, 대비, 채도, 색조 
    transforms.RandomGrayscale(p=0.75),# 랜덤 그레이 스케일
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])



path_train='output_dataset/train'

path_val='output_dataset/val'

path_test='output_dataset/test'


trainset_1 = datasets.ImageFolder(root=path_train, transform=transform)
print(trainset_1.class_to_idx)

trainset_2 = datasets.ImageFolder(root=path_train, transform=ag_transform)
print(trainset_2.class_to_idx)

trainset=ConcatDataset([trainset_1,trainset_2])

valset = datasets.ImageFolder(root=path_val, transform=transform)
print(valset.class_to_idx)
testset = datasets.ImageFolder(root=path_test, transform=transform)
print(testset.class_to_idx)

partition = {'train': trainset, 'val':valset, 'test':testset}


In [None]:
results=[]



In [None]:
torch.cuda.empty_cache()

CAE_v2_model =CAE_v2.CAE()
CAE_v2_train_losses,CAE_v2_val_losses=MM.CAE_train_eval(CAE_v2_model,50,partition,results)


In [None]:
f=plt.figure(figsize=(10,5))
for type,num in zip([CAE_v2_train_losses,CAE_v2_val_losses],[1,2]):
        ax=f.add_subplot(1,2,num)
        ax.plot( type, label=('train_losses' if num==1 else 'val_losses'))
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('mse')
        ax.set_title(('train_losses' if num==1 else 'val_losses'))
        ax.grid()
        ax.legend()
cae_results=[]
cae_results.append({
        'model': 'CAE_v2',
        'train_losses':CAE_v2_train_losses,
        'val_losses': CAE_v2_val_losses,
    })


In [None]:
import json
path='model_save/'

#with open(path+cae_results[0]['model']+'.json', 'w') as file:
#        json.dump(cae_results[0], file)
        
#torch.save(CAE_v2_model, 'model_save/cae_v2.pth')
CAE_v2_model = torch.load('model_save/cae_v2.pth')


In [None]:
CNN_based_on_CAE_v2=Classifier_v2.CAEClassifier(CAE_v2_model)
CNN_based_on_CAE_v2=CNN_based_on_CAE_v2.to("cuda")
MM.model_train_eval(CNN_based_on_CAE_v2,'CNN_based_on_CAE_v2',50,partition,results)


In [None]:
MM.plot_loss_and_accuracy(results)


In [None]:
path='model_save/'


torch.save(CNN_based_on_CAE_v2, path+'CNN_based_on_CAE_v2.pth')
CNN_based_on_CAE_v2 = torch.load(path+'CNN_based_on_CAE_v2.pth')
CNN_based_on_CAE_v2.eval()
test_acc,test_pred,test_real = MM.test(CNN_based_on_CAE_v2, partition)
print(test_acc)



import json
for i in range(len(results)):
    results[i]['test_pred']=torch.cat(results[i]['test_pred']).cpu().reshape(-1).tolist()
    results[i]['test_real']=torch.cat(results[i]['test_real']).cpu().reshape(-1).tolist()
    
    with open(path+results[i]['model']+'.json', 'w') as file:
        json.dump(results[i], file)
        


