In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

import numpy as np
import torchvision
from torchvision import datasets,models, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import time
import os
import copy

import cv2

plt.ion()

In [None]:
RESIZE_SIZE = 128
CLASS_NUM = 8     # 분류할 class 개수
EPOCH = 25

#data_dir = './data/RaFD+PE'
#data_dir = './data/RaFD+PE+WEB'
data_dir = './data/ALL'

PE_data_dir = data_dir + '/PE'
RaFD_data_dir = data_dir + '/RaFD'
GAN_data_dir = data_dir + '/GAN'

# Normalized Image path
save_dir = "./result/tensor_image/normalization/grayscale/"

# Load Dataset
## Dataset Transform

In [None]:
RaFD_MEAN = [0.485, 0.456, 0.406]
RaFD_STD = [0.229, 0.224, 0.225]

In [None]:
RaFD_transforms = {
    'train' : transforms.Compose([
        transforms.Resize(RESIZE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(RaFD_MEAN, RaFD_STD)
    ]),
    'val': transforms.Compose([
        transforms.Resize(RESIZE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(RaFD_MEAN, RaFD_STD)
    ]),
    'test': transforms.Compose([
        transforms.Resize(RESIZE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(RaFD_MEAN, RaFD_STD)
    ])
}

PE_transforms = {
    'train' : transforms.Compose([
        transforms.Resize(RESIZE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(RaFD_MEAN, RaFD_STD)
    ]),
    'val': transforms.Compose([
        transforms.Resize(RESIZE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(RaFD_MEAN, RaFD_STD)
    ]),
    'test': transforms.Compose([
        transforms.Resize(RESIZE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(RaFD_MEAN, RaFD_STD)
    ])
}

GAN_transforms = {
    'train' : transforms.Compose([
        transforms.Resize(RESIZE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(RaFD_MEAN, RaFD_STD)
    ]),
    'val': transforms.Compose([
        transforms.Resize(RESIZE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(RaFD_MEAN, RaFD_STD)
    ]),
    'test': transforms.Compose([
        transforms.Resize(RESIZE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(RaFD_MEAN, RaFD_STD)
    ])
}


### without gan data

In [None]:
rafd_datasets = {x: datasets.ImageFolder(os.path.join(RaFD_data_dir, x), RaFD_transforms[x])
                     for x in ['train', 'val', 'test']}
pe_datasets = {x: datasets.ImageFolder(os.path.join(PE_data_dir, x), PE_transforms[x])
                  for x in ['train', 'val', 'test']}

image_datasets = {x: rafd_datasets[x] + pe_datasets[x] for x in ['train', 'val', 'test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
                  for x in ['train', 'val', 'test']}

dataset_sizes = {x: len(image_datasets[x])for x in ['train', 'val', 'test']}
class_names = rafd_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
rafd_dataloaders = {x: torch.utils.data.DataLoader(rafd_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
                  for x in ['train', 'val', 'test']}
pe_dataloaders = {x: torch.utils.data.DataLoader(pe_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
                  for x in ['train', 'val', 'test']}


### with gan data

In [None]:
rafd_datasets = {x: datasets.ImageFolder(os.path.join(RaFD_data_dir, x), RaFD_transforms[x])
                     for x in ['train', 'val', 'test']}
pe_datasets = {x: datasets.ImageFolder(os.path.join(PE_data_dir, x), PE_transforms[x])
                  for x in ['train', 'val', 'test']}
gan_datasets = {x: datasets.ImageFolder(os.path.join(GAN_data_dir, x), GAN_transforms[x])
                   for x in ['train', 'val', 'test']}

image_datasets = {x: rafd_datasets[x] + pe_datasets[x] + gan_datasets[x] for x in ['train', 'val', 'test']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
                  for x in ['train', 'val', 'test']}

dataset_sizes = {x: len(image_datasets[x])for x in ['train', 'val', 'test']}
class_names = rafd_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
rafd_dataloaders = {x: torch.utils.data.DataLoader(rafd_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
                  for x in ['train', 'val', 'test']}
pe_dataloaders = {x: torch.utils.data.DataLoader(pe_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
                  for x in ['train', 'val', 'test']}
gan_dataloaders = {x: torch.utils.data.DataLoader(gan_datasets[x], batch_size=4,
                                                 shuffle=True, num_workers=4)
                  for x in ['train', 'val', 'test']}

In [None]:
dataset_sizes

## Visualize datasets

In [None]:
def imshow_tensor(inp, title = None) :
    inp = inp.numpy().transpose((1, 2, 0))
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    
    if title is not None :
        plt.title(title)
    plt.pause(0.0001)

In [None]:
## visualize train dataset

inputs, classes = next(iter(dataloaders['train']))
out = torchvision.utils.make_grid(inputs)

imshow_tensor(out, title = [class_names[x] for x in classes])

### Save Tensor Image

In [None]:
def save_tensor_image(_phase = 'train') :
    num_images = 0
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders[_phase]) :
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            for j in range(inputs.size()[0]) :
                save_title = save_dir + '/' + _phase + '/' + _phase + '_' + str(num_images).zfill(3) + '.jpg'
                save_image(inputs.cpu().data[j], save_title)
                num_images += 1
        return

In [None]:
save_tensor_image('train')
save_tensor_image('test')
save_tensor_image('val')

# Train Model

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs = 25) :
    since = time.time()
    
    best_model_wts = model.state_dict()
    best_acc = 0.0
    
    for epoch in range(num_epochs) :
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        # epoch 마다 train과 val을 번갈아 실행
        for phase in ['train', 'val'] :
            if phase == 'train' :
                scheduler.step()
                model.train(True)
            else :
                model.train(False)
                
            running_loss = 0.0
            running_corrects = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()

                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)
                
                if phase == 'train' :
                    loss.backward()
                    optimizer.step()
                    # scheduler.step()
                    
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train' :
                scheduler.step()
                
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                 phase, epoch_loss, epoch_acc))


            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc :
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()
        
    time_elapsed = time.time() - since
    print('Training Complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load vest model weights
    model.load_state_dict(best_model_wts)
    return model

## Finetuning ConvNet
### Load Pretrained model

In [None]:
model_ft = models.resnet18(pretrained = True)
num_ftrs = model_ft.fc.in_features

# nn.Linear(num_ftrs, len(class_names))
model_ft.fc = nn.Linear(num_ftrs, CLASS_NUM)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size = 7, gamma = 0.1)

### Train and Evaluate

In [None]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                          num_epochs = EPOCH)

# Visualize Result
## Visualize Result with accuracy

In [None]:
## 각 클래스/데이터셋 별 정확도 출력

def visualize_accuracy(model, DATALOADERS, _phase='val') :
    was_training = model.training
    model.eval()
    images_so_far = 0
    acc = 0.0
    class_acc_list = [0.0 for i in range(CLASS_NUM)]
    class_cnt_list = [0.0 for i in range(CLASS_NUM)]
    correct = 0.0
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(DATALOADERS[_phase]):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            for j in range(inputs.size()[0]):
                
                images_so_far += 1
                is_correct = True
                if preds[j] != labels[j] : is_correct = False
                
                # Accuracy
                if is_correct : 
                    correct += 1
                    class_acc_list[preds[j]] += 1  
                
                class_cnt_list[labels[j]] += 1
                
        # calculate Accuracy
        acc = correct / images_so_far * 100.0
        print(">> Total Accuracy : {:.4f}".format(acc))
        
        print(">> CLASS Accuracy")
        for i in range(CLASS_NUM) :
            if class_cnt_list[i] == 0 : class_acc = 0.0
            else : class_acc = class_acc_list[i] / class_cnt_list[i] * 100
            print("- {} : {:.4f} ({} outof {})".format(
                class_names[i], class_acc, class_acc_list[i], class_cnt_list[i]))
            
        model.train(mode = was_training)
        return

### VAL ACCURACY

In [None]:
print("SEPARATE VAL ACCURACY\n")

print("RaFD ACCURACY")
visualize_accuracy(model_ft, rafd_dataloaders, 'val')

print("\nPE ACCURACY")
visualize_accuracy(model_ft, pe_dataloaders, 'val')

print("\nGAN ACCURACY")
visualize_accuracy(model_ft, gan_dataloaders, 'val')

print("\nTOTAL ACCURACY")
visualize_accuracy(model_ft, dataloaders, 'val')

### TEST ACCURACY

In [None]:
print("SEPARATE TEST ACCURACY\n")

print("RaFD ACCURACY")
visualize_accuracy(model_ft, rafd_dataloaders, 'test')

print("\nPE ACCURACY")
visualize_accuracy(model_ft, pe_dataloaders, 'test')

print("\nGAN ACCURACY")
visualize_accuracy(model_ft, gan_dataloaders, 'test')

print("\nTOTAL ACCURACY")
visualize_accuracy(model_ft, dataloaders, 'test')

## Visualize Model with Image

In [None]:
## 각 테스트 이미지와 그 분류 결과 출력
def visualize_model(model, num_images = 6, _phase='val') :
    was_training = model.training
    model.eval()
    images_so_far = 0
    acc = 0.0
    class_acc_list = [0.0 for i in range(CLASS_NUM)]
    class_cnt_list = [0.0 for i in range(CLASS_NUM)]
    correct = 0.0
    fig = plt.figure()
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders[_phase]):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            for j in range(inputs.size()[0]):
                
                images_so_far += 1
                ax = plt.subplot(1, 1, 1)
                ax.axis('off')
                is_correct = True
                if preds[j] != labels[j] : is_correct = False
                
                # Accuracy
                if is_correct : correct += 1  
                if num_images > 0 and images_so_far >= num_images : continue  
                
                class_cnt_list[labels[j]] += 1
                # 결과 출력 이미지
                if is_correct :
                    ax.set_title('predicted: {}'.format(
                                class_names[preds[j]]))
                    class_acc_list[preds[j]] += 1
                else :
                    ax.set_title('predicted: {} | answer: {}'.format(
                                class_names[preds[j]], class_names[labels[j]]))

                imshow(inputs.cpu().data[j])
                
        # calculate Accuracy
        acc = correct / images_so_far * 100.0
        print(">> Total Accuracy : {:.4f}".format(acc))
        
        print(">> CLASS Accuracy")
        for i in range(CLASS_NUM) :
            class_acc = class_acc_list[i] / class_cnt_list[i] * 100
            print("- {} : {:.4f} ({} outof {})".format(
                class_names[i], class_acc, class_acc_list[i], class_cnt_list[i]))
            
        model.train(mode = was_training)
        return

In [None]:
visualize_model(model_ft, -1, 'val')

In [None]:
visualize_model(model_ft, -1, 'test')