In [1]:
import os
import umap
import numba 

import torch
import numpy as np

import seaborn as sns
import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm

import matplotlib.pyplot as plt

from models import *
from deepfashion_loader import *

%matplotlib widget

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    torch.cuda.empty_cache()
        
# Configuration params
cfg = {'root'                : '../ifetch',
       'arch'                : 'resnet50',
       'train_dict_path'     : './preprocessed/deepfashion_inshop_train.pt',
       'train_features_path' : None, #'./preprocessed/deepfashion_inshop_train_' + 'resnet50' + '.pt',
       'dev_dict_path'       : './preprocessed/deepfashion_inshop_val.pt',
       'dev_features_path'   : None, #'./preprocessed/deepfashion_inshop_val_' + 'resnet50' + '.pt',
       'taxonomy_path'       : './deepfashion_taxonomy_hyperbolic.pt',
       'bbox_path'           : '../ifetch/deepfashion/in_shop/list_bbox_inshop.txt',
       'batch_size'          : 512}

    
data_T = T.Compose([T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

train_samples = DeepFashionData(root=cfg['root'],
                                data_dict_path=cfg['train_dict_path'],
                                features_path=cfg['train_features_path'],
                                taxonomy_path=cfg['taxonomy_path'],
                                bbox_path=cfg['bbox_path'],
                                transforms=data_T)

train_loader = DataLoader(train_samples,
                          batch_size=cfg['batch_size'],
                          shuffle=True,
                          collate_fn=train_samples.collate_fn)

dev_samples = DeepFashionData(root=cfg['root'],
                              data_dict_path=cfg['dev_dict_path'],
                              features_path=cfg['dev_features_path'],
                              taxonomy_path=cfg['taxonomy_path'],
                              transforms=data_T,
                              bbox_path=cfg['bbox_path'])

dev_loader = DataLoader(dev_samples,
                        batch_size=cfg['batch_size'],
                        shuffle=False,
                        collate_fn=dev_samples.collate_fn)

model = HyperbolicFeat().to(device);
model.load_state_dict(torch.load('./exp_hyperbolic_3/best_weights.pth'))
model.eval()

HyperbolicFeat(
  (backbone): Resnet50Feat(
    (backbone): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu):

In [3]:
embeddings = []
categories = []
genders    = []

with torch.no_grad():
    for i_batch, batch_dict in enumerate( tqdm(train_loader) ):
        node_emb = batch_dict['node_embed'].to(device)
        cat      = batch_dict['cat'].to(device)
        img      = batch_dict['img'].to(device)
        gender   = batch_dict['gender'].to(device)
        embed    = model(img)

        categories.append(cat.unsqueeze(1))
        embeddings.append(embed)
        genders.append(gender)
        
embeddings = torch.cat(embeddings).cpu().numpy()
categories = torch.cat(categories).cpu().numpy()
genders    = torch.cat(genders).cpu().numpy()

100%|██████████| 51/51 [01:10<00:00,  1.38s/it]


In [4]:
@numba.njit()
def h_dist(a,b):
    return np.arccosh(-(a @ b) + 2.0 * a[-1] * b[-1])

reducer = umap.UMAP(n_neighbors=120, min_dist=1.0, spread=1.0, metric=h_dist)

u = reducer.fit_transform(embeddings)

  warn(


In [5]:
full_label = []
for i in range(categories.shape[0]):
    full_label.append(dev_samples.voc['gender']._idx2word[genders[i]] + "/" + dev_samples.voc['cat']._idx2word[categories[i].item()])

uniques = np.unique(full_label)
N       = len(uniques)    

plt.figure(figsize=(10,10))
sns.set_theme(style='dark')
with sns.color_palette('tab20', N):
    for k in range(N):
        ii = [i for i, s in enumerate(full_label) if s==uniques[k]]
        plt.scatter(u[ii,0], u[ii,1], s=7.0, edgecolor='none', label=full_label[ii[0]]);

lgd = plt.legend(fontsize=6, loc="lower left", markerscale=3, frameon=False)
#plt.axis('off')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
uniques

array(['MEN/Denim', 'MEN/Jackets_Vests', 'MEN/Pants', 'MEN/Shirts_Polos',
       'MEN/Shorts', 'MEN/Suiting', 'MEN/Sweaters',
       'MEN/Sweatshirts_Hoodies', 'MEN/Tees_Tanks',
       'WOMEN/Blouses_Shirts', 'WOMEN/Cardigans', 'WOMEN/Denim',
       'WOMEN/Dresses', 'WOMEN/Graphic_Tees', 'WOMEN/Jackets_Coats',
       'WOMEN/Leggings', 'WOMEN/Pants', 'WOMEN/Rompers_Jumpsuits',
       'WOMEN/Shorts', 'WOMEN/Skirts', 'WOMEN/Sweaters',
       'WOMEN/Sweatshirts_Hoodies', 'WOMEN/Tees_Tanks'], dtype='<U25')