# Library Load

In [1]:
# 필요한 PyTorch 라이브러리 불러오기
import os
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image

import PIL
from PIL import Image
import matplotlib.pyplot as plt

import cv2
import numpy as np
import pandas as pd

import copy

from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

In [2]:
np.random.seed(42)

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [4]:
# GPU 장치 사용 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Data Load

In [5]:
# Style Image Path
style_img_path = './style_img/'
style_img_list = os.listdir(style_img_path)
len(style_img_list)

114

In [6]:
style_img_list

['b01.jpg',
 'b02.jpg',
 'b03.jpg',
 'b04.jpg',
 'b05.jpg',
 'b06.jpg',
 'b07.jpg',
 'b08.jpg',
 'b09.jpg',
 'b10.jpg',
 'b100.jpg',
 'b101.jpg',
 'b102.jpg',
 'b103.jpg',
 'b104.jpg',
 'b105.jpg',
 'b106.jpg',
 'b107.jpg',
 'b108.jpg',
 'b109.jpg',
 'b11.jpg',
 'b110.jpg',
 'b111.jpg',
 'b112.jpg',
 'b113.jpg',
 'b114.jpg',
 'b12.jpg',
 'b13.jpg',
 'b14.jpg',
 'b15.jpg',
 'b16.jpg',
 'b17.jpg',
 'b18.jpg',
 'b19.jpg',
 'b20.jpg',
 'b21고누놀이.jpg',
 'b22금강사군첩.jpg',
 'b23기와이기.jpg',
 'b24길쌈.jpg',
 'b25나들이.jpg',
 'b26나룻배.jpg',
 'b27논갈이.jpg',
 'b28늦은 밤 피리부는 선인.jpg',
 'b29담배썰기.jpg',
 'b30대장간.jpg',
 'b31벼타작.jpg',
 'b32병진년화첩-옥순봉도.jpg',
 'b33빨래터.jpg',
 'b34소림명월도.jpg',
 'b35송호도.jpg',
 'b36신선과 사슴.jpg',
 'b37신행길.jpg',
 'b38연광정연회도.jpg',
 'b39자리짜기.jpg',
 'b40장터길.jpg',
 'b41점심.jpg',
 'b42주막.jpg',
 'b43처용무.jpg',
 'b44편자박기.jpg',
 'b45하화청정도.jpg',
 'b46행상.jpg',
 'b47활쏘기.jpg',
 'b48황묘농접도.jpg',
 'b49.jpg',
 'b50.jpg',
 'b51.jpg',
 'b52.jpg',
 'b53.jpg',
 'b54.jpg',
 'b55.jpg',
 'b56.jpg',
 'b57.jpg',
 'b58.

# Gram Matrix Extraction

In [7]:
# 뉴럴 네트워크 모델을 불러옵니다.
cnn = models.vgg19(pretrained=True).features.to(device).eval()
print(cnn)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace=True)
  (18): MaxPoo

In [8]:
# 이미지를 불러와 다운받아 텐서(Tensor) 객체로 변환하는 함수
def image_loader(img_path, imsize):
    loader = transforms.Compose([
        transforms.Resize(imsize), # 이미지의 크기를 변경
        transforms.ToTensor() # torch.Tensor 형식으로 변경 [0, 255] → [0, 1]
    ])
    image = PIL.Image.open(img_path)
    # 네트워크 입력에 들어갈 이미지에 배치 목적의 차원(dimension) 추가
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float) # GPU로 올리기


# torch.Tensor 형태의 이미지를 화면에 출력하는 함수
def imshow(tensor):
    # matplotlib는 CPU 기반이므로 CPU로 옮기기
    image = tensor.cpu().clone()
    # torch.Tensor에서 사용되는 배치 목적의 차원(dimension) 제거
    image = image.squeeze(0)
    # PIL 객체로 변경 
    image = transforms.ToPILImage()(image)
    # 이미지를 화면에 출력(matplotlib는 [0, 1] 사이의 값이라고 해도 정상적으로 처리)
    plt.imshow(image)
    plt.show()

In [9]:
# 입력 정규화(Normalization)를 위한 초기화
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)

class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = mean.clone().view(-1, 1, 1)
        self.std = std.clone().view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

