In [19]:
# dependency package import

import os

import math
import time
import random
import shutil
from datetime import datetime
from pathlib import Path

from glob import glob

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import pandas as pd
import numpy as np

from PIL import Image
from tqdm.auto import tqdm

In [20]:
# 평가 이미지 리스트
BASE_DIR = '/data/kts123/aihub/reid'
test_imgs = '/data/kts123/aihub/reid/img_list_test.txt'

# 훈련된 모델경로
weight_path = '/home/kts123/gc2021/3차/track3/arcface/checkpoints_res50_base/scheduler_resnet50_99.pth'

In [29]:
df = pd.read_csv(test_imgs)
for kls, df_i in tqdm(df.groupby('KLS_IDX')):
    dst_dir = f'eval_imgs/{kls:05d}'
    Path(dst_dir).mkdir(exist_ok=True, parents=True)
    for name in df_i['NAME'].values:
        src = f'{BASE_DIR}/{name}'
        dst = f'{dst_dir}/{Path(name).name}'
        shutil.copy(src, dst)

  0%|          | 0/500 [00:00<?, ?it/s]

In [30]:
class EvalDataset(Dataset):
    def __init__(self, root_dir, trsf):
        
        to_pname = lambda x : Path(x).parent.name
        
        imgs = sorted(glob(f'{root_dir}/*/*.jpg'))
        pnames = list({to_pname(e) for e in imgs})
        
        to_label = lambda x: pnames.index(to_pname(x)) 
        self.imgs = [[to_label(e), e] for e in tqdm(imgs)]
        self.trsf = trsf 
        self.num_labels = len(pnames) 
    
    def select_random(self, idx, is_positive=True):
        pivot_label, pivot_path = self.imgs[idx]
        while True:
            label, path = random.choice(self.imgs)
            if path == pivot_path:
                continue
            if (label == pivot_label) is is_positive:
                return path
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        _, pivot_path = self.imgs[idx]
        pos_path = self.select_random(idx, is_positive=True)
        neg_path = self.select_random(idx, is_positive=False)
        paths = [pivot_path, pos_path, neg_path]
        imgs = [Image.open(e) for e in paths]
        imgs = [self.trsf(e) for e in imgs]
        return imgs, paths

In [None]:
data_transforms =  transforms.Compose([
        transforms.Resize(int(IMG_SIZE*1.2)),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [None]:
dataset =  EvalDataset(CROP_IMG_DIR, data_transforms) 
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)
print(len(dataset), dataset.num_labels)

In [33]:
class MetricModel(nn.Module):
    def __init__(self, weights_path, device):
        super(MetricModel, self).__init__()
        
        backbone = models.resnext50_32x4d(pretrained=True)
        backbone.fc = nn.Linear(in_features=backbone.fc.in_features, out_features=METRIC_DIM)
        self.backbone = backbone

        weights = torch.load(weights_path)
        self.device = device
        self.backbone.load_state_dict(weights)
        self.to(device)
        self.eval() 
        
    def forward(self, input_imgs):
        with torch.no_grad():
            x = self.backbone(input_imgs)
            x = F.normalize(x)
        return x
