In [None]:
### 필요한 라이브러리 준비

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
import copy

from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from ipywidgets import interact

In [None]:
test_data_dir = './DataSet/archive/test'

# 분류에 사용할 class 정의(7개의 감정들)
feelings_list = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']

In [None]:
IMAGE_SIZE = 48

### 이미지 파일 경로를 리스트 형태로 저장하기 위한 함수
def list_image_file(data_dir,sub_dir):
    image_files = []
    
    images_dir = os.path.join(data_dir,sub_dir)
    for file_path in os.listdir(images_dir):
        image_files.append(os.path.join(sub_dir,file_path))
    return image_files

In [None]:
test_angry_imgs = list_image_file(test_data_dir,'angry')
test_disgust_imgs = list_image_file(test_data_dir,'disgust')
test_fear_imgs = list_image_file(test_data_dir,'fear')
test_happy_imgs = list_image_file(test_data_dir,'happy')
test_neutral_imgs = list_image_file(test_data_dir,'neutral')
test_sad_imgs = list_image_file(test_data_dir,'sad')
test_surprise_imgs = list_image_file(test_data_dir,'surprise')

In [None]:
### 테스트 이미지 전처리

def preprocess_image(image):
    transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
    ])
    
    tensor_image = transformer(image) # image: (C,H,W)
    tensor_image = tensor_image.unsqueeze(0) # (B(batch),C,H,W)
    
    return tensor_image

In [None]:
### 예측을 위한 함수

def model_predict(image,model):
    tensor_image = preprocess_image(image) # 이미지(feature data)
    prediction = model(tensor_image) # 예측
    
    _, pred_label1 = torch.max(prediction.detach(),dim = 1) # dim = 1 : 1차원으로 이미지를 가져오겠다.
    print('pred_label1: ',pred_label1)
    
    pred_label = pred_label1.squeeze(0) # 차원 증가
    print('pred_label2: ',pred_label)
    
    return pred_label.item() 

In [None]:
### 학습된 모델 불러오기

ckpt = torch.load('./best_model/model22_gpu.pth')

model = build_vgg19_based_model()
model.load_state_dict(ckpt)
model.eval()

In [None]:
### 이미지 파일을 RGB 3차원 배열로 가져오는 함수

def get_RGB_image(data_dir,file_name):
    image_file = os.path.join(data_dir,file_name) # 이미지 경로 설정
    image = cv2.imread(image_file) # 이미지 열기
    image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) # BGR -> RGB
    
    return image

In [None]:
# 가장 개수가 적은 감정에 개수 맞추기

min_num_files = min(len(test_angry_imgs), len(test_disgust_imgs), len(test_fear_imgs),len(test_happy_imgs),
                    len(test_neutral_imgs),len(test_sad_imgs),len(test_surprise_imgs))

In [None]:
### 예측 결과 시각화

@interact(index = (0, min_num_files - 1))
def show_result(index = 0):
    angry_image = get_RGB_image(test_data_dir, test_angry_imgs[index])
    disgust_image = get_RGB_image(test_data_dir, test_disgust_imgs[index])
    fear_image = get_RGB_image(test_data_dir, test_fear_imgs[index])
    happy_image = get_RGB_image(test_data_dir, test_happy_imgs[index])
    neutral_image = get_RGB_image(test_data_dir, test_neutral_imgs[index])
    sad_image = get_RGB_image(test_data_dir, test_sad_imgs[index])
    surprise_image = get_RGB_image(test_data_dir, test_surprise_imgs[index])
    
    prediction_1 = model_predict(angry_image, model)
    prediction_2 = model_predict(disgust_image, model)
    prediction_3 = model_predict(fear_image, model)
    prediction_4 = model_predict(happy_image, model)
    prediction_5 = model_predict(neutral_image, model)
    prediction_6 = model_predict(sad_image, model)
    prediction_7 = model_predict(surprise_image, model)
    
    
    plt.figure(figsize=(21, 15))
    plt.subplot(141)
    plt.title(f'Pred: {feelings_list[prediction_1]} | GT: Angry')
    plt.imshow(angry_image)
    
    plt.subplot(142)
    plt.title(f'Pred: {feelings_list[prediction_2]} | GT: Disgust')
    plt.imshow(disgust_image)
    
    plt.subplot(143)
    plt.title(f'Pred: {feelings_list[prediction_3]} | GT: Fear')
    plt.imshow(fear_image)
    
    plt.subplot(144)
    plt.title(f'Pred: {feelings_list[prediction_4]} | GT: Happy')
    plt.imshow(happy_image)
    
    plt.subplot(241)
    plt.title(f'Pred: {feelings_list[prediction_5]} | GT: Neutral')
    plt.imshow(neutral_image)
    
    plt.subplot(242)
    plt.title(f'Pred: {feelings_list[prediction_6]} | GT: Sad')
    plt.imshow(sad_image)
    
    plt.subplot(243)
    plt.title(f'Pred: {feelings_list[prediction_7]} | GT: Surprise')
    plt.imshow(surprise_image)
    
    plt.tight_layout()