In [1]:
### 필요한 라이브러리 준비

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
import copy

from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from ipywidgets import interact

In [2]:
test_data_dir = './DataSet/archive/test'

# 분류에 사용할 class 정의(7개의 감정들)
feelings_list = ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']

In [3]:
IMAGE_SIZE = 48

### 이미지 파일 경로를 리스트 형태로 저장하기 위한 함수
def list_image_file(data_dir,sub_dir):
    image_files = []
    
    images_dir = os.path.join(data_dir,sub_dir)
    for file_path in os.listdir(images_dir):
        image_files.append(os.path.join(sub_dir,file_path))
    return image_files

In [4]:
test_angry_imgs = list_image_file(test_data_dir,'angry')
test_disgust_imgs = list_image_file(test_data_dir,'disgust')
test_fear_imgs = list_image_file(test_data_dir,'fear')
test_happy_imgs = list_image_file(test_data_dir,'happy')
test_neutral_imgs = list_image_file(test_data_dir,'neutral')
test_sad_imgs = list_image_file(test_data_dir,'sad')
test_surprise_imgs = list_image_file(test_data_dir,'surprise')

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
### 테스트 이미지 전처리

def preprocess_image(image):
    transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
    ])
    
    tensor_image = transformer(image) # image: (C,H,W)
    tensor_image = tensor_image.unsqueeze(0) # (B(batch),C,H,W)
    
    return tensor_image.to(device)

In [7]:
### 예측을 위한 함수

def model_predict(image,model):
    tensor_image = preprocess_image(image) # 이미지(feature data)
    prediction = model(tensor_image) # 예측
    
    _, pred_label1 = torch.max(prediction.detach(),dim = 1) # dim = 1 : 1차원으로 이미지를 가져오겠다.
    print('pred_label1: ',pred_label1)
    
    pred_label = pred_label1.squeeze(0) # 차원 증가
    print('pred_label2: ',pred_label)
    
    return pred_label.item()

In [8]:
### 모델 생성 함수
# 기존의 VGG19 모델 호출 -> head 부분 수정

def build_vgg19_based_model():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = models.vgg19(pretrained = True) # 이미 학습된 vgg19 모델 불러오기
    # 일반 NN Layer(FC Layer)
    model.avgpool = nn.AdaptiveAvgPool2d(output_size = (1,1)) # 각 구역의 평균값 출력
    model.classifier = nn.Sequential(
        nn.Flatten(), # 평탄화
        nn.Linear(512,256), # 512 -> 256
        nn.ReLU(), # 활성화 함수
        nn.Dropout(0.1), # 과적합 방지
        nn.Linear(256,7), # 256 -> 7(7개의 감정으로 분류되니)
        nn.Softmax() # 활성화 함수(각 클래스에 속할 확률 추정)
    )
    
    return model.to(device)

In [9]:
### 학습된 모델 불러오기

ckpt = torch.load('./best_model/model_57.pth')

model = build_vgg19_based_model()
model.load_state_dict(ckpt)
model.eval()



VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [10]:
### 이미지 파일을 RGB 3차원 배열로 가져오는 함수

def get_RGB_image(data_dir,file_name):
    image_file = os.path.join(data_dir,file_name) # 이미지 경로 설정
    image = cv2.imread(image_file) # 이미지 열기
    image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB) # BGR -> RGB
    
    return image

In [11]:
# 가장 개수가 적은 감정에 개수 맞추기

min_num_files = min(len(test_angry_imgs), len(test_disgust_imgs), len(test_fear_imgs),len(test_happy_imgs),
                    len(test_neutral_imgs),len(test_sad_imgs),len(test_surprise_imgs))

In [12]:
### 예측 결과 시각화

@interact(index = (0, min_num_files - 1))
def show_result(index = 0):
    # 테스트 이미지 파일 가져오기
    angry_image = get_RGB_image(test_data_dir, test_angry_imgs[index])
    disgust_image = get_RGB_image(test_data_dir, test_disgust_imgs[index])
    fear_image = get_RGB_image(test_data_dir, test_fear_imgs[index])
    happy_image = get_RGB_image(test_data_dir, test_happy_imgs[index])
    neutral_image = get_RGB_image(test_data_dir, test_neutral_imgs[index])
    sad_image = get_RGB_image(test_data_dir, test_sad_imgs[index])
    surprise_image = get_RGB_image(test_data_dir, test_surprise_imgs[index])
    
    # 예측
    prediction_1 = model_predict(angry_image, model)
    prediction_2 = model_predict(disgust_image, model)
    prediction_3 = model_predict(fear_image, model)
    prediction_4 = model_predict(happy_image, model)
    prediction_5 = model_predict(neutral_image, model)
    prediction_6 = model_predict(sad_image, model)
    prediction_7 = model_predict(surprise_image, model)
    
    # 시각화
    plt.figure(figsize=(21, 15))
    plt.subplot(171)
    plt.title(f'Pred: {feelings_list[prediction_1]} | GT: Angry')
    plt.imshow(angry_image)
    plt.subplot(172)
    plt.title(f'Pred: {feelings_list[prediction_2]} | GT: Disgust')
    plt.imshow(disgust_image)
    plt.subplot(173)
    plt.title(f'Pred: {feelings_list[prediction_3]} | GT: Fear')
    plt.imshow(fear_image)
    plt.subplot(174)
    plt.title(f'Pred: {feelings_list[prediction_4]} | GT: Happy')
    plt.imshow(happy_image)
    plt.subplot(175)
    plt.title(f'Pred: {feelings_list[prediction_5]} | GT: Neutral')
    plt.imshow(neutral_image)
    plt.subplot(176)
    plt.title(f'Pred: {feelings_list[prediction_6]} | GT: Sad')
    plt.imshow(sad_image)
    plt.subplot(177)
    plt.title(f'Pred: {feelings_list[prediction_7]} | GT: Surprise')
    plt.imshow(surprise_image)
    plt.tight_layout()

interactive(children=(IntSlider(value=0, description='index', max=109), Output()), _dom_classes=('widget-inter…