In [10]:
def gram_matrix(input):
    # a는 배치 크기, b는 특징 맵의 개수, (c, d)는 특징 맵의 차원을 의미
    a, b, c, d = input.size()
    # 논문에서는 i = 특징 맵의 개수, j = 각 위치(position)
    features = input.view(a * b, c * d)
    # 행렬 곱으로 한 번에 Gram 내적 계산 가능
    G = torch.mm(features, features.t())
    # Normalize 목적으로 값 나누기
    return G.div(a * b * c * d)

In [11]:
cnn = copy.deepcopy(cnn)
normalization = Normalization(cnn_normalization_mean, cnn_normalization_std).to(device)

# 가장 먼저 입력 이미지가 입력 정규화(input normalization)를 수행하도록
model = nn.Sequential(normalization)

# 현재 CNN 모델에 포함되어 있는 모든 레이어를 확인하며
i = 0
for layer in cnn.children():
    if isinstance(layer, nn.Conv2d):
        i += 1
        name = 'conv_{}'.format(i)
    elif isinstance(layer, nn.ReLU):
        name = 'relu_{}'.format(i)
        layer = nn.ReLU(inplace=False)
    elif isinstance(layer, nn.MaxPool2d):
        name = 'pool_{}'.format(i)
    elif isinstance(layer, nn.BatchNorm2d):
        name = 'bn_{}'.format(i)
    else:
        raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

    model.add_module(name, layer)
    print(model)
    print('='*60)

Sequential(
  (0): Normalization()
  (conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
Sequential(
  (0): Normalization()
  (conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_1): ReLU()
)
Sequential(
  (0): Normalization()
  (conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_1): ReLU()
  (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
Sequential(
  (0): Normalization()
  (conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_1): ReLU()
  (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_2): ReLU()
)
Sequential(
  (0): Normalization()
  (conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_1): ReLU()
  (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_2): ReLU()
  (pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=Fals

In [12]:
model

Sequential(
  (0): Normalization()
  (conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_1): ReLU()
  (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_2): ReLU()
  (pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_3): ReLU()
  (conv_4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_4): ReLU()
  (pool_4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_5): ReLU()
  (conv_6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_6): ReLU()
  (conv_7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_7): ReLU()
  (conv_8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_8): ReLU()
  (po

In [13]:
feature_model = model[:12]

In [14]:
feature_model

Sequential(
  (0): Normalization()
  (conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_1): ReLU()
  (conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_2): ReLU()
  (pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_3): ReLU()
  (conv_4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_4): ReLU()
  (pool_4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [15]:
gram_matrix_list = []

for name in tqdm(style_img_list):
    img_size = PIL.Image.open(style_img_path + name).size
    style_img = image_loader(style_img_path + name, (img_size[1], img_size[0]))
    
    target_feature = feature_model(style_img).detach()
    print(target_feature.shape)
    
    GRAM_MATRIX = gram_matrix(target_feature)
    print(GRAM_MATRIX.shape)
    GRAM_MATRIX = GRAM_MATRIX.to('cpu')
    GRAM_MATRIX = GRAM_MATRIX.numpy()
    print(GRAM_MATRIX.shape)
    
    gram_matrix_list.append(GRAM_MATRIX)

  0%|                                                                                          | 0/114 [00:00<?, ?it/s]

torch.Size([1, 256, 500, 322])


  2%|█▍                                                                                | 2/114 [00:03<02:42,  1.45s/it]

torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 272, 500])
torch.Size([256, 256])
(256, 256)


  3%|██▏                                                                               | 3/114 [00:03<01:40,  1.10it/s]

torch.Size([1, 256, 356, 500])
torch.Size([256, 256])
(256, 256)


  4%|██▉                                                                               | 4/114 [00:03<01:11,  1.53it/s]

torch.Size([1, 256, 405, 500])
torch.Size([256, 256])
(256, 256)


  5%|████▎                                                                             | 6/114 [00:04<00:47,  2.27it/s]

torch.Size([1, 256, 500, 487])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 181])
torch.Size([256, 256])
(256, 256)


  6%|█████                                                                             | 7/114 [00:04<00:37,  2.85it/s]

torch.Size([1, 256, 500, 230])
torch.Size([256, 256])
(256, 256)


  7%|█████▊                                                                            | 8/114 [00:05<00:36,  2.93it/s]

torch.Size([1, 256, 426, 500])
torch.Size([256, 256])
(256, 256)


  8%|██████▍                                                                           | 9/114 [00:05<00:35,  2.96it/s]

torch.Size([1, 256, 425, 500])
torch.Size([256, 256])
(256, 256)


  9%|███████                                                                          | 10/114 [00:06<00:42,  2.42it/s]

torch.Size([1, 256, 431, 500])
torch.Size([256, 256])
(256, 256)


 10%|███████▊                                                                         | 11/114 [00:06<00:35,  2.89it/s]

torch.Size([1, 256, 500, 250])
torch.Size([256, 256])
(256, 256)


 11%|████████▌                                                                        | 12/114 [00:06<00:32,  3.18it/s]

torch.Size([1, 256, 349, 500])
torch.Size([256, 256])
(256, 256)


 11%|█████████▏                                                                       | 13/114 [00:06<00:30,  3.31it/s]

torch.Size([1, 256, 421, 500])
torch.Size([256, 256])
(256, 256)


 13%|██████████▋                                                                      | 15/114 [00:07<00:23,  4.23it/s]

torch.Size([1, 256, 348, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 177, 500])
torch.Size([256, 256])
(256, 256)


 14%|███████████▎                                                                     | 16/114 [00:07<00:23,  4.15it/s]

torch.Size([1, 256, 500, 362])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 229])


 15%|████████████                                                                     | 17/114 [00:07<00:21,  4.53it/s]

torch.Size([256, 256])
(256, 256)


 16%|████████████▊                                                                    | 18/114 [00:07<00:22,  4.18it/s]

torch.Size([1, 256, 425, 500])
torch.Size([256, 256])
(256, 256)


 17%|█████████████▌                                                                   | 19/114 [00:08<00:24,  3.92it/s]

torch.Size([1, 256, 420, 500])
torch.Size([256, 256])
(256, 256)


 18%|██████████████▏                                                                  | 20/114 [00:08<00:23,  3.92it/s]

torch.Size([1, 256, 411, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 143, 123])
torch.Size([256, 256])
(256, 256)


 19%|███████████████▋                                                                 | 22/114 [00:08<00:18,  4.97it/s]

torch.Size([1, 256, 352, 500])
torch.Size([256, 256])
(256, 256)


 20%|████████████████▎                                                                | 23/114 [00:08<00:19,  4.74it/s]

torch.Size([1, 256, 343, 500])
torch.Size([256, 256])
(256, 256)


 21%|█████████████████                                                                | 24/114 [00:09<00:27,  3.27it/s]

torch.Size([1, 256, 424, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 158, 136])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 152, 136])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 170, 141])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 132, 111])
torch.Size([256, 256])
(256, 256)


 25%|████████████████████▌                                                            | 29/114 [00:09<00:11,  7.16it/s]

