In [1]:
import argparse
import os
import numpy as np
from tqdm import tqdm
import faiss
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from dataloader import Data, DataQuery
from model import Extractor, MemoryBlock
from argument_parser import add_base_args, add_eval_args
from utils import split_labels,  compute_NDCG, get_target_attr
import constants as C

In [2]:
#attr_nums = [16, 17, 19, 14, 10, 15, 2, 11, 16, 7, 9, 15]
attr_num = [37, 81, 84]

In [3]:
model = Extractor(attr_num, backbone='alexnet', dim_chunk=340)

In [4]:
model.load_state_dict(torch.load('../models/DeepFashion/extractor_best.pkl'))

<All keys matched successfully>

In [5]:
file_root = "/home/jameslee/cmu/large-scale/fashion-iq/hm_data/"
img_root_path = "/home/jameslee/cmu/large-scale/fashion-iq/hm_data/images/"

In [6]:
gallery_data = Data(file_root, img_root_path,
                        transforms.Compose([
                            transforms.Resize((C.TARGET_IMAGE_SIZE, C.TARGET_IMAGE_SIZE)),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=C.IMAGE_MEAN, std=C.IMAGE_STD)
                        ]), mode='test')

In [7]:
gallery_loader = torch.utils.data.DataLoader(gallery_data, batch_size=64, shuffle=False,
                                     sampler=torch.utils.data.SequentialSampler(gallery_data),
                                     num_workers=16,
                                     drop_last=False)

In [8]:
model.cuda()
model.eval()

