In [1]:
import torch, torchvision
import json
import torch.nn.functional as F
import PIL
from PIL import Image
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# directory containing all raw images
IMG_PATH = "/home/alex/CMU/777/ADARI/v2/full"
SENTS_PATH = "/home/alex/CMU/777/ADARI/ADARI_furniture_sents.json"

def open_json(path):
    f = open(path) 
    data = json.load(f) 
    f.close()
    return data 

def save_json(file_path, data):
    out_file = open(file_path, "w")
    json.dump(data, out_file)
    out_file.close()

In [23]:
class EncoderCNN(torch.nn.Module):
    def __init__(self):
        """Load the pretrained ResNet50 and replace top fc layer."""
        super(EncoderCNN, self).__init__()
        resnet = torchvision.models.resnet50(pretrained=True)
        resnet.eval()
        modules = list(resnet.children())[:-1]      # delete the last fc (classification) layer.
        self.resnet = torch.nn.Sequential(*modules)
        
    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        return features
    
    
class ImageDataset(Dataset):
    def __init__(self, path_to_images, im_names, patch_size = 8, img_size = 64):
        self.img_path = path_to_images
        self.images = im_names
        self.patch_size = patch_size
        self.img_size = img_size
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_name = self.images[index]
        
        name = self.img_path + "/" + image_name
        img = Image.open(name)
        
        img = torchvision.transforms.Compose([
        torchvision.transforms.Resize(self.img_size),
        torchvision.transforms.CenterCrop(self.img_size),
        torchvision.transforms.ToTensor()])(img)
        
        # pad just in case
        img = F.pad(img, (img.shape[2] % self.patch_size // 2, img.shape[2] % self.patch_size // 2,
                         img.shape[1] % self.patch_size // 2, img.shape[1] % self.patch_size // 2))
        
        patches = {}
        for i in range(img.shape[1] // self.patch_size):
            for j in range(img.shape[2] // self.patch_size):
                patches[(i,j)] =\
                    img[:, i*self.patch_size:(i+1)*self.patch_size, j*self.patch_size:(j+1)*self.patch_size]
                
        return patches, name

In [24]:

im2sents = open_json(SENTS_PATH)
dataset = ImageDataset(IMG_PATH, list(im2sents.keys()))
encoder = EncoderCNN()
encoder.to(device)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)



In [None]:
image_embeddings = dict() # dictionary to store image embeddings
encoder.eval()
with torch.no_grad():
    for i, (images, name) in enumerate(dataloader):
        # Encode patches with CNN
        patches = {}
        for patch, im in images.items():
            im = im.to(device)
            out = encoder(im)
            patches[patch] = out[0].cpu()
        
        image_embeddings[name] = patches
        
with open("fashionbert_resnet_patches8x8.json", "w") as f:
    json.dump(image_embeddings, f)