torch.Size([1, 256, 261, 500])
torch.Size([256, 256])
(256, 256)


 26%|█████████████████████▎                                                           | 30/114 [00:09<00:13,  6.15it/s]

torch.Size([1, 256, 417, 500])
torch.Size([256, 256])
(256, 256)


 28%|██████████████████████▋                                                          | 32/114 [00:10<00:14,  5.47it/s]

torch.Size([1, 256, 500, 418])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 207, 500])
torch.Size([256, 256])
(256, 256)


 29%|███████████████████████▍                                                         | 33/114 [00:10<00:16,  4.77it/s]

torch.Size([1, 256, 426, 500])
torch.Size([256, 256])
(256, 256)


 30%|████████████████████████▏                                                        | 34/114 [00:11<00:18,  4.37it/s]

torch.Size([1, 256, 426, 500])
torch.Size([256, 256])
(256, 256)


 36%|█████████████████████████████▏                                                   | 41/114 [00:11<00:06, 12.00it/s]

torch.Size([1, 256, 373, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 77, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 56, 78])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 79, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 56, 78])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])

 49%|███████████████████████████████████████▊                                         | 56/114 [00:11<00:01, 34.69it/s]


torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 66, 78])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 67, 78])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 47])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 43, 73])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 56, 78])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 79, 65])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 78, 68])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 79, 56])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 79, 67])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 67, 78])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 

 57%|██████████████████████████████████████████████▏                                  | 65/114 [00:11<00:01, 29.63it/s]

