In [None]:
import sys 
sys.path.append("../") # adding root folder to the path

import torch 
import torchvision
from torchvision import transforms
from torchvision.models import *
from torchvision import datasets
from tqdm import tqdm

from MRL import *
from imagenetv2_pytorch import ImageNetV2Dataset
from argparse import ArgumentParser
from utils import *

# nesting list is by default from 8 to 2048 in powers of 2, can be modified from here.
BATCH_SIZE = 1024
IMG_SIZE = 256
CENTER_CROP_SIZE = 224
NESTING_LIST=[2**i for i in range(3, 12)]
ROOT="" # path to validation datasets
model_weight_path = ""

In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self, base_model):
        super(FeatureExtractor, self).__init__()
        # Removing the last fully connected layer
        self.features = nn.Sequential(*list(base_model.children())[:-1])

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return x

model = resnet50(False, weights=None)
model = load_from_old_ckpt(model, False, NESTING_LIST, extract_ft=True)
apply_blurpool(model)	
model.load_state_dict(get_ckpt(model_weight_path)) # Since our models have a torch DDP wrapper, we modify keys to exclude first 7 chars. 
model = model.cuda()
model.eval()



normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
test_transform = transforms.Compose([
				transforms.Resize(IMG_SIZE),
				transforms.CenterCrop(CENTER_CROP_SIZE),
				transforms.ToTensor(),
				normalize])

dataset = torchvision.datasets.ImageFolder(ROOT+'train/', transform=test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=8, shuffle=False)

In [6]:
with torch.no_grad():
    enumerable = tqdm(enumerate(dataloader))

    for ii, (img_input, target) in enumerable:
        logits = model(img_input.cuda())
        for i, num_feat in enumerate(NESTING_LIST):
            ft_vec = logits[i][:, :num_feat]
            print(ft_vec.shape)
        break

0it [00:05, ?it/s]

torch.Size([1024, 8])
torch.Size([1024, 16])
torch.Size([1024, 32])
torch.Size([1024, 64])
torch.Size([1024, 128])
torch.Size([1024, 256])
torch.Size([1024, 512])
torch.Size([1024, 1024])
torch.Size([1024, 2048])



