In [None]:
import pandas as pd
import numpy as np
import os
import cv2
import random
from tqdm.notebook import tqdm

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import albumentations
import timm

import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.model_selection import GroupKFold

from warnings import filterwarnings
filterwarnings("ignore")

device = torch.device('cuda')

In [None]:
image_size = 512
batch_size = 16
n_worker = 4
init_lr = 3e-4
n_epochs = 6
fold_id = 0
thres = 0.5

# 0.3부터 1까지 0.1간격으로 threshold를 검증하기 위해 만든 배열
search_space = np.arange(0.3, 1, 0.1)

backbone_name = 'resnet18'
weight_dir = './weights/resnet18_512_epoch5.pth'
data_dir = './data/'

In [None]:
df_train_all = pd.read_csv(os.path.join(data_dir, 'train.csv'))
df_train_all['file_path'] = df_train_all.image.apply(lambda x: os.path.join(data_dir, 'train_images', x))

In [None]:
gkf = GroupKFold(n_splits=5)
df_train_all['fold'] = -1
for fold, (train_idx, valid_idx) in enumerate(gkf.split(df_train_all, None, df_train_all.label_group)):
    df_train_all.loc[valid_idx, 'fold'] = fold
    
df_train = df_train_all[df_train_all['fold'] != fold_id]
df_valid = df_train_all[df_train_all['fold'] == fold_id]

In [None]:
transforms_valid = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
    albumentations.Normalize()
])

In [None]:
class SHOPEEDataset(Dataset):
    def __init__(self, df, mode, transform=None):
        self.df = df.reset_index(drop=True)
        self.mode = mode
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.loc[index]
        img = cv2.imread(row.file_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transform is not None:
            res = self.transform(image=img)
            img = res['image'].transpose(2,0,1)
        
        if self.mode == 'test':
            return torch.tensor(img).float()
        else:
            return torch.tensor(img).float(), torch.tensor(row.label_group)

In [None]:
class ArcFaceClassifier(nn.Module):
    def __init__(self, in_features, output_classes):
        super().__init__()
        self.W = nn.Parameter(torch.Tensor(in_features, output_classes))
        nn.init.kaiming_uniform_(self.W)
    def forward(self, x):
        x_norm = F.normalize(x)
        W_norm = F.normalize(self.W, dim=0)
        return x_norm @ W_norm
    
class ResnetArcFace(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=True)
        embedding_size = self.backbone.get_classifier().in_features
        self.after_conv=nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.BatchNorm1d(embedding_size))   
        self.classifier = ArcFaceClassifier(embedding_size, df_train.label_group.nunique())
    
    def forward(self, x, output_embs=False):
        embeddings = self.after_conv(self.backbone.forward_features(x))
        if output_embs:
            return F.normalize(embeddings)
        return self.classifier(embeddings)

In [None]:
model = ResnetArcFace()
model.to(device);

# 학습된 가중치 적용
model.load_state_dict(torch.load(weight_dir))

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


<All keys matched successfully>

In [None]:
dataset_valid = SHOPEEDataset(df_valid, 'test', transform = transforms_valid)
valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers = n_worker)

In [None]:
def get_embeddings(data_loader):
    model.eval()
    embs = []
    with torch.no_grad():
        for batch_idx, (images) in enumerate(tqdm(data_loader)):
            images = images.to(device)
            features = model(images, output_embs=True)
            embs += [features.detach().cpu()]
    embs = torch.cat(embs).cpu().numpy()
    return embs

In [None]:
embs = get_embeddings(valid_loader)

  0%|          | 0/429 [00:00<?, ?it/s]

In [None]:
embs.shape

(6851, 512)

In [None]:
df_valid = df_valid.reset_index(drop=True)
df_valid['embs'] = embs.tolist()

In [None]:
def show_image(file_path, title):
    plt.title(title)
    plt.axis('off')
    img = cv2.imread(file_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(img)

for i in random.sample(range(len(embs)), 100):
    search_row = df_valid.iloc[i]
    plt.figure(figsize=(18,3))
    plt.subplot(1, 6, 1)
    show_image(search_row.file_path, 'Input Image')

    pred = []
    for df_i, row in df_valid.iterrows():
        # 검색 이미지와 동일한 id의 이미지는 제외한다.
        if search_row['posting_id'] == row['posting_id']:
            continue

        cosine_sim = np.array(search_row['embs'])@np.array(row['embs']).T
        if cosine_sim > thres:
            pred.append((df_i, cosine_sim))

    # 코사인 유사도를 기준으로 내림차순 정렬한다.
    # pred = sorted(pred, key=lambda x: x[1], reverse=True)
    for j, (df_i, cosine_sim) in enumerate(pred):
        # 하나의 이미지에 대하여 5개까지만 시각화한다.
        if j == 5:
            break
        plt.subplot(1, 6, j+2)
        show_image(df_valid.iloc[df_i].file_path, 'Searched Image')
    plt.show()

Output hidden; open in https://colab.research.google.com to view.