torch.Size([1, 256, 428, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 175, 150])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 427])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 425])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 422, 500])
torch.Size([256, 256])
(256, 256)


 61%|█████████████████████████████████████████████████▋                               | 70/114 [00:13<00:03, 11.21it/s]

torch.Size([1, 256, 424, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 279])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 319])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 129, 500])
torch.Size([256, 256])
(256, 256)


 65%|████████████████████████████████████████████████████▌                            | 74/114 [00:13<00:04,  9.08it/s]

torch.Size([1, 256, 397, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 420, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 188])
torch.Size([256, 256])
(256, 256)


 68%|██████████████████████████████████████████████████████▋                          | 77/114 [00:14<00:04,  7.74it/s]

torch.Size([1, 256, 382, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 157, 500])
torch.Size([256, 256])
(256, 256)

 69%|████████████████████████████████████████████████████████▏                        | 79/114 [00:14<00:04,  7.58it/s]


torch.Size([1, 256, 500, 222])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 182, 500])
torch.Size([256, 256])
(256, 256)


 71%|█████████████████████████████████████████████████████████▌                       | 81/114 [00:15<00:05,  5.94it/s]

torch.Size([1, 256, 500, 194])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 263, 500])
torch.Size([256, 256])
(256, 256)


 73%|██████████████████████████████████████████████████████████▉                      | 83/114 [00:15<00:05,  6.00it/s]

torch.Size([1, 256, 500, 234])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 218])
torch.Size([256, 256])
(256, 256)


 75%|████████████████████████████████████████████████████████████▍                    | 85/114 [00:16<00:04,  6.26it/s]

torch.Size([1, 256, 185, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 253])
torch.Size([256, 256])
(256, 256)

 76%|█████████████████████████████████████████████████████████████▊                   | 87/114 [00:16<00:04,  6.39it/s]


torch.Size([1, 256, 500, 208])
torch.Size([256, 256])
(256, 256)


 78%|███████████████████████████████████████████████████████████████▏                 | 89/114 [00:16<00:04,  5.95it/s]

torch.Size([1, 256, 347, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 236])
torch.Size([256, 256])
(256, 256)


 79%|███████████████████████████████████████████████████████████████▉                 | 90/114 [00:17<00:03,  6.29it/s]

torch.Size([1, 256, 500, 191])
torch.Size([256, 256])
(256, 256)


 80%|████████████████████████████████████████████████████████████████▋                | 91/114 [00:17<00:04,  5.37it/s]

torch.Size([1, 256, 500, 403])
torch.Size([256, 256])
(256, 256)


 81%|█████████████████████████████████████████████████████████████████▎               | 92/114 [00:17<00:04,  4.94it/s]

torch.Size([1, 256, 355, 500])
torch.Size([256, 256])
(256, 256)


 82%|██████████████████████████████████████████████████████████████████               | 93/114 [00:17<00:04,  4.51it/s]

torch.Size([1, 256, 427, 500])
torch.Size([256, 256])
(256, 256)


 82%|██████████████████████████████████████████████████████████████████▊              | 94/114 [00:17<00:04,  4.70it/s]

torch.Size([1, 256, 500, 226])
torch.Size([256, 256])
(256, 256)


 83%|███████████████████████████████████████████████████████████████████▌             | 95/114 [00:18<00:05,  3.29it/s]

torch.Size([1, 256, 454, 286])
torch.Size([256, 256])
(256, 256)


 85%|████████████████████████████████████████████████████████████████████▉            | 97/114 [00:18<00:04,  4.10it/s]

torch.Size([1, 256, 346, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 226])
torch.Size([256, 256])
(256, 256)


 86%|█████████████████████████████████████████████████████████████████████▋           | 98/114 [00:19<00:03,  4.55it/s]

torch.Size([1, 256, 500, 241])
torch.Size([256, 256])
(256, 256)


 88%|██████████████████████████████████████████████████████████████████████▏         | 100/114 [00:19<00:02,  5.28it/s]

torch.Size([1, 256, 351, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 182, 500])
torch.Size([256, 256])
(256, 256)


 89%|██████████████████████████████████████████████████████████████████████▉         | 101/114 [00:19<00:02,  4.60it/s]

