In [1]:
%reload_ext autoreload
%autoreload 2

import os

import math
import time
import random
from datetime import datetime
from pathlib import Path

from glob import glob

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np

from PIL import Image
from tqdm import tqdm

import metric_eval

METRIC_DIM = 512
IMG_SIZE = 256 

GPU_ID = 0
device = torch.device(f"cuda:{GPU_ID}")


BACKBONE = models.resnet50()


In [2]:
WEIGHTS_PATH = 'weights_G1/20220213_212547_ep(004)_acc(0.0617).pt' 
CROP_IMG_DIR = f'./reid_data/eval'

In [3]:
class EvalDataset(Dataset):
    def __init__(self, root_dir, trsf):
        
        to_pname = lambda x : Path(x).parent.name
        
        imgs = sorted(glob(f'{root_dir}/*/*.png'))[::]
        pnames = list({to_pname(e) for e in imgs})
        
        to_label = lambda x: pnames.index(to_pname(x)) 
        self.imgs = [[to_label(e), e] for e in tqdm(imgs)]
        self.trsf = trsf 
        self.num_labels = len(pnames)
        
    
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        label, path = self.imgs[idx]
        img = Image.open(path)
        img = self.trsf(img)
        return  img, label, path

In [4]:
data_transforms =  transforms.Compose([
        transforms.Resize(int(IMG_SIZE*1.0)),
        transforms.CenterCrop(IMG_SIZE),
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [5]:
dataset =  EvalDataset(CROP_IMG_DIR, data_transforms) 
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
print(len(dataset), dataset.num_labels)

100%|██████████| 5000/5000 [00:00<00:00, 77019.34it/s]

5000 500





In [6]:
class MetricModel(nn.Module):
    def __init__(self, weights_path, device):
        super(MetricModel, self).__init__()
        
        backbone = BACKBONE
        backbone.fc = nn.Linear(in_features=backbone.fc.in_features, out_features=METRIC_DIM)
        self.backbone = backbone

        weights = torch.load(weights_path)
        self.device = device
        self.backbone.load_state_dict(weights)
        self.to(device)
        self.eval()
        
    def forward(self, input_imgs):
        with torch.no_grad():
            x = self.backbone(input_imgs)
            x = F.normalize(x)
        return x
    
model = MetricModel(WEIGHTS_PATH, device)

In [7]:
total_outs = []
total_labels = []
total_paths = []

for imgs, labels, paths in tqdm(dataloader):
    imgs = imgs.to(model.device)
    outs = model(imgs)
    
    total_outs.append(outs)
    total_paths.extend(paths)
    total_labels.extend(labels.tolist())
    
total_outs = torch.cat(total_outs).cpu()

100%|██████████| 79/79 [00:07<00:00, 10.08it/s]


In [8]:
f1_score = metric_eval.calc_f1(total_outs, total_labels)
print(f'f1_score:{f1_score:.04f}')

100%|██████████| 5000/5000 [01:15<00:00, 66.44it/s]

f1_score:0.8436



