In [None]:
import torch
from PIL import Image
import torch
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# CNN 모델 불러오기
# 152
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet152', pretrained=True)

In [None]:
# feature 추출하기
def extract_feature(filepath,image_name) :
    input_image = Image.open(filepath+image_name)
    preprocess = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the mode

    if torch.cuda.is_available():
      input_batch = input_batch.to('cuda')
      model.to('cuda')

    with torch.no_grad():
      output = model(input_batch)
    
#     return output[0]
    return torch.nn.functional.softmax(output[0], dim=0)

In [None]:
from os import listdir
from os.path import isfile, join
filepath = 'C:/바탕 화면/playdata/ice/eclipse/Devils_Project/frontend/src/assets/img/studio'
files = [f for f in listdir(filepath) if isfile(join(filepath, f))]

In [None]:
# 서버 이미지 불러오기
import time
start = time.time()
filepath = 'C:/바탕 화면/playdata/ice/eclipse/Devils_Project/frontend/src/assets/img/studio/'
feature_list = []
for i in range(1000):
    vector = extract_feature(filepath, files[i])
    feature = {
      'name' : files[i],
      'vector' : vector
      }
    feature_list.append(feature)
print(len(feature_list))
print("time :", time.time() - start)  # 현재시각 - 시작시간 = 실행 시간

In [None]:
def get_sim_pic(filepath, file):
    target_img = extract_feature(filepath,file)
    sim_list = []
    cos = torch.nn.CosineSimilarity(dim=-1)
    for feature in feature_list:
        sim_degree = cos(target_img,feature['vector'])
        sim = {
          'name' : feature['name'],
          'sim_degree' : float(sim_degree)
        }
        sim_list.append(sim)
    df = pd.DataFrame(sim_list)
    result_list = df.sort_values(by='sim_degree' ,ascending=False)[:5]['name']
    print(df.sort_values(by='sim_degree' ,ascending=False)[:5])
    return result_list

In [None]:
def show_pics(result_list, target_file) :
    print(target_file)
    img =Image.open(filepath+target_file)
    # img.show()
    pix=np.array(img)
    plt.imshow(img)
    plt.show()
    print('-'*30)
    for img_name in result_list:
        print(img_name)
        img =Image.open(filepath+img_name)
        pix=np.array(img)
        plt.imshow(img)
        plt.show()

In [None]:
file = files[1001]
result_list = get_sim_pic(filepath, file)
show_pics(result_list, file)

In [None]:
# 화면에서 요청한 이미지 가져와 추출
print(files[1002]) # 사진 지정... 서버 연결 시 삭제
target_img = extract_feature(filepath,files[1002])

In [None]:
# 유사도 측정
sim_list = []
cos = torch.nn.CosineSimilarity(dim=-1)
for feature in feature_list:
    sim_degree = cos(target_img,feature['vector'])
    sim = {
      'name' : feature['name'],
      'sim_degree' : float(sim_degree)
    }
    sim_list.append(sim)
print(sim_list)

In [None]:
df = pd.DataFrame(sim_list)
result_list = df.sort_values(by='sim_degree' ,ascending=False)[:5]['name']
result_list

In [None]:
print(files[1002])
img =Image.open(filepath+files[1002])
# img.show()
pix=np.array(img)
plt.imshow(img)
plt.show()
print('-'*30)
for img_name in result_list:
    print(img_name)
    img =Image.open(filepath+img_name)
    pix=np.array(img)
    plt.imshow(img)
    plt.show()

In [None]:

plt.show()

In [None]:
# testset = torchvision.datasets.ImageFolder(root=path[:-9]+'/resources', transform=trans)
path = '../../frontend/src/assets/img'
preprocess = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ])
testset = torchvision.datasets.ImageFolder(path, transform=preprocess)
testloader = torch.utils.data.DataLoader(dataset=testset, shuffle=False)
device = torch.device('cpu')
image_list = []
for image, label in testloader:
    image=image.to(device)
print(len(image))
    