torch.Size([1, 256, 406, 500])
torch.Size([256, 256])
(256, 256)


 90%|████████████████████████████████████████████████████████████████████████▎       | 103/114 [00:20<00:02,  4.90it/s]

torch.Size([1, 256, 372, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 214])
torch.Size([256, 256])
(256, 256)


 91%|████████████████████████████████████████████████████████████████████████▉       | 104/114 [00:20<00:02,  4.35it/s]

torch.Size([1, 256, 431, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 246])

 92%|█████████████████████████████████████████████████████████████████████████▋      | 105/114 [00:20<00:01,  4.78it/s]


torch.Size([256, 256])
(256, 256)


 93%|██████████████████████████████████████████████████████████████████████████▍     | 106/114 [00:20<00:02,  3.56it/s]

torch.Size([1, 256, 500, 388])
torch.Size([256, 256])
(256, 256)


 94%|███████████████████████████████████████████████████████████████████████████     | 107/114 [00:21<00:01,  3.51it/s]

torch.Size([1, 256, 500, 431])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 251])

 96%|████████████████████████████████████████████████████████████████████████████▍   | 109/114 [00:21<00:01,  4.60it/s]


torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 229])
torch.Size([256, 256])
(256, 256)


 96%|█████████████████████████████████████████████████████████████████████████████▏  | 110/114 [00:21<00:00,  4.35it/s]

torch.Size([1, 256, 395, 500])
torch.Size([256, 256])
(256, 256)


 98%|██████████████████████████████████████████████████████████████████████████████▌ | 112/114 [00:22<00:00,  5.09it/s]

torch.Size([1, 256, 347, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 190, 500])
torch.Size([256, 256])
(256, 256)


 99%|███████████████████████████████████████████████████████████████████████████████▎| 113/114 [00:22<00:00,  4.39it/s]

torch.Size([1, 256, 404, 500])
torch.Size([256, 256])
(256, 256)
torch.Size([1, 256, 500, 264])


100%|████████████████████████████████████████████████████████████████████████████████| 114/114 [00:22<00:00,  5.03it/s]

torch.Size([256, 256])
(256, 256)





In [16]:
len(gram_matrix_list)

114

# K-Means

In [17]:
data = {}

for style, gram in tqdm(zip(style_img_list, gram_matrix_list)):
    data[style] = gram

114it [00:00, 114035.45it/s]


In [18]:
data

