In [1]:
%reload_ext autoreload
%autoreload 2

import os

import math
import time
import random
from datetime import datetime
from pathlib import Path

from glob import glob

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np

from PIL import Image
from tqdm import tqdm
import metric_eval

METRIC_DIM = 512
IMG_SIZE = 256 

GPU_ID = 0
device = torch.device(f"cuda:{GPU_ID}")


BACKBONE = models.resnet50()


In [2]:
WEIGHTS_PATH = 'weights_G1/20220213_212547_ep(004)_acc(0.0617).pt' 
CROP_IMG_DIR = f'./reid_data/eval'

In [3]:
class EvalDataset(Dataset):
    def __init__(self, root_dir, trsf):
        
        to_pname = lambda x : Path(x).parent.name
        
        imgs = sorted(glob(f'{root_dir}/*/*.png'))[::]
        pnames = list({to_pname(e) for e in imgs})
        
        to_label = lambda x: pnames.index(to_pname(x)) 
        self.imgs = [[to_label(e), e] for e in tqdm(imgs)]
        self.trsf = trsf 
        self.num_labels = len(pnames)
        
    def select_random(self, idx, is_positive=True):
        pivot_label, pivot_path = self.imgs[idx]
        while True:
            label, path = random.choice(self.imgs)
            if path == pivot_path:
                continue
            if (label == pivot_label) is is_positive:
                return path
    
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        _, pivot_path = self.imgs[idx]
        pos_path = self.select_random(idx, is_positive=True)
        neg_path = self.select_random(idx, is_positive=False)
        paths = [pivot_path, pos_path, neg_path]
        imgs = [Image.open(e) for e in paths]
        imgs = [self.trsf(e) for e in imgs]
        return imgs, paths

In [4]:
data_transforms =  transforms.Compose([
        transforms.Resize(int(IMG_SIZE*1.0)),
        transforms.CenterCrop(IMG_SIZE),
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [5]:
dataset =  EvalDataset(CROP_IMG_DIR, data_transforms) 
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
print(len(dataset), dataset.num_labels)

100%|██████████| 5000/5000 [00:00<00:00, 76098.19it/s]

5000 500





In [6]:
class MetricModel(nn.Module):
    def __init__(self, weights_path, device):
        super(MetricModel, self).__init__()
        
        backbone = BACKBONE
        backbone.fc = nn.Linear(in_features=backbone.fc.in_features, out_features=METRIC_DIM)
        self.backbone = backbone

        weights = torch.load(weights_path)
        self.device = device
        self.backbone.load_state_dict(weights)
        self.to(device)
        self.eval()
        
    def forward(self, input_imgs):
        with torch.no_grad():
            x = self.backbone(input_imgs)
            x = F.normalize(x)
        return x
    
model = MetricModel(WEIGHTS_PATH, device)

In [7]:
total_outs = []
total_paths = []

for imgs, paths in tqdm(dataloader):
    imgs = [e.to(model.device) for e in imgs]
    outs = [model(e) for e in imgs]
    outs = torch.stack(outs)
    
    outs = outs.transpose(0,1)
    paths = list(zip(*paths))
    
    total_outs.append(outs)
    total_paths.extend(paths)
    
total_outs = torch.cat(total_outs).cpu()

100%|██████████| 79/79 [00:23<00:00,  3.31it/s]


In [8]:
import pandas as pd

t_pivot = total_outs[:,0]
t_pos = total_outs[:,1]
t_neg = total_outs[:,2]

df = pd.DataFrame({
    'p_pivot':list(zip(*total_paths))[0],
    'p_pos':list(zip(*total_paths))[1],
    'p_neg':list(zip(*total_paths))[2],
})

df['sim_pos'] = torch.bmm(t_pivot.unsqueeze(1), t_pos.unsqueeze(2)).squeeze().tolist()
df['sim_neg'] = torch.bmm(t_pivot.unsqueeze(1), t_neg.unsqueeze(2)).squeeze().tolist()
display(df)

Unnamed: 0,p_pivot,p_pos,p_neg,sim_pos,sim_neg
0,./reid_data/eval/00814/IN_H00838_SN4_102701_23...,./reid_data/eval/00814/IN_H00838_SN4_102701_22...,./reid_data/eval/00742/IN_H00766_SN1_102401_19...,0.785030,0.118886
1,./reid_data/eval/00368/OUT_H00388_SN4_091803_2...,./reid_data/eval/00368/OUT_H00388_SN4_091803_2...,./reid_data/eval/00383/OUT_H00403_SN3_091801_2...,0.771903,-0.038777
2,./reid_data/eval/00641/OUT_H00664_SN3_100801_2...,./reid_data/eval/00641/OUT_H00664_SN3_100801_2...,./reid_data/eval/00667/OUT_H00691_SN1_100902_3...,0.921295,0.093174
3,./reid_data/eval/00202/OUT_H00221_SN1_082001_6...,./reid_data/eval/00202/OUT_H00221_SN1_082001_6...,./reid_data/eval/00705/IN_H00729_SN3_101901_11...,0.957880,-0.099255
4,./reid_data/eval/00362/OUT_H00382_SN2_091801_1...,./reid_data/eval/00362/OUT_H00382_SN2_091801_1...,./reid_data/eval/00133/OUT_H00152_SN3_081801_2...,0.554231,0.059862
...,...,...,...,...,...
4995,./reid_data/eval/00877/IN_H00901_SN1_102801_10...,./reid_data/eval/00877/IN_H00901_SN1_102801_10...,./reid_data/eval/00158/OUT_H00177_SN4_081901_1...,0.552042,-0.053040
4996,./reid_data/eval/00171/OUT_H00190_SN1_081901_1...,./reid_data/eval/00171/OUT_H00190_SN1_081901_1...,./reid_data/eval/00823/IN_H00847_SN1_102701_13...,0.782344,-0.112486
4997,./reid_data/eval/00832/IN_H00856_SN1_102701_11...,./reid_data/eval/00832/IN_H00856_SN1_102701_11...,./reid_data/eval/00123/OUT_H00142_SN1_081801_2...,0.888225,0.111939
4998,./reid_data/eval/00671/OUT_H00695_SN1_100901_3...,./reid_data/eval/00671/OUT_H00695_SN1_100901_3...,./reid_data/eval/00895/OUT_H00919_SN1_102901_2...,0.453123,-0.031296


In [9]:
# 정확도 계산 
acc = metric_eval.calc_acc(df)
print(f'acc:{acc:.04f}')

acc:0.9404
