In [1]:
import os
import json
import torch
import cv2 as cv
import clip
import numpy as np
import torch.nn as nn

import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm

import matplotlib.pyplot as plt

from models import *
from deepfashion_loader import *

%matplotlib widget

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    torch.cuda.empty_cache()
        
# Configuration params
cfg = {'root'                : '../ifetch',
       'arch'                : 'resnet50',
       'train_dict_path'     : './preprocessed/deepfashion_inshop_train.pt',
       'train_features_path' : None, #'./preprocessed/deepfashion_inshop_train_' + 'resnet50' + '.pt',
       'dev_dict_path'       : './preprocessed/deepfashion_inshop_val.pt',
       'dev_features_path'   : None, #'./preprocessed/deepfashion_inshop_val_' + 'resnet50' + '.pt',
       'taxonomy_path'       : './deepfashion_taxonomy_hyperbolic.pt',
       'bbox_path'           : '../ifetch/deepfashion/in_shop/list_bbox_inshop.txt',
       'batch_size'          : 1024}

data_T = T.Compose([T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

samples = DeepFashionData(root=cfg['root'],
                          data_dict_path=cfg['dev_dict_path'],
                          features_path=cfg['dev_features_path'],
                          taxonomy_path=cfg['taxonomy_path'],
                          transforms=data_T,
                          bbox_path=cfg['bbox_path'])

loader = DataLoader(samples,
                    batch_size=cfg['batch_size'],
                    shuffle=False,
                    collate_fn=samples.collate_fn)

model = HyperbolicFeat()
model = model.to(device)
model.load_state_dict(torch.load('./exp_hyperbolic_2/best_weights.pth'))
model.eval();

In [4]:
i  = np.random.randint(len(samples))
im = Image.open(os.path.join(samples.root, samples.path[i]))
im = im.convert(mode='RGB')

plt.close('all')
plt.figure(figsize=(10,7))
plt.subplot(1,3,1)

imnp   = np.array(im)
bbox   = samples.bbox[os.path.join('img', samples.gender[i], samples.cat[i], samples.id[i], samples.filename[i])][2:]
width  = bbox[2] - bbox[0]
height = bbox[3] - bbox[1]
cv.rectangle(imnp, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255,0,0), 2)

margin_x  = max((100-width) // 2, 0)
margin_y  = max((100-height) // 2, 0)
new_bbox  = [max(bbox[0]-margin_x,0), max(bbox[1]-margin_y,0), min(bbox[2]+margin_x,im.size[0]), min(bbox[3]+margin_y, im.size[1])]
cv.rectangle(imnp, (new_bbox[0], new_bbox[1]), (new_bbox[2], new_bbox[3]), (0,255,0), 2)
plt.imshow(imnp)

plt.subplot(1,3,2)
plt.imshow(im.crop(bbox))

plt.subplot(1,3,3)
plt.imshow(im.crop(new_bbox))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7f441937bd90>

In [150]:
proto_node_labels = list(samples.taxonomy.keys())
proto_cat_labels  = []

for label in proto_node_labels:
    if len(label.split('/')) == 2:
        proto_cat_labels.append(label)
        
proto_cat_embed = torch.stack([samples.taxonomy[label] for label in proto_cat_labels], dim=0).to(device)

print(proto_cat_labels)

['WOMEN/Dresses', 'WOMEN/Skirts', 'WOMEN/Blouses_Shirts', 'MEN/Sweatshirts_Hoodies', 'WOMEN/Cardigans', 'WOMEN/Jackets_Coats', 'WOMEN/Sweaters', 'WOMEN/Tees_Tanks', 'WOMEN/Shorts', 'WOMEN/Rompers_Jumpsuits', 'WOMEN/Graphic_Tees', 'WOMEN/Pants', 'MEN/Shorts', 'MEN/Sweaters', 'MEN/Denim', 'MEN/Tees_Tanks', 'WOMEN/Sweatshirts_Hoodies', 'MEN/Pants', 'WOMEN/Denim', 'MEN/Jackets_Vests', 'WOMEN/Leggings', 'MEN/Shirts_Polos', 'MEN/Suiting']


# Category Classification

In [153]:
true_cat_labels = []
true_pid_labels = []
pred_cat_labels = []

metric = torch.ones(128).to(device)
metric[-1] = -1

total = 0
embeddings = []
with torch.no_grad():
    for batch_dict in tqdm(loader):
        batch_gender_idx = batch_dict['gender']
        batch_cat_idx    = batch_dict['cat']
        batch_pid_idx    = batch_dict['id']
        batch_img        = batch_dict['img'].to(device)
        
        batch_embed = model(batch_img)
        embeddings.append(batch_embed)
        
        distances = torch.acosh( -torch.matmul(metric * batch_embed, proto_cat_embed.transpose(1,0)) )
        
        pred_cat_idx = torch.argmin(distances, dim=1, keepdim=False)
        
        for i in range(pred_cat_idx.shape[0]):
            pred_cat_label = proto_cat_labels[pred_cat_idx[i]]
            true_cat_label = samples.voc['gender']._idx2word[batch_gender_idx[i].item()] + '/' + samples.voc['cat']._idx2word[batch_cat_idx[i].item()]
            
            true_cat_labels.append(true_cat_label)
            pred_cat_labels.append(pred_cat_label)

100%|██████████| 13/13 [00:44<00:00,  3.46s/it]


In [145]:
distances[0,:]

tensor([2.3013, 2.2784, 2.2992, 4.2808, 2.2778, 2.2811, 2.2828, 0.2809, 2.2814,
        2.2798, 2.2742, 2.2816, 4.2812, 4.2807, 4.2804, 4.2833, 2.2772, 4.2810,
        2.2766, 4.2806, 2.2755, 4.2808, 4.2800], device='cuda:0')

In [141]:
acc = (np.array(true_cat_labels) == np.array(pred_cat_labels)).sum() / len(pred_cat_labels)

print("Acc (%) : {}".format(100 * acc))

Acc (%) : 77.98921661909293


In [142]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

cm = confusion_matrix(true_cat_labels, pred_cat_labels, normalize='true', labels=proto_cat_labels)
df_cm = pd.DataFrame(cm, index = [i for i in proto_cat_labels], columns = [i for i in proto_cat_labels])
plt.figure(figsize = (12,9))

sn.set(font_scale=.5)
sn.heatmap(df_cm, annot=True, fmt='.2f')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<AxesSubplot:>

In [154]:
pool = torch.cat(embeddings)

In [177]:
K = 4
i = np.random.randint(len(samples))
out = samples[i]
query_img = out[-2]

img_tensor = samples.transforms(query_img).unsqueeze(0).to(device)
img_embed  = model(img_tensor)

scores = torch.acosh(-torch.matmul(metric * img_embed, pool.transpose(1,0))).cpu().squeeze()

match_scores, match_idx = torch.topk(scores, K+1, largest=False)

print(match_scores)

plt.close('all')
fig = plt.figure(figsize=(14,3))
plt.subplot(1,K+1,1)
plt.imshow(query_img)
plt.axis('off')
#plt.title("Query: " + samples.voc['cat']._idx2word[cats[match_idx[0].item()]], fontsize=7)
for i in range(K):
    plt.subplot(1,K+1,i+2)
    plt.imshow(samples[match_idx[i+1].item()][-2])
    #plt.title(samples.voc['cat']._idx2word[cats[match_idx[i+1].item()]], fontsize=7)
    plt.axis('off')

tensor([0.0000, 0.0398, 0.0554, 0.0594, 0.0620], grad_fn=<TopkBackward>)


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [136]:
i = np.random.randint(len(samples))
out = samples[i]
print(out[1])
print(out[0])
plt.figure()
plt.imshow(out[-2])
plt.axis('off')

Tees_Tanks
WOMEN


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(-0.5, 223.5, 223.5, -0.5)