# Load Feature Extractor

In [73]:
import torch
def load_model_feature_extractor(model_arch='resnet18'):
    model = torch.hub.load('pytorch/vision:v0.6.0', model_arch, pretrained=True)
    feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])
    return feature_extractor

feature_extractor = load_model_feature_extractor()

Using cache found in /Users/samuelkahn/.cache/torch/hub/pytorch_vision_v0.6.0


# Load Data

In [23]:
import utils
import os
from torch_color_describer import (ContextualColorDescriber, create_example_dataset)
from colors import ColorsCorpusReader
from sklearn.model_selection import train_test_split

utils.fix_random_seeds()
COLORS_SRC_FILENAME = os.path.join(
    "data", "colors", "filteredCorpus.csv")

dev_corpus = ColorsCorpusReader(
    COLORS_SRC_FILENAME,
    word_count=2,
    normalize_colors=True)

dev_examples = list(dev_corpus.read())
dev_rawcols, dev_texts = zip(*[[ex.colors, ex.contents] for ex in dev_examples])
dev_rawcols_train, dev_rawcols_test, dev_texts_train, dev_texts_test = \
    train_test_split(dev_rawcols, dev_texts)

# Convert Color Reps

In [55]:
import colorsys

# Convert from HLS to RGB
def convert_color_to_rgb(color):
    rgb = colorsys.hls_to_rgb(color[0],color[1],color[2])
    return rgb

def convert_to_imagenet_input(hsl):
    rgb = convert_color_to_rgb(hsl)

    r = torch.full((224,224),rgb[0]).unsqueeze(2)
    g = torch.full((224,224),rgb[1]).unsqueeze(2)
    b = torch.full((224,224),rgb[2]).unsqueeze(2)
    expanded_rep = torch.cat((r,g,b),2)
    
    expanded_rep = expanded_rep.permute(2,1,0).unsqueeze(0)
    
    return expanded_rep

def convert_color_tuple(colors):
    converted_colors = [[convert_to_imagenet_input(col) for col in cols] for cols in colors ] 
    return converted_colors


converted_data = convert_color_tuple(dev_rawcols)

# Extract Features - here a batch is a set of 3 colors 

In [105]:
def extract_features_from_batch(extractor, examples):
    output = extractor(examples)
    shape = output.shape
    output = output.reshape((shape[0],shape[1]))
    return output





In [None]:
extracted_features  = []
i = 1
for colors in converted_data:
    cols_batch = torch.cat((colors[0],
                            colors[1],
                            colors[2]))
                           
    batch_extraction = extract_features_from_batch(feature_extractor,cols_batch)
    extracted_features.append(extracted_features)
    if i%100==0:
        print(f"Running batch number: {i}")
    i+=1

Running batch number: 100
