In [258]:
import os, yaml, matplotlib.pyplot as plt, numpy as np, collections, torch, PIL
from PIL import Image
def load_process_data(main_dir):
    value_w2 = collections.defaultdict(list)
    for _dir in os.listdir(main_dir):
        method_name = _dir[0: _dir.find('_')]
        value = np.load(os.path.join(main_dir, _dir, 'sinkhorn_div.npy'))
        value_w2[method_name].append(value)
    # process data
    stds, means, final_value = {}, {}, {}
    for method_name in value_w2.keys():
        value_w2[method_name] = np.array(value_w2[method_name])
        stds[method_name] = np.std(value_w2[method_name], axis = 0, ddof = 1)
        means[method_name] = np.mean(value_w2[method_name], axis = 0)
        final_value[method_name] = means[method_name][-1]
    return means, stds, final_value

### Iter

In [259]:
main_dir = '../results_save/sketching_iter'
save_dir = '../figures/sketching'
settings = yaml.load(open('../load_results/plot_settings.yaml').read(), Loader = yaml.FullLoader)
# process_data
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
means, stds, final = load_process_data(main_dir)
# save the final result
with open(os.path.join(save_dir, 'sketching_iter_results.txt'), mode = 'w') as f:
    for method_name in settings['order']:
        if method_name not in final.keys():
            continue
        _val = final[method_name]
        f.write('{}: {:.3e} \n'.format(settings['label'][method_name], _val))
# plot the result
plt.figure(figsize = (7.2 * 2, 4.8 * 1))
plt.subplot(121)
for method_name in ['MMDF', 'MMDFCA', 'MMDFDK']:
    if method_name not in means.keys(): continue
    x_grid = np.linspace(0, 5000, len(means[method_name]))
    plt.errorbar(
        x_grid, means[method_name], yerr = stds[method_name], capsize = 3, 
        color = settings['color'][method_name], 
        linestyle = settings['linestyle'][method_name], 
        label = settings['label'][method_name], 
        alpha = 1.0, linewidth = 1.5
    )
plt.ylim(1e-4, 1.2e-3)
plt.yscale('log')
plt.xscale('linear')
plt.xlabel('Iterations', {'size': 12})
plt.ylabel('Sinkhorn Divergence', {'size': 12})
plt.tick_params(labelsize = 12)
plt.legend(fontsize = 14)
plt.subplot(122)
for method_name in ['SD', 'SDCA', 'SDDK']:
    if method_name not in means.keys(): continue
    x_grid = np.linspace(0, 300, len(means[method_name]))
    plt.errorbar(
        x_grid, means[method_name], yerr = stds[method_name], capsize = 3, 
        color = settings['color'][method_name], 
        linestyle = settings['linestyle'][method_name], 
        label = settings['label'][method_name], 
        alpha = 1.0, linewidth = 1.5
    )
plt.ylim(1e-6, 1.2e-5)
plt.yscale('log')
plt.xscale('linear')
plt.xlabel('Iterations', {'size': 12})
plt.ylabel('Sinkhorn Divergence', {'size': 12})
plt.tick_params(labelsize = 12)
plt.legend(fontsize = 14)
plt.savefig(os.path.join(save_dir, 'sketching_iter.pdf'), dpi = 120, bbox_inches = 'tight')
plt.close()

### Time

In [260]:
main_dir = '../results_save/sketching_time'
save_dir = '../figures/sketching'
settings = yaml.load(open('../load_results/plot_settings.yaml').read(), Loader = yaml.FullLoader)
# process_data
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
means, stds, final = load_process_data(main_dir)
# save the final result
with open(os.path.join(save_dir, 'sketching_time_results.txt'), mode = 'w') as f:
    for method_name in settings['order']:
        if method_name not in final.keys():
            continue
        _val = final[method_name]
        f.write('{}: {:.3e} \n'.format(settings['label'][method_name], _val))
# plot the result
plt.figure(figsize = (7.2 * 2, 4.8 * 1))
plt.subplot(121)
for method_name in ['MMDF', 'MMDFCA', 'MMDFDK']:
    if method_name not in means.keys(): continue
    x_grid = np.linspace(0, 360, len(means[method_name]))
    plt.errorbar(
        x_grid, means[method_name], yerr = stds[method_name], capsize = 3, 
        color = settings['color'][method_name], 
        linestyle = settings['linestyle'][method_name], 
        label = settings['label'][method_name], 
        alpha = 1.0, linewidth = 1.5
    )
