## This serves as the purpose to reproduce what Winn and Muresan. had done.
https://arxiv.org/pdf/1909.07586.pdf

In [1]:
import os
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.colors as colors
plt.rcParams['figure.figsize'] = (5.0, 0.8) 
import matplotlib.patches as mpatches
from util.color_util import *
import pickle
from random import shuffle
import torch.optim as optim
import colorsys
from model.WM18 import *
from numpy import dot
from numpy.linalg import norm
from scipy import spatial
from colormath.color_objects import sRGBColor, LabColor
from colormath.color_conversions import convert_color
from colormath.color_diff import delta_e_cie2000
from skimage import io, color
import random
from tabulate import tabulate

In [2]:
RGB = True

# load triples
triple_train = pickle.load( open( "../munroe/triple_train.p", "rb" ) )
triple_dev = pickle.load( open( "../munroe/triple_dev.p", "rb" ) )
triple_test = pickle.load( open( "../munroe/triple_test.p", "rb" ) )

# load colors
cdict_train_rgb = pickle.load( open( "../munroe/cdict_train.p", "rb" ) )
cdict_dev_rgb = pickle.load( open( "../munroe/cdict_dev.p", "rb" ) )
cdict_test_rgb = pickle.load( open( "../munroe/cdict_test.p", "rb" ) )

cdict_train = dict()
cdict_dev = dict()
cdict_test = dict()

if RGB:
    cdict_train = cdict_train_rgb
    cdict_dev = cdict_dev_rgb
    cdict_test = cdict_test_rgb
else:
    for c in cdict_train_rgb.keys():
        cdict_train[c] = torch.tensor(colors.rgb_to_hsv(cdict_train_rgb[c]))
    for c in cdict_dev_rgb.keys():
        cdict_dev[c] = torch.tensor(colors.rgb_to_hsv(cdict_dev_rgb[c]))
    for c in cdict_test_rgb.keys():
        cdict_test[c] = torch.tensor(colors.rgb_to_hsv(cdict_test_rgb[c]))

# load embeddings for this dataset only
embeddings = pickle.load( open( "../munroe/glove_color.p", "rb" ) )

# generate test sets
test_set = generate_test_set(triple_train, triple_test)

## WM18 Model Reproduce

In [3]:
mse = nn.MSELoss(reduction = 'none')
cos = nn.CosineSimilarity(dim=1)
colorLoss = lambda source, target, wg: ((1-cos(wg, target-source)) + mse(target, source+wg).sum(dim=-1)).sum()

In [14]:
net = WM18(color_dim=3)
NUM_EPOCHE = 500
RETRAIN = True
FOURIER_TRANSFORM = False

if RETRAIN:
    '''
    Skip this as you dont have to retrain!
    Main training loop
    '''
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    debug = False
    sample_per_color = 1                               # Set to 1 to be as the same as the original paper

    for i in range(NUM_EPOCHE):
        net.train()
        loss = 0.0
        batch_num = 0
        batch_index = 0
        for batch_emb1, batch_emb2, batch_base_color, batch_base_color_raw, batch_target_color in \
            generate_batch(cdict_train, triple_train, embeddings,
                           sample_per_color=sample_per_color,
                           fourier=False):
            pred = net(batch_emb1, batch_emb2, batch_base_color)
            wg = pred - batch_base_color_raw           # calculate the wg for the loss to use
            batch_loss = colorLoss(batch_base_color_raw, batch_target_color, wg)
            loss += batch_loss
            batch_num += batch_emb1.shape[0]           # sum up total sample size
            batch_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if debug:
                print(f"Batch: {batch_index+1}, train loss:{batch_loss.detach().numpy()}")
            batch_index += 1
        if i % 10 == 0:
            print(f"Epoche: {i+1}, train loss:{loss.detach().numpy()}")
    # save the literal speaker to disk
    checkpoint = {"model" : net.state_dict(), "name" : "wm18"}
    torch.save(checkpoint, "./save_model/wm18_model.pth")
else:
    checkpoint = torch.load("./save_model/wm18_model.pth")
    net.load_state_dict(checkpoint['model'])

Epoche: 1, train loss:358.961181640625
Epoche: 11, train loss:249.33009338378906
Epoche: 21, train loss:200.93533325195312
Epoche: 31, train loss:172.26718139648438
Epoche: 41, train loss:150.99639892578125
Epoche: 51, train loss:133.40618896484375
Epoche: 61, train loss:118.75057983398438
Epoche: 71, train loss:106.88467407226562
Epoche: 81, train loss:97.76130676269531
Epoche: 91, train loss:89.46797943115234
Epoche: 101, train loss:81.13385009765625
Epoche: 111, train loss:76.76742553710938
Epoche: 121, train loss:71.72525787353516
Epoche: 131, train loss:65.91504669189453
Epoche: 141, train loss:63.24393081665039
Epoche: 151, train loss:59.372764587402344
Epoche: 161, train loss:56.02265167236328
Epoche: 171, train loss:53.20472717285156
Epoche: 181, train loss:51.52340316772461
Epoche: 191, train loss:48.59611511230469
Epoche: 201, train loss:46.76017761230469
Epoche: 211, train loss:45.8540153503418
Epoche: 221, train loss:43.098480224609375
Epoche: 231, train loss:42.23448181152

In [17]:
net_predict = predict_color(net, test_set, cdict_test, embeddings, sample_per_color=1, fourier=FOURIER_TRANSFORM)

predict seen_pair set with 312 samples.
predict unseen_pair set with 18 samples.
predict unseen_base set with 62 samples.
predict unseen_mod set with 41 samples.
predict unseen_fully set with 17 samples.
predict overall set with 450 samples.


In [18]:
evaluation_metrics = evaluate_color(net_predict)

condition     cosine (std)    delta_E (std)
------------  --------------  ---------------
seen_pair     0.946 (0.138)   4.314 (3.135)
unseen_pair   0.845 (0.339)   4.972 (2.489)
unseen_base   0.799 (0.396)   8.735 (5.791)
unseen_mod    0.593 (0.466)   11.587 (5.677)
unseen_fully  0.519 (0.599)   13.865 (8.430)
overall       0.873 (0.298)   5.973 (5.052)


In [None]:
condition = "seen_pair"
triple_sample = random.choice(list(net_predict[condition].keys()))
print(triple_sample)
plt.rcParams['figure.figsize'] = (5.0, 0.8) 
sample = net_predict[condition][triple_sample]
plot_color_change_raw(sample["base"][0], sample["true"][0] - sample["base"][0], strength=2)
plot_color_change_raw(sample["base"][0], sample["pred"][0] - sample["base"][0], strength=2)