{'b01.jpg': array([[ 0.01665387, -0.00460237, -0.00438314, ..., -0.00327673,
          0.00052889, -0.00252655],
        [-0.00460237,  0.06756994,  0.01037986, ...,  0.00106901,
         -0.00134396, -0.01411994],
        [-0.00438314,  0.01037986,  0.01942222, ...,  0.0007324 ,
         -0.00193583, -0.0069167 ],
        ...,
        [-0.00327673,  0.00106901,  0.0007324 , ...,  0.03768617,
         -0.00150939,  0.00050153],
        [ 0.00052889, -0.00134396, -0.00193583, ..., -0.00150939,
          0.03574582, -0.00132985],
        [-0.00252655, -0.01411994, -0.0069167 , ...,  0.00050153,
         -0.00132985,  0.04403799]], dtype=float32),
 'b02.jpg': array([[ 0.04221762, -0.00867174, -0.01583532, ..., -0.01158446,
         -0.00271553, -0.00410667],
        [-0.00867174,  0.06652595,  0.01046396, ...,  0.01522296,
         -0.00684379,  0.02943474],
        [-0.01583532,  0.01046396,  0.03385571, ...,  0.00698373,
          0.00044413, -0.00103135],
        ...,
        [-0.01158

In [19]:
# get a list of the filenames
filenames = np.array(list(data.keys()))
print(filenames.shape)

# get a list of just the features
feat = np.array(list(data.values()))
print(feat.shape)

feat_resh = feat.reshape(114, -1)
print(feat_resh.shape)

(114,)
(114, 256, 256)
(114, 65536)


In [20]:
kmeans = KMeans(n_clusters=3, n_jobs=8, random_state=42, max_iter=100)
kmeans.fit(feat_resh)



KMeans(max_iter=100, n_clusters=3, n_jobs=8, random_state=42)

In [21]:
kmeans.labels_

array([0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0,
       0, 0, 0, 2, 1, 2, 0, 0, 0, 2, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 1, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0])

In [22]:
len(kmeans.labels_)

114

In [23]:
kmeans.cluster_centers_

array([[ 0.04166927, -0.00539597, -0.01604619, ...,  0.00773096,
         0.00355729,  0.05958238],
       [ 0.13243957, -0.03661051, -0.04479899, ...,  0.09734885,
        -0.02457084,  0.27106723],
       [ 0.06503568, -0.00984312, -0.01953047, ...,  0.04698005,
        -0.02642296,  0.12430751]], dtype=float32)

In [24]:
kmeans.cluster_centers_.shape

(3, 65536)

In [25]:
kmeans.cluster_centers_[0].shape

(65536,)

In [26]:
gram_matrix_list[0].reshape(-1, ).shape

(65536,)

## dissimilarity style extraction

## cluster & image mapping

In [27]:
groups = {}

for file, cluster in zip(filenames, kmeans.labels_):
    if cluster not in groups.keys():
        groups[cluster] = []
        groups[cluster].append(file)
    else:
        groups[cluster].append(file)

In [28]:
groups

{0: ['b01.jpg',
  'b02.jpg',
  'b03.jpg',
  'b04.jpg',
  'b05.jpg',
  'b06.jpg',
  'b08.jpg',
  'b09.jpg',
  'b10.jpg',
  'b100.jpg',
  'b101.jpg',
  'b102.jpg',
  'b103.jpg',
  'b104.jpg',
  'b105.jpg',
  'b106.jpg',
  'b107.jpg',
  'b108.jpg',
  'b109.jpg',
  'b110.jpg',
  'b111.jpg',
  'b112.jpg',
  'b113.jpg',
  'b14.jpg',
  'b15.jpg',
  'b16.jpg',
  'b18.jpg',
  'b19.jpg',
  'b20.jpg',
  'b49.jpg',
  'b50.jpg',
  'b51.jpg',
  'b52.jpg',
  'b53.jpg',
  'b54.jpg',
  'b55.jpg',
  'b56.jpg',
  'b57.jpg',
  'b58.jpg',
  'b59.jpg',
  'b60.jpg',
  'b61.jpg',
  'b62.jpg',
  'b63.jpg',
  'b64.jpg',
  'b65.jpg',
  'b66.jpg',
  'b67.jpg',
  'b68.jpg',
  'b69.jpg',
  'b70.jpg',
  'b71.jpg',
  'b72.jpg',
  'b73.jpg',
  'b74.jpg',
  'b76.jpg',
  'b77.jpg',
  'b78.jpg',
  'b79.jpg',
  'b80.jpg',
  'b81.jpg',
  'b82.jpg',
  'b83.jpg',
  'b84.jpg',
  'b85.jpg',
  'b86.jpg',
  'b87.jpg',
  'b88.jpg',
  'b89.jpg',
  'b90.jpg',
  'b91.jpg',
  'b92.jpg',
  'b93.jpg',
  'b94.jpg',
  'b95.jpg',
  'b96.j

## centroid style extraction

In [29]:
dissim_img_idx = []

for idx, cen in tqdm(enumerate(kmeans.cluster_centers_)):
    tmp_list = []
    for gram in feat_resh:
        tmp_list.append(abs((cen - gram)).sum())
    print(np.argmin(tmp_list))
    dissim_img_idx.append(np.argmin(tmp_list))

3it [00:00, 71.20it/s]

112
48
55





In [30]:
dissim_img_idx

[112, 48, 55]

In [31]:
dissim_style_name = []

for dis in dissim_img_idx:
    dissim_style_name.append(style_img_list[dis])
    print(style_img_list[dis])

b98.jpg
b34소림명월도.jpg
b41점심.jpg


In [37]:
sum(sum(data['b98.jpg'] - gram_matrix_list[112]))

0.0

## cluster & center image mapping

In [33]:
dissim_center_img = {}

for name in dissim_style_name:
    for key, value in groups.items():
        if name in value:
            dissim_center_img[name] = key
        else:
            pass

In [34]:
dissim_center_img

{'b98.jpg': 0, 'b34소림명월도.jpg': 1, 'b41점심.jpg': 2}

## dissimilarity image show

In [35]:
#PIL.Image.open(style_img_path + style_img_list[112])

In [36]:
#PIL.Image.open(style_img_path + style_img_list[60])