# Load Feature Extractor

In [2]:
import torch
def load_model_feature_extractor(model_arch='resnet18'):
    model = torch.hub.load('pytorch/vision:v0.6.0', model_arch, pretrained=True)
    feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])
    return feature_extractor

feature_extractor = load_model_feature_extractor()

Using cache found in /Users/samuelkahn/.cache/torch/hub/pytorch_vision_v0.6.0


# Load Data

In [3]:
import utils.utils as utils
import os
from utils.torch_color_describer import (ContextualColorDescriber, create_example_dataset)
from utils.colors import ColorsCorpusReader
from sklearn.model_selection import train_test_split

utils.fix_random_seeds()
COLORS_SRC_FILENAME = os.path.join(
    "data", "colors", "filteredCorpus.csv")

dev_corpus = ColorsCorpusReader(
    COLORS_SRC_FILENAME,
    word_count=2,
    normalize_colors=True)

dev_examples = list(dev_corpus.read())
dev_rawcols, dev_texts = zip(*[[ex.colors, ex.contents] for ex in dev_examples])
dev_rawcols_train, dev_rawcols_test, dev_texts_train, dev_texts_test = \
    train_test_split(dev_rawcols, dev_texts)

# Convert Color Reps

In [4]:
import colorsys

# Convert from HLS to RGB
def convert_color_to_rgb(color):
    rgb = colorsys.hls_to_rgb(color[0],color[1],color[2])
    return rgb

def convert_to_imagenet_input(hsl):
    rgb = convert_color_to_rgb(hsl)

    r = torch.full((224,224),rgb[0]).unsqueeze(2)
    g = torch.full((224,224),rgb[1]).unsqueeze(2)
    b = torch.full((224,224),rgb[2]).unsqueeze(2)
    expanded_rep = torch.cat((r,g,b),2)
    
    expanded_rep = expanded_rep.permute(2,1,0).unsqueeze(0)
    
    return expanded_rep

def convert_color_tuple(colors):
    converted_colors = [[convert_to_imagenet_input(col) for col in cols] for cols in colors ] 
    return converted_colors


converted_data = convert_color_tuple(dev_rawcols)

In [5]:
print(converted_data)

KeyboardInterrupt: 

# Extract Features - here a batch is a set of 3 colors 

In [5]:
def extract_features_from_batch(extractor, examples):
    output = extractor(examples)
    shape = output.shape
    output = output.reshape((shape[0],shape[1]))
    return output


# Do the extraction 

In [28]:
extracted_features  = []
import sys
with torch.no_grad():
    for colors in converted_data:
        # Convert to 3x224x224 matrix
        cols_batch = torch.cat((colors[0],
                            colors[1],
                            colors[2]))
        
        # Run color through the feature extractor
        batch_extraction = extract_features_from_batch(feature_extractor,cols_batch)
        
        # convert to numpy array
        batch_extraction = batch_extraction.numpy()
        
        # append to list
        extracted_features.append(batch_extraction)
        
        length = len(extracted_features)
        
        # Print some stats 
        if length%100==0:
            total_size = sys.getsizeof(extracted_features)
            print(f"Running batch number: {length}, Size of array: {total_size/(1024**2)} Megabytes")
            



Running batch number: 100, Size of array: 0.00087738037109375 Megabytes
Running batch number: 200, Size of array: 0.0016021728515625 Megabytes
Running batch number: 300, Size of array: 0.0024261474609375 Megabytes
Running batch number: 400, Size of array: 0.0031585693359375 Megabytes
Running batch number: 500, Size of array: 0.00408172607421875 Megabytes
Running batch number: 600, Size of array: 0.0052490234375 Megabytes
Running batch number: 700, Size of array: 0.005950927734375 Megabytes
Running batch number: 800, Size of array: 0.00673675537109375 Megabytes
Running batch number: 900, Size of array: 0.00762176513671875 Megabytes
Running batch number: 1000, Size of array: 0.00861358642578125 Megabytes
Running batch number: 1100, Size of array: 0.00861358642578125 Megabytes
Running batch number: 1200, Size of array: 0.009735107421875 Megabytes
Running batch number: 1300, Size of array: 0.01099395751953125 Megabytes
Running batch number: 1400, Size of array: 0.01099395751953125 Megabyte

In [38]:
import pickle
pickle.dump( extracted_features, open( "data/colors/resnet18_color_embeddings.pickle", "wb" ) )

1000000.0