In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import glob
import json
import os
from tqdm.notebook import tqdm
import numpy as np
from torchvision import transforms
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch.nn.functional as F

from model.triplet_model import FeatureAggregator

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
class CustomDataset(Dataset):
    def __init__(self, img_dir: str, frame_ids: list, json_path: str, inp_size:int, orig_im_size=(1080, 1920)) -> None:
        self.img_dir = img_dir
        self.frame_ids = frame_ids
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        self.inp_size = inp_size
        self.orig_h = orig_im_size[0]
        self.orig_w = orig_im_size[1]
        
    def __getitem__(self, idx: int) -> int:
        frame_id = self.frame_ids[idx]
        frame_info = self.data[frame_id.replace('_', '/')+'.jpg']
        
        jk_records = []
        sdcl_records = []
        cap_records = []
        bbox_records = []
        jk_imgs = []
        sdcl_imgs = []
        cap_imgs = []
        jk_bboxes = []
        sdcl_bboxes = []
        cap_bboxes = []
        sample_ids = []
        for grp_id, detect_results in frame_info.items():
            sample_ids.append(idx)
            
            # Create placeholders for non-existing data
            jk_records.append(0)
            sdcl_records.append(0)
            cap_records.append(0)
            jk_imgs.append(torch.zeros((3, self.inp_size, self.inp_size), dtype=torch.float32))
            sdcl_imgs.append(torch.zeros((3, self.inp_size, self.inp_size), dtype=torch.float32))
            cap_imgs.append(torch.zeros((3, self.inp_size, self.inp_size), dtype=torch.float32))
            jk_bboxes.append([-1, -1, -1, -1])
            sdcl_bboxes.append([-1, -1, -1, -1])
            cap_bboxes.append([-1, -1, -1, -1])
            
            for obj, bbox in detect_results.items():
                bbox = list(map(float, bbox))
                bbox[0] = bbox[0]/self.orig_w
                bbox[1] = bbox[1]/self.orig_h
                bbox[2] = bbox[2]/self.orig_w
                bbox[3] = bbox[3]/self.orig_h
                img = cv2.imread(os.path.join(self.img_dir, frame_id, obj, f'{grp_id}.jpg'))
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = Image.fromarray(img)
                img = self.preprocess(img)
                if obj == 'jockey':
                    jk_imgs[-1] = img
                    jk_bboxes[-1] = bbox
                    jk_records[-1] = 1
                if obj == 'sdcl':
                    sdcl_imgs[-1] = img
                    sdcl_bboxes[-1] = bbox
                    sdcl_records[-1] = 1
                if obj == 'cap':
                    cap_imgs[-1] = img
                    cap_bboxes[-1] = bbox
                    cap_records[-1] = 1
                    
        jk_imgs = torch.stack(jk_imgs, dim=0)
        sdcl_imgs = torch.stack(sdcl_imgs, dim=0)
        cap_imgs = torch.stack(cap_imgs, dim=0)
        jk_bboxes = torch.tensor(jk_bboxes)
        sdcl_bboxes = torch.tensor(sdcl_bboxes)
        cap_bboxes = torch.tensor(cap_bboxes)
                
        return jk_imgs, sdcl_imgs, cap_imgs, jk_bboxes, sdcl_bboxes, cap_bboxes, jk_records, sdcl_records, cap_records, sample_ids
    
    def preprocess(self, img: Image.Image) -> torch.tensor:
        w, h = img.size
        long_edge = max(w, h)
        resize_ratio = self.inp_size / long_edge
        resize_shape = (round(h*resize_ratio), round(w*resize_ratio))
        w_diff, h_diff = (self.inp_size - resize_shape[1]), (self.inp_size - resize_shape[0])
        l_pad = w_diff//2
        r_pad = w_diff - l_pad
        t_pad = h_diff//2
        b_pad = h_diff - t_pad
        padding = (l_pad, t_pad, r_pad, b_pad)

        transform = transforms.Compose([
            transforms.Resize(resize_shape),  # interpolation `BILINEAR` is applied by default
            transforms.Pad(padding=padding, fill=0, padding_mode='constant'),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = transform(img)

        return image
    
    def collate_fn(self, batch_data: list) -> tuple:
        jk_imgs, sdcl_imgs, cap_imgs, jk_bboxes, sdcl_bboxes, cap_bboxes, jk_records, sdcl_records, cap_records, sample_ids = zip(*batch_data)
        jk_imgs = torch.cat(jk_imgs)
        sdcl_imgs = torch.cat(sdcl_imgs)
        cap_imgs = torch.cat(cap_imgs)
        jk_bboxes = torch.cat(jk_bboxes)
        sdcl_bboxes = torch.cat(sdcl_bboxes)
        cap_bboxes = torch.cat(cap_bboxes)
        jk_records = [i for r in jk_records for i in r]
        sdcl_records = [i for r in sdcl_records for i in r]
        cap_records = [i for r in cap_records for i in r]
        sample_ids = [i for r in sample_ids for i in r]
        return jk_imgs, sdcl_imgs, cap_imgs, jk_bboxes, sdcl_bboxes, cap_bboxes, jk_records, sdcl_records, cap_records, sample_ids
    
    def __len__(self):
        return len(self.frame_ids)

In [4]:
def gen_triplet_dists(
    jk_features: torch.tensor, 
    sdcl_features: torch.tensor, 
    cap_features: torch.tensor, 
    jk_records: list, 
    sdcl_records: list, 
    cap_records: list, 
    sample_ids: list, 
    dist_metric: str
) -> tuple(torch.tensor, torch.tensor):
    '''
    s - len(sample_ids)
    
    jk_features (torch.float32): [s, f_dim]
    sdcl_features (torch.float32): [s, f_dim]
    cap_features (torch.float32): [s, f_dim]
    jk_records (List): length s
    sdcl_records (List): length s
    cap_records (List): length s
    sample_ids (List): length s
    dist_metric: 'cosine' or 'l2'
    '''
    anc = []
    pos = []
    neg = []
    for i, s_id in enumerate(sample_ids):
        jk_exists = jk_records[i]
        sdcl_exists = sdcl_records[i]
        cap_exists = cap_records[i]
        same_s_idcs = np.where(np.array(sample_ids) == s_id)[0]
        if len(same_s_idcs) < 2:
            continue
        
        if jk_exists and sdcl_exists:
            anc.append(jk_features[i])
            pos.append(sdcl_features[i])
            neg_id = np.random.choice(np.where(same_s_idcs != i)[0])
            obj = np.random.choice(['jk', 'sdcl', 'cap'])
            if obj == 'jk':
                neg.append(jk_features[neg_id])
            elif obj == 'sdcl':
                neg.append(sdcl_features[neg_id])
            elif obj == 'cap':
                neg.append(cap_features[neg_id])
                
        if sdcl_exists and cap_exists:
            anc.append(sdcl_features[i])
            pos.append(cap_features[i])
            neg_id = np.random.choice(np.where(same_s_idcs != i)[0])
            obj = np.random.choice(['jk', 'sdcl', 'cap'])
            if obj == 'jk':
                neg.append(jk_features[neg_id])
            elif obj == 'sdcl':
                neg.append(sdcl_features[neg_id])
            elif obj == 'cap':
                neg.append(cap_features[neg_id])
                
        if jk_exists and cap_exists:
            anc.append(jk_features[i])
            pos.append(cap_features[i])
            neg_id = np.random.choice(np.where(same_s_idcs != i)[0])
            obj = np.random.choice(['jk', 'sdcl', 'cap'])
            if obj == 'jk':
                neg.append(jk_features[neg_id])
            elif obj == 'sdcl':
                neg.append(sdcl_features[neg_id])
            elif obj == 'cap':
                neg.append(cap_features[neg_id])
    
    # torch.stack to avoid losing grad history
    anc = torch.stack(anc)
    pos = torch.stack(pos)
    neg = torch.stack(neg)
    
    if dist_metric == 'cosine':
        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        pos_dists = 1 - cos(anc, pos)
        neg_dists = 1 - cos(anc, neg)
    elif dist_metric == 'l2':
        l2 = nn.PairwiseDistance(p=2.0, eps=1e-06, keepdim=True)
        pos_dists = l2(anc, pos)
        neg_dists = l2(anc, neg)
    
    return pos_dists, neg_dists

In [6]:
json_path = './data/extracted_data.json'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
inp_size = 64
batch_size = 32
feat_dim = 256
lr = 0.001
n_epochs = 1000
dist_metric = 'l2'  # 'cosine' or 'l2'
margin = 250 if dist_metric == 'l2' else 2

model_dst = f'./triplet_models/triplet_{dist_metric}_{feat_dim}'
if not os.path.exists(model_dst):
    os.makedirs(model_dst)

# Train
train_img_dir = './data/all_extracted/train'
train_img_paths = glob.glob('./data/all_extracted/train/*/*/*.jpg')
train_frame_ids = list(set(map(lambda x: x.split('/')[-3], train_img_paths)))
train_set = CustomDataset(train_img_dir, train_frame_ids, json_path, inp_size)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=train_set.collate_fn,
                          drop_last=True, prefetch_factor=batch_size//2)

# val
val_img_dir = './data/all_extracted/val'
val_img_paths = glob.glob('./data/all_extracted/val/*/*/*.jpg')
val_frame_ids = list(set(map(lambda x: x.split('/')[-3], val_img_paths)))
val_set = CustomDataset(val_img_dir, val_frame_ids, json_path, inp_size)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=val_set.collate_fn,
                        drop_last=True, prefetch_factor=batch_size//2)

In [7]:
img_encoder = models.resnet18(pretrained=True).to(device)
img_encoder.train()
feat_agg = FeatureAggregator(img_encoder, feat_dim).to(device)

loss_fn = nn.MarginRankingLoss(margin=margin)
optim = torch.optim.Adam(feat_agg.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 5, eta_min=0, last_epoch=-1, verbose=False)

In [8]:
def validate(val_loader, model):
    model.eval()
    val_pos_dist = 0
    val_neg_dist = 0
    val_dist = 0
    for bi, data in enumerate(val_loader):
        jk_imgs, sdcl_imgs, cap_imgs, jk_bboxes, sdcl_bboxes, cap_bboxes, jk_records, sdcl_records, cap_records, sample_ids = data
        jk_imgs, sdcl_imgs, cap_imgs, jk_bboxes, sdcl_bboxes, cap_bboxes = list(map(lambda x: x.to(device), [jk_imgs, sdcl_imgs, cap_imgs, jk_bboxes, sdcl_bboxes, cap_bboxes]))

        with torch.no_grad():
            jk_feat = model(jk_imgs, jk_bboxes)
            sdcl_feat = model(sdcl_imgs, sdcl_bboxes)
            cap_feat = model(cap_imgs, cap_bboxes)

        pos_dists, neg_dists = gen_triplet_dists(jk_feat, sdcl_feat, cap_feat, jk_records, sdcl_records, cap_records, sample_ids, dist_metric)

        val_pos_dist += pos_dists.cpu().mean().item()
        val_neg_dist += neg_dists.cpu().mean().item()
        val_dist += abs(pos_dists.cpu() - neg_dists.cpu()).mean().item()
        
    val_pos_dist /= (bi+1)
    val_neg_dist /= (bi+1)
    val_dist /= (bi+1)
    
    return val_pos_dist, val_neg_dist, val_dist

In [None]:
best_val_dist = 0
for ep in range(n_epochs):
    ep_loss = 0
    train_pos_dist = 0
    train_neg_dist = 0
    train_dist = 0
    pbar = tqdm(train_loader, desc=f'Epoch {ep+1}/{n_epochs}')
    feat_agg.train()
    for bi, data in enumerate(pbar):
        jk_imgs, sdcl_imgs, cap_imgs, jk_bboxes, sdcl_bboxes, cap_bboxes, jk_records, sdcl_records, cap_records, sample_ids = data
        jk_imgs, sdcl_imgs, cap_imgs, jk_bboxes, sdcl_bboxes, cap_bboxes = list(map(lambda x: x.to(device), [jk_imgs, sdcl_imgs, cap_imgs, jk_bboxes, sdcl_bboxes, cap_bboxes]))

        jk_feat = feat_agg(jk_imgs, jk_bboxes)
        sdcl_feat = feat_agg(sdcl_imgs, sdcl_bboxes)
        cap_feat = feat_agg(cap_imgs, cap_bboxes)

        pos_dists, neg_dists = gen_triplet_dists(jk_feat, sdcl_feat, cap_feat, jk_records, sdcl_records, cap_records, sample_ids, dist_metric)
        loss = loss_fn(pos_dists, neg_dists, torch.ones_like(pos_dists)*-1)

        optim.zero_grad()
        loss.backward()
        optim.step()

        ep_loss += loss.detach().item()
        train_pos_dist += pos_dists.cpu().mean().item()
        train_neg_dist += neg_dists.cpu().mean().item()
        train_dist += abs(pos_dists.cpu() - neg_dists.cpu()).mean()

        postfix = {
            'Ep Loss': f'{ep_loss/(bi+1):.2f}',
            'Tr Pos Dist': f'{train_pos_dist/(bi+1):.4f}',
            'Tr Neg Dist': f'{train_neg_dist/(bi+1):.4f}',
            'Tr Dist': f'{train_dist/(bi+1):.4f}'
        }
        pbar.set_postfix(postfix)
        
    scheduler.step()
        
    val_pos_dist, val_neg_dist, val_dist = validate(val_loader, feat_agg)
    
    if val_dist > best_val_dist:
        best_val_dist = val_dist
        torch.save(feat_agg, os.path.join(model_dst, f'feat_agg_{val_dist:.4f}.pth'))
        print(f'Model saved at epoch {ep}')
    
    print(f'Tr Loss: {ep_loss/(bi+1):.2f}\nTr Pos Dist: {train_pos_dist/(bi+1):.4f} | Tr Neg Dist: {train_neg_dist/(bi+1):.4f} | \
Tr Dist: {train_dist/(bi+1):.4f}\nVal Pos Dist: {val_pos_dist:.4f} | Val Neg Dist: {val_neg_dist:.4f} | Val Dist: {val_dist:.4f} (best: {best_val_dist:.4f})')

Epoch 1/1000:   0%|          | 0/250 [00:00<?, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Model saved at epoch 0
Tr Loss: 9.29
Tr Pos Dist: 11.7811 | Tr Neg Dist: 136.4640 | Tr Dist: 125.1837
Val Pos Dist: 34.3587 | Val Neg Dist: 194.5626 | Val Dist: 162.9382 (best: 162.9382)


Epoch 2/1000:   0%|          | 0/250 [00:00<?, ?it/s]