In [4]:
import os
import pickle
import sys
sys.path.append("../") # adding root folder to the path

import torch 
import torchvision
from torchvision import transforms
from torchvision.models import *
from torchvision import datasets
from tqdm import tqdm

from MRL import *
from imagenetv2_pytorch import ImageNetV2Dataset
from argparse import ArgumentParser
from utils import *

# nesting list is by default from 8 to 2048 in powers of 2, can be modified from here.
BATCH_SIZE = 1024
IMG_SIZE = 256
CENTER_CROP_SIZE = 224
NESTING_LIST=[2**i for i in range(3, 12)]
ROOT=""
model_weight_path =""
output_dir =""

In [5]:
model = resnet50(False, weights=None)
model = load_from_old_ckpt(model, False, NESTING_LIST, extract_ft=True)
apply_blurpool(model)	
model.load_state_dict(get_ckpt(model_weight_path)) # Since our models have a torch DDP wrapper, we modify keys to exclude first 7 chars. 
model = model.cuda()
model.eval()

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
test_transform = transforms.Compose([
				transforms.Resize(IMG_SIZE),
				transforms.CenterCrop(CENTER_CROP_SIZE),
				transforms.ToTensor(),
				normalize])

dataset = torchvision.datasets.ImageFolder(ROOT+'train/', transform=test_transform)
idx_to_class = {v: k for k, v in dataset.class_to_idx.items()} 
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=8, shuffle=False)

In [None]:
ft_to_vecs = {ft : [] for ft in NESTING_LIST}
ft_to_label_to_vecs = {ft : {k : [] for k in dataset.class_to_idx} for ft in NESTING_LIST}

with torch.no_grad():
    for img_inputs, labels in tqdm(dataloader):
        logits = model(img_inputs.cuda())
        label_names = [idx_to_class[label.item()] for label in labels]

        for i, num_feat in enumerate(NESTING_LIST):
            for ft_vec, label in zip(logits[i], label_names):
                ft_vec = ft_vec.cpu()
                ft_to_vecs[num_feat].append(ft_vec)
                ft_to_label_to_vecs[num_feat][label].append(ft_vec)

with open(f"{output_dir}/ft_to_vecs.pkl", "wb") as file:
    pickle.dump(ft_to_vecs, file)

with open(f"{output_dir}/ft_to_label_to_vecs.pkl", "wb") as file:
    pickle.dump(ft_to_label_to_vecs, file)

for num_feat, vecs in ft_to_vecs.items():
    ft_dir = f"{output_dir}/ft_size_{num_feat}"
    os.makedirs(ft_dir, exist_ok = True)
    filepath = f"{ft_dir}/ft_{num_feat}.pack"
    save_fvecs(filepath, vecs)

for num_feat, label_to_vecs in ft_to_label_to_vecs.items():
    ft_dir = f"{output_dir}/ft_size_{num_feat}"
    os.makedirs(ft_dir, exist_ok = True)
    for label, vecs in label_to_vecs.items():
        out_dir = f"{ft_dir}/{label}"
        os.makedirs(out_dir, exist_ok = True)
        filepath = f"{out_dir}/{label}.pack"
        save_fvecs(filepath, vecs)

 73%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                      | 919/1252 [47:36<08:47,  1.58s/it]