## This serves as the purpose to reproduce what Han et al. had done in rgb based model.
https://arxiv.org/pdf/1909.07586.pdf

In [7]:
import os
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.colors as colors
plt.rcParams['figure.figsize'] = (5.0, 0.8) 
import matplotlib.patches as mpatches
from util.color_util import *
import pickle
from random import shuffle
import torch.optim as optim
import colorsys
from model.HSC19 import *
from numpy import dot
from numpy.linalg import norm
from scipy import spatial
from colormath.color_objects import sRGBColor, LabColor
from colormath.color_conversions import convert_color
from colormath.color_diff import delta_e_cie2000
from skimage import io, color
import random
from tabulate import tabulate

In [8]:
RGB = True
EXTEND = True

# load triples
if EXTEND:
    triple_train = pickle.load( open( "../munroe/triple_train.p", "rb" ) )
    triple_dev = pickle.load( open( "../munroe/triple_dev.p", "rb" ) )
    triple_test = pickle.load( open( "../munroe/triple_test.p", "rb" ) )
else:
    triple_train = pickle.load( open( "../munroe/triple_train_reduce.p", "rb" ) )
    triple_dev = pickle.load( open( "../munroe/triple_dev_reduce.p", "rb" ) )
    triple_test = pickle.load( open( "../munroe/triple_test_reduce.p", "rb" ) )
    
# load colors
cdict_train_rgb = pickle.load( open( "../munroe/cdict_train.p", "rb" ) )
cdict_dev_rgb = pickle.load( open( "../munroe/cdict_dev.p", "rb" ) )
cdict_test_rgb = pickle.load( open( "../munroe/cdict_test.p", "rb" ) )

cdict_train = dict()
cdict_dev = dict()
cdict_test = dict()

if RGB:
    cdict_train = cdict_train_rgb
    cdict_dev = cdict_dev_rgb
    cdict_test = cdict_test_rgb
else:
    for c in cdict_train_rgb.keys():
        cdict_train[c] = torch.tensor(colors.rgb_to_hsv(cdict_train_rgb[c]))
    for c in cdict_dev_rgb.keys():
        cdict_dev[c] = torch.tensor(colors.rgb_to_hsv(cdict_dev_rgb[c]))
    for c in cdict_test_rgb.keys():
        cdict_test[c] = torch.tensor(colors.rgb_to_hsv(cdict_test_rgb[c]))

# load embeddings for this dataset only
embeddings = pickle.load( open( "../munroe/glove_color.p", "rb" ) )

# generate test sets
test_set = generate_test_set(triple_train, triple_test)

In [9]:
mse = nn.MSELoss(reduction = 'none')
cos = nn.CosineSimilarity(dim=1)
colorLoss = lambda source, target, wg: ((1-cos(wg, target-source)) + mse(target, source+wg).sum(dim=-1)).sum()

In [10]:
net = HSC19_RGB(color_dim=3)
NUM_EPOCHE = 1000
RETRAIN = False
FOURIER_TRANSFORM = False
MODEL_NAME = "hsc19_rgb"

if RETRAIN:
    '''
    Skip this as you dont have to retrain!
    Main training loop
    '''
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    debug = False
    sample_per_color = 1                               # Set to 1 to be as the same as the original paper

    for i in range(NUM_EPOCHE):
        net.train()
        loss = 0.0
        batch_num = 0
        batch_index = 0
        for batch_emb1, batch_emb2, batch_base_color, batch_base_color_raw, batch_target_color in \
            generate_batch(cdict_train, triple_train, embeddings,
                           sample_per_color=sample_per_color,
                           fourier=False):
            pred = net(batch_emb1, batch_emb2, batch_base_color)
            wg = pred - batch_base_color_raw           # calculate the wg for the loss to use
            batch_loss = colorLoss(batch_base_color_raw, batch_target_color, wg)
            loss += batch_loss
            batch_num += batch_emb1.shape[0]           # sum up total sample size
            batch_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if debug:
                print(f"Batch: {batch_index+1}, train loss:{batch_loss.detach().numpy()}")
            batch_index += 1
        if i % 100 == 0:
            print(f"Epoche: {i+1}, train loss:{loss.detach().numpy()}")
    # save the literal speaker to disk
    checkpoint = {"model" : net.state_dict(), "name" : MODEL_NAME}
    torch.save(checkpoint, "./save_model/" + MODEL_NAME + ".pth")
else:
    checkpoint = torch.load("./save_model/" + MODEL_NAME + ".pth")
    net.load_state_dict(checkpoint['model'])

In [11]:
net_predict = predict_color(net, test_set, cdict_test, embeddings, sample_per_color=5, fourier=FOURIER_TRANSFORM)

predict seen_pair set with 312 samples.
predict unseen_pair set with 18 samples.
predict unseen_base set with 62 samples.
predict unseen_mod set with 41 samples.
predict unseen_fully set with 17 samples.
predict overall set with 450 samples.


In [12]:
evaluation_metrics = evaluate_color(net_predict)

condition     cosine (std)    delta_E (std)
------------  --------------  ---------------
seen_pair     0.881 (0.163)   5.358 (2.815)
unseen_pair   0.626 (0.512)   6.782 (2.502)
unseen_base   0.689 (0.478)   9.657 (6.201)
unseen_mod    0.382 (0.424)   13.402 (4.694)
unseen_fully  0.252 (0.586)   13.029 (6.174)
overall       0.771 (0.350)   7.073 (4.776)
