In [1]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms, models
import torch.nn as nn
import torch
import numpy as np

import PIL
import pickle

In [2]:
class ResNet(nn.Module):
    def __init__(self, net, dim):
        super(ResNet, self).__init__()
        self.net = net.eval().cuda()
        self.dim = dim
        self.penult_layer = self.net._modules.get('avgpool')
    
    def forward(self, x):
        output = self.get_embedding(self, x)
        return output
    
    def get_embedding(self, x):
        embedding = torch.cuda.FloatTensor(x.shape[0], self.dim, 1, 1).fill_(0)
        def copy(m, i ,o):
            embedding.copy_(o.data)
        hook = self.penult_layer.register_forward_hook(copy)
        self.net(x)
        hook.remove()
        return embedding.view(embedding.size()[0], -1)

In [3]:
image_directory = 'data/Flickr'
image_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
image_data = ImageFolder(image_directory, transform = image_transform)
image_from_idx = [i[0] for i in image_data.samples]

In [4]:
len(image_data)

269648

In [5]:
batch_size = 32
image_dataloader = DataLoader(image_data, batch_size = batch_size)

In [6]:
resnet152 = ResNet(models.resnet152(pretrained=True), 2048)
resnet18 = ResNet(models.resnet18(pretrained=True), 512)

In [None]:
targets = []
resnet18_feats = []
resnet152_feats = []

for batch_idx, (data, target) in enumerate(image_dataloader):
    if not type(data) in (tuple, list):
        data = (data,)
    data = tuple(d.cuda() for d in data)
    
    embeddings_18 = resnet18.get_embedding(*data).detach().cpu()
    embeddings_152 = resnet152.get_embedding(*data).detach().cpu()
    for i in range(batch_size):
        targets.append(int(target[i]))
        resnet18_feats.append(embeddings_18[i].numpy())
        resnet152_feats.append(embeddings_152[i].numpy())
    if batch_idx and not batch_idx % 100:
        np.save("data/nuswide_features/targets/batch_{}.npy".format(str(batch_idx).zfill(5)), targets)
        np.save("data/nuswide_features/resnet18/batch_{}.npy".format(str(batch_idx).zfill(5)), resnet18_feats)
        np.save("data/nuswide_features/resnet152/batch_{}.npy".format(str(batch_idx).zfill(5)), resnet152_feats)
        targets = []
        resnet18_feats = []
        resnet152_feats = []
        

In [8]:
batch_idx = len(image_dataloader)
np.save("data/nuswide_features/targets/batch_{}.npy".format(str(batch_idx).zfill(5)), targets)
np.save("data/nuswide_features/resnet18/batch_{}.npy".format(str(batch_idx).zfill(5)), resnet18_feats)
np.save("data/nuswide_features/resnet152/batch_{}.npy".format(str(batch_idx).zfill(5)), resnet152_feats)