In [32]:
import os
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.colors as colors
plt.rcParams['figure.figsize'] = (5.0, 0.8) 
import matplotlib.patches as mpatches
from util.color_util import *
import pickle
from random import shuffle
import torch.optim as optim
import colorsys
from model.RSA import *
from model.WM18 import *
from numpy import dot
from numpy.linalg import norm
from scipy import spatial
from colormath.color_objects import sRGBColor, LabColor
from colormath.color_conversions import convert_color
from colormath.color_diff import delta_e_cie2000
from skimage import io, color
import random
from tabulate import tabulate

In [33]:
RGB = True
EXTEND = True
NUM_EPOCHE = 500
RETRAIN = True
FOURIER_TRANSFORM = False
MODEL_NAME = "literal_listener_wm18"
SAMPLE_PER_COLOR = 1
LISTENER = True
COLOR_DIM = 54 if FOURIER_TRANSFORM else 3

In [34]:
# load triples
if EXTEND:
    triple_train = pickle.load( open( "../munroe/triple_train.p", "rb" ) )
    triple_dev = pickle.load( open( "../munroe/triple_dev.p", "rb" ) )
    triple_test = pickle.load( open( "../munroe/triple_test.p", "rb" ) )
else:
    triple_train = pickle.load( open( "../munroe/triple_train_reduce.p", "rb" ) )
    triple_dev = pickle.load( open( "../munroe/triple_dev_reduce.p", "rb" ) )
    triple_test = pickle.load( open( "../munroe/triple_test_reduce.p", "rb" ) )
    
# load colors
cdict_train_rgb = pickle.load( open( "../munroe/cdict_train.p", "rb" ) )
cdict_dev_rgb = pickle.load( open( "../munroe/cdict_dev.p", "rb" ) )
cdict_test_rgb = pickle.load( open( "../munroe/cdict_test.p", "rb" ) )

cdict_train = dict()
cdict_dev = dict()
cdict_test = dict()

if RGB:
    cdict_train = cdict_train_rgb
    cdict_dev = cdict_dev_rgb
    cdict_test = cdict_test_rgb
else:
    for c in cdict_train_rgb.keys():
        cdict_train[c] = torch.tensor(colors.rgb_to_hsv(cdict_train_rgb[c]))
    for c in cdict_dev_rgb.keys():
        cdict_dev[c] = torch.tensor(colors.rgb_to_hsv(cdict_dev_rgb[c]))
    for c in cdict_test_rgb.keys():
        cdict_test[c] = torch.tensor(colors.rgb_to_hsv(cdict_test_rgb[c]))

# load embeddings for this dataset only
embeddings = pickle.load( open( "../munroe/glove_color.p", "rb" ) )

# generate test sets
test_set = generate_test_set_inverse(triple_train, triple_test)

In [35]:
mse = nn.MSELoss(reduction = 'none')
cos = nn.CosineSimilarity(dim=1)
colorLoss = lambda source, target, wg: ((1-cos(wg, source-target)) + mse(source, target+wg).sum(dim=-1)).sum()

In [36]:
# net = LiteralListener(color_dim=COLOR_DIM)
net = WM18(color_dim=COLOR_DIM)
if RETRAIN:
    '''
    Skip this as you dont have to retrain!
    Main training loop
    '''
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    debug = False
    sample_per_color = SAMPLE_PER_COLOR

    for i in range(NUM_EPOCHE):
        net.train()
        loss = 0.0
        batch_num = 0
        batch_index = 0
        for batch_emb1, batch_emb2, batch_base_color, batch_base_color_raw, batch_target_color in \
            generate_batch(cdict_train, triple_train, embeddings,
                           sample_per_color=sample_per_color,
                           fourier=FOURIER_TRANSFORM,
                           listener=LISTENER):
            pred = net(batch_emb1, batch_emb2, batch_base_color)
            wg = batch_base_color_raw - pred           # calculate the wg for the loss to use
            batch_loss = colorLoss(batch_base_color_raw, batch_target_color, wg)
            loss += batch_loss
            batch_num += batch_emb1.shape[0]           # sum up total sample size
            batch_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if debug:
                print(f"Batch: {batch_index+1}, train loss:{batch_loss.detach().numpy()}")
            batch_index += 1
        if i % 100 == 0:
            print(f"Epoche: {i+1}, train loss:{loss.detach().numpy()}")
    # save the literal speaker to disk
    checkpoint = {"model" : net.state_dict(), "name" : MODEL_NAME}
    torch.save(checkpoint, "./save_model/" + MODEL_NAME + ".pth")
else:
    checkpoint = torch.load("./save_model/" + MODEL_NAME + ".pth")
    net.load_state_dict(checkpoint['model'])

Epoche: 1, train loss:404.3560791015625
Epoche: 101, train loss:156.42449951171875
Epoche: 201, train loss:80.58253479003906
Epoche: 301, train loss:43.66958999633789
Epoche: 401, train loss:32.82158660888672


In [37]:
net_predict = predict_color(net, test_set, cdict_test, embeddings,
                            sample_per_color=1, fourier=FOURIER_TRANSFORM,
                            listener=LISTENER)

predict seen_pair set with 312 samples.
predict unseen_pair set with 0 samples.
predict unseen_base set with 80 samples.
predict unseen_mod set with 0 samples.
predict unseen_fully set with 58 samples.
predict overall set with 450 samples.


In [40]:
def evaluate_color(net_predict, fmt="rgb", eval_target="pred", reduced=False, listener=False):
    evaluation_metrics = dict()
    for k in net_predict:
        evaluation_metrics[k] = dict()
        # we have 2 metrices to report
        evaluation_metrics[k]["cosine"] = []
        evaluation_metrics[k]["delta_E"] = []
        for triple in net_predict[k].keys():
            true = net_predict[k][triple]["true"]
            pred = net_predict[k][triple][eval_target]
            base = net_predict[k][triple]["base"]
            if reduced:
                pred = pred.mean(dim=0).unsqueeze(dim=0)
            sample_size = pred.shape[0]
            color_sim = 0.0
            color_delta_e = 0.0
            for i in range(sample_size):
                if fmt == "rgb":
                    pred_rgb = pred[i]
                    true_rgb = true[0]
                    base_rgb = base[0]
                else:
                    pred_rgb = torch.tensor(colors.hsv_to_rgb(pred[i]))  # rgb space for target color
                    true_rgb = torch.tensor(colors.hsv_to_rgb(true[0]))  # keep consistent with previous paper
                    base_rgb = torch.tensor(colors.hsv_to_rgb(base[0]))  # rgb space for target color
                # cosine metrics
                if listener:
                    cos_sim = 1 - spatial.distance.cosine(base_rgb - pred_rgb, base_rgb - true_rgb)
                else:
                    cos_sim = 1 - spatial.distance.cosine(pred_rgb - base_rgb, true_rgb - base_rgb)
                color_sim += cos_sim
                # delta_E
                c1 = sRGBColor(rgb_r=pred_rgb[0], rgb_g=pred_rgb[1], rgb_b=pred_rgb[2])
                c2 = sRGBColor(rgb_r=true_rgb[0], rgb_g=true_rgb[1], rgb_b=true_rgb[2])
                # Convert from RGB to Lab Color Space
                color1_lab = convert_color(c1, LabColor)
                # Convert from RGB to Lab Color Space
                color2_lab = convert_color(c2, LabColor)
                delta_e = delta_e_cie2000(color1_lab, color2_lab)
                color_delta_e += delta_e

            color_sim = color_sim*1.0 / sample_size  # color avg cosine
            color_delta_e = color_delta_e*1.0 / sample_size  # color avg cosine
            
            evaluation_metrics[k]["cosine"].append(color_sim)
            evaluation_metrics[k]["delta_E"].append(color_delta_e)
            
    # display evaluation metrices accordingly
    display_list = []
    for condition in evaluation_metrics.keys():
        cosine = evaluation_metrics[condition]["cosine"]
        delta_E = evaluation_metrics[condition]["delta_E"]
        cosine_str = "%s (%s)" % ('{:.3f}'.format(np.mean(cosine)), '{:.3f}'.format(np.std(cosine), ddof=1))
        delta_E_str = "%s (%s)" % ('{:.3f}'.format(np.mean(delta_E)), '{:.3f}'.format(np.std(delta_E), ddof=1))
        row = [condition, cosine_str, delta_E_str]
        display_list.append(row)
        
    print(tabulate(display_list, headers=['condition', 'cosine (std)', 'delta_E (std)']))
    
    return evaluation_metrics


In [41]:
evaluation_metrics = evaluate_color(net_predict, listener=LISTENER)

condition     cosine (std)    delta_E (std)
------------  --------------  ---------------
seen_pair     0.920 (0.154)   5.034 (3.479)
unseen_pair   nan (nan)       nan (nan)
unseen_base   0.531 (0.573)   14.051 (11.051)
unseen_mod    nan (nan)       nan (nan)
unseen_fully  0.133 (0.562)   17.208 (9.180)
overall       0.749 (0.439)   8.206 (8.028)