plt.ylim(1e-4, 1.2e-3)
plt.yscale('log')
plt.xscale('linear')
plt.xlabel('Time (s)', {'size': 12})
plt.ylabel('Sinkhorn Divergence', {'size': 12})
plt.tick_params(labelsize = 12)
plt.legend(fontsize = 14)
plt.subplot(122)
for method_name in ['SD', 'SDCA', 'SDDK']:
    if method_name not in means.keys(): continue
    x_grid = np.linspace(0, 360, len(means[method_name]))
    plt.errorbar(
        x_grid, means[method_name], yerr = stds[method_name], capsize = 3, 
        color = settings['color'][method_name], 
        linestyle = settings['linestyle'][method_name], 
        label = settings['label'][method_name], 
        alpha = 1.0, linewidth = 1.5
    )
plt.ylim(8e-7, 1.2e-5)
plt.yscale('log')
plt.xscale('linear')
plt.xlabel('Time (s)', {'size': 12})
plt.ylabel('Sinkhorn Divergence', {'size': 12})
plt.tick_params(labelsize = 12)
plt.legend(fontsize = 14)
plt.savefig(os.path.join(save_dir, 'sketching_time.pdf'), dpi = 120, bbox_inches = 'tight')
plt.close()

### Particles

In [261]:
main_dir = '../results_save/sketching_iter'
save_dir = '../figures/sketching'
color_map = 'viridis'
settings = yaml.load(open('../load_results/plot_settings.yaml').read(), Loader = yaml.FullLoader)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
# load and display
support, mass = {}, {}
for _dir in os.listdir(main_dir):
    if 'S[0]' not in _dir: continue
    method_name = _dir[0: _dir.find('_')]
    support[method_name] = np.load(os.path.join(main_dir, _dir, 'support.npy'))[-1]
    mass[method_name] = np.load(os.path.join(main_dir, _dir, 'mass.npy'))[-1]
for method_name in support.keys():
    _support = support[method_name]
    _mass = mass[method_name]
    _mass = _mass / _mass.sum()
    fig = plt.figure(figsize = (3.6, 4.0))
    plt.hist2d(
        _support[:, 0], _support[:, 1], 
        bins = (80, 80), range = ((-0.05, 0.95), (-0.05, 1.05)), 
        weights = _mass, cmap = color_map, density = True
    )
    plt.xlim((0, 0.9))
    plt.ylim((0, 1.0))
    plt.axis('off')
    plt.savefig(os.path.join(save_dir, '%s.png'%(settings['label'][method_name])), dpi = 120, bbox_inches = 'tight')
    plt.close()

In [262]:
img_path = '../datasets/sketching/cheetah/target.jpg'
image = Image.open(img_path)
aspect_hw = 1.0 * image.height / image.width
pix = np.array(image)
original = pix
pix = pix[:, :, 0]
pix = 255 - pix
# create a meshgrid and interpret the image as a probability distribution on it
x_grid = torch.linspace(0, 1, steps = pix.shape[0]) # x is the height and y is the width, we will convert later
y_grid = torch.linspace(0, pix.shape[1] / pix.shape[0], steps = pix.shape[1])
x_mesh, y_mesh = torch.meshgrid(x_grid, y_grid, indexing = 'ij')
x_mesh = x_mesh.reshape(-1)
y_mesh = y_mesh.reshape(-1)
pix_arr = pix.reshape(-1)
tgt_support = []
tgt_mass = []
value_thr = 50
for value, x, y in zip(pix_arr, x_mesh, y_mesh):
    if value > value_thr:
        tgt_support.append(torch.tensor([y, 1 - x]))
        tgt_mass.append(torch.tensor(value, dtype = torch.float32))

tgt_support = torch.stack(tgt_support, dim = 0)
tgt_mass = torch.stack(tgt_mass, dim = 0)
tgt_mass = tgt_mass / tgt_mass.sum()
fig = plt.figure(figsize = (3.6, 4.0))
plt.hist2d(
    tgt_support.numpy()[:, 0], tgt_support.numpy()[:, 1], 
    bins = (len(torch.unique(tgt_support[:, 0])), len(torch.unique(tgt_support[:, 1]))), 
    weights = tgt_mass.numpy(), cmap = color_map, density = True
)
plt.xlim((0, 0.9))
plt.ylim((0, 1.0))
plt.axis('off')
plt.savefig(os.path.join(save_dir, '%s.png'%('target')), dpi = 120, bbox_inches = 'tight')
plt.close()