Extractor(
  (backbone): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
    (classifier): Sequential(
      (0): Dropout(p=0.5, inplace=False)
    

In [12]:
torch.cuda.is_available()

True

In [13]:
torch.cuda.device_count()

8

In [37]:
gallery_feat = []
gallery_attr = []
with torch.no_grad():
    for i, (img, _) in enumerate(tqdm(gallery_loader)):
        img = img.cuda()
        dis_feat, attr_out = model(img)
        gallery_feat.append(F.normalize(torch.cat(dis_feat, 1)).squeeze().cpu().numpy())
        gallery_attr.append(F.normalize(torch.cat(attr_out, 1)).squeeze().cpu().numpy())


  0%|                                                                 | 0/1643 [00:00<?, ?it/s][A
  0%|                                                       | 1/1643 [00:03<1:27:35,  3.20s/it][A
  0%|▏                                                      | 5/1643 [00:03<1:01:24,  2.25s/it][A
  1%|▎                                                        | 9/1643 [00:03<43:07,  1.58s/it][A
  1%|▍                                                       | 13/1643 [00:03<30:20,  1.12s/it][A
  1%|▌                                                       | 17/1643 [00:05<25:04,  1.08it/s][A
  1%|▋                                                       | 20/1643 [00:05<17:51,  1.52it/s][A
  1%|▊                                                       | 24/1643 [00:05<12:42,  2.12it/s][A
  2%|▉                                                       | 27/1643 [00:05<09:10,  2.94it/s][A
  2%|█                                                       | 31/1643 [00:05<06:36,  4.06it/s][A
  2%|█▏  

 18%|█████████▊                                             | 294/1643 [00:43<03:26,  6.53it/s][A
 18%|█████████▉                                             | 298/1643 [00:44<02:43,  8.21it/s][A
 18%|██████████                                             | 302/1643 [00:44<02:04, 10.73it/s][A
 19%|██████████▏                                            | 305/1643 [00:45<04:28,  4.97it/s][A
 19%|██████████▎                                            | 308/1643 [00:45<04:04,  5.45it/s][A
 19%|██████████▍                                            | 310/1643 [00:46<03:15,  6.82it/s][A
 19%|██████████▌                                            | 314/1643 [00:46<02:41,  8.22it/s][A
 19%|██████████▋                                            | 318/1643 [00:46<02:03, 10.70it/s][A
 20%|██████████▋                                            | 321/1643 [00:47<04:21,  5.05it/s][A
 20%|██████████▊                                            | 324/1643 [00:48<03:56,  5.58it/s][A
 20%|█████

 34%|██████████████████▋                                    | 557/1643 [01:20<02:03,  8.79it/s][A
 34%|██████████████████▊                                    | 561/1643 [01:21<02:58,  6.07it/s][A
 34%|██████████████████▊                                    | 563/1643 [01:21<02:38,  6.82it/s][A
 34%|██████████████████▉                                    | 565/1643 [01:22<03:44,  4.79it/s][A
 35%|███████████████████                                    | 569/1643 [01:22<02:45,  6.48it/s][A
 35%|███████████████████▏                                   | 573/1643 [01:22<02:03,  8.63it/s][A
 35%|███████████████████▎                                   | 577/1643 [01:23<02:54,  6.12it/s][A
 35%|███████████████████▍                                   | 579/1643 [01:23<02:41,  6.58it/s][A
 35%|███████████████████▍                                   | 581/1643 [01:24<03:58,  4.46it/s][A
 36%|███████████████████▌                                   | 585/1643 [01:24<02:54,  6.05it/s][A
 36%|█████

 50%|███████████████████████████▍                           | 820/1643 [01:57<01:48,  7.56it/s][A
 50%|███████████████████████████▌                           | 822/1643 [01:58<03:51,  3.55it/s][A
 50%|███████████████████████████▋                           | 826/1643 [01:58<02:47,  4.87it/s][A
 51%|███████████████████████████▊                           | 830/1643 [01:58<02:03,  6.59it/s][A
 51%|███████████████████████████▉                           | 833/1643 [01:59<02:17,  5.87it/s][A
 51%|████████████████████████████                           | 837/1643 [02:00<02:57,  4.53it/s][A
 51%|████████████████████████████▏                          | 841/1643 [02:00<02:10,  6.14it/s][A
 51%|████████████████████████████▎                          | 845/1643 [02:00<01:37,  8.19it/s][A
 52%|████████████████████████████▍                          | 849/1643 [02:01<01:48,  7.35it/s][A
 52%|████████████████████████████▌                          | 853/1643 [02:03<02:37,  5.03it/s][A
 52%|█████

 69%|█████████████████████████████████████                 | 1129/1643 [02:41<01:20,  6.42it/s][A
 69%|█████████████████████████████████████▏                | 1133/1643 [02:41<00:59,  8.51it/s][A
 69%|█████████████████████████████████████▎                | 1137/1643 [02:41<00:45, 11.03it/s][A
 69%|█████████████████████████████████████▌                | 1141/1643 [02:43<01:41,  4.95it/s][A
 70%|█████████████████████████████████████▋                | 1145/1643 [02:43<01:14,  6.67it/s][A
 70%|█████████████████████████████████████▊                | 1149/1643 [02:43<00:55,  8.83it/s][A
 70%|█████████████████████████████████████▊                | 1152/1643 [02:43<00:45, 10.79it/s][A
 70%|█████████████████████████████████████▉                | 1155/1643 [02:43<00:36, 13.25it/s][A
 70%|██████████████████████████████████████                | 1158/1643 [02:45<01:57,  4.11it/s][A
 71%|██████████████████████████████████████▏               | 1162/1643 [02:45<01:26,  5.59it/s][A
 71%|█████

 88%|███████████████████████████████████████████████▊      | 1453/1643 [03:26<00:22,  8.53it/s][A
 89%|███████████████████████████████████████████████▉      | 1457/1643 [03:26<00:16, 11.03it/s][A
 89%|████████████████████████████████████████████████      | 1461/1643 [03:28<00:37,  4.80it/s][A
 89%|████████████████████████████████████████████████▏     | 1465/1643 [03:28<00:27,  6.46it/s][A
 89%|████████████████████████████████████████████████▎     | 1469/1643 [03:28<00:20,  8.52it/s][A
 90%|████████████████████████████████████████████████▍     | 1473/1643 [03:29<00:15, 11.02it/s][A
 90%|████████████████████████████████████████████████▌     | 1477/1643 [03:30<00:33,  5.00it/s][A
 90%|████████████████████████████████████████████████▋     | 1481/1643 [03:30<00:24,  6.75it/s][A
 90%|████████████████████████████████████████████████▊     | 1485/1643 [03:31<00:17,  8.91it/s][A
 91%|████████████████████████████████████████████████▉     | 1489/1643 [03:31<00:13, 11.54it/s][A
 91%|█████

In [41]:
gallery_attr[0].shape

(64, 202)

In [42]:
feats = gallery_attr[0]
for i in range(1, len(gallery_attr)):
    feats = np.vstack((feats, gallery_attr[i]))
print(feats.shape)

(105100, 202)


In [43]:
np.save("visual_features.npy", feats)

In [52]:
confidence = np.load("visual_features.npy")

In [55]:
conf_1 = confidence[:, :attr_num[0]]
conf_2 = confidence[:, attr_num[0]:attr_num[0]+attr_num[1]]
conf_3 = confidence[:, attr_num[0]+attr_num[1]:]

conf_1.shape, conf_2.shape, conf_3.shape

((105100, 37), (105100, 81), (105100, 84))

In [59]:
feat_1 = np.argmax(conf_1, axis=1)
feat_2 = np.argmax(conf_2, axis=1)
feat_3 = np.argmax(conf_3, axis=1)

In [61]:
vis_feat = np.stack((feat_1, feat_2, feat_3), axis=-1)
print(vis_feat.shape)

(105100, 3)


In [47]:
f = open("/home/jameslee/cmu/large-scale/fashion-iq/hm_data/imgs_test.txt", "r")
Lines = f.readlines()

In [64]:
Lines = [i.split(".")[0] for i in Lines]
print(Lines[0])

0685601024


In [66]:
print(len(Lines))
article_ids = np.array(Lines)

105100


In [50]:
import pandas as pd

In [68]:
df = pd.DataFrame({'article_id':article_ids, 'vis_feat_1':feat_1, 'vis_feat_2':feat_2, 'vis_feat_3':feat_3})

In [73]:
df.to_csv("visual_feat_attr.csv", index=False)