In [1]:
import sys
import torch
import torch.utils.data as data
import os
import time
import pickle
import numpy as np
from PIL import Image
import re

import json
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F

from IPython.display import clear_output
import torch.nn as nn
import torch.optim as optim

import itertools
import collections
import pdb
cuda = torch.cuda.is_available()
cuda

False

In [2]:
def update_progress(progress):
    bar_length = 20
    if isinstance(progress, int):
        progress = float(progress)
    if not isinstance(progress, float):
        progress = 0
    if progress < 0:
        progress = 0
    if progress >= 1:
        progress = 1

    block = int(round(bar_length * progress))
    clear_output(wait = True)
    text = "Progress: [{0}] {1:.1f}%".format( "#" * block + "-" * (bar_length - block), progress * 100)
    print(text)

#### Fetch Data

In [3]:
# IMAGES
im_path_fur = "../ADARI/images/furniture/v2/thumbs/small" # small are 64x64, medium 256x256 and large 512x512

# JSON_FILES
data_path_fur = "../ADARI/json_files/cleaned/ADARI_v2/furniture_v2_c.json"

# WORD EMBEDDINGS
word_embeddings_path = "../ADARI/word_embeddings/fur_5c_50d_sk_glove_ft.json"

# FILES FOR DATALOADER
dset_words_p = "../ADARI/json_files/ADARI_FUR_images_sentences_words/ADARI_v2_FUR_images_words.json"
dset_sentences_p = "../ADARI/json_files/ADARI_FUR_images_sentences_words/ADARI_v2_FUR_images_sentences.json"
dset_sentences_POS_p = "../ADARI/json_files/ADARI_FUR_images_sentences_words/ADARI_v2_FUR_images_sentences_tokenized.json"

### Load files

In [4]:
# Path for file dset_dataloader.json
def open_json(path):
    f = open(path) 
    data = json.load(f) 
    f.close()
    return data 

def flatten(S):
    if S == []:
        return S
    if isinstance(S[0], list):
        return flatten(S[0]) + flatten(S[1:])
    return S[:1] + flatten(S[1:])

#### Word embeddings 

In [5]:
word_embs = open_json(word_embeddings_path)

#### Dataset_dataloader

In [6]:
dset_words = open_json(dset_words_p)
dset_sents = open_json(dset_sentences_p)
dset_sents_tokenized = open_json(dset_sentences_POS_p)
# im2idx = open_json(im2idx_path)

#### Create dictionary to get index of image names, A small parser to get words as image: list of words

In [7]:
def is_date(word):
    rx = r"[0-9]+(?:st|[nr]d|th)"
    if re.findall(rx, word, flags=re.I) != []:
        return True
    return False

In [8]:
# The 2 dictionaries below for dataset dataloader
im2idx = dict()
im_words = dict()

# Temp lists 
image_names = list(dset_words.keys())
words = list(dset_words.values())

# Iterate over length of dictionary and get im2idx and im_words 
for i in range(len(image_names)):
    im = image_names[i]
    words_list = flatten(list(words[i].values()))
    cleaned_w = []
    for w in words_list:
        if w != '"the' and w != '"The' and len(w) > 1 and is_date(w) != True:
            cleaned_w.append(w)

    im_words[im] = cleaned_w
    im2idx[im] = i

In [9]:
img_size = 64
class ImageDataset(Dataset):
    def __init__(self, dataset, im2idx, path_to_images, train=True):
        self.img_path = path_to_images
        self.data = dataset
        self.im2idx = im2idx
        self.images = list(dataset.keys())
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_name = self.images[index]
        idx = self.im2idx[image_name]
        
        name = self.img_path + "/" + image_name
        img = Image.open(name)
        
        img = transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor()])(img)
        
        return img, idx
                
def collate(sequence):
    """
    "the input of this function is the output of function __getitem__"
    "this gets BATCH_SIZE times GETITEM! "
    if batch_Size == 2 --> sequence is a list with length 2. 
    Each list is a tuple (image, label) = ((3,64,64), label_length)
    """
    # Concatenate all images in the batch
    images = torch.cat(([batch_[0].view(-1, 3, 64, 64) for batch_ in sequence]), dim=0)
    
    # Pad labels with max_sequence_label
    idxs = torch.LongTensor([batch_[1] for batch_ in sequence])     
    
    return images, idxs


In [10]:
dataset = ImageDataset(im_words, im2idx, im_path_fur, train=True)

In [11]:
class EncoderCNN(nn.Module):
    def __init__(self):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc (classification) layer.
        self.resnet = nn.Sequential(*modules)
        
    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        return features

In [12]:
img_embedder = EncoderCNN()

In [13]:
batch_size = 1
num_workers = 8 if cuda else 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn = collate, shuffle=False, num_workers=num_workers, drop_last=False)

### Run this cell to get image embeddings from ResNet152 

In [None]:
image_embeddings = dict() # dictionary to store image embeddings
with torch.no_grad():
    for i, (images, idx) in enumerate(dataloader):
        update_progress(i/len(dataloader))
        batch_size = images.shape[0]
        images = images.to(device)
        idx = idx.to(device)   
        # Encode image with CNN
        features = img_embedder(images).squeeze(3).squeeze(2) # shape. [batch, 2048]
        
        # Dictonary key:image_idx, embedding
        image_embeddings[idx.item()] = features


Progress: [#-------------------] 5.4%


In [None]:
with open("resnet_image_embeddings.json", "w") as f:
    json.dump(image_embeddings, f)