In [None]:
# preamble
import os
import sys
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import time

print('Now: ' + time.strftime('%c'))
print(sys.version)

# pseudo-llh
# suffixes = ['0',       '1',       '2',        '3']
# descr =    ['bs = 10', 'bs = 50', 'bs = 100', 'bs = 500']

# distillation
# suffixes = ['10',      '11',      '12',       '13']
# descr =    ['bs = 10', 'bs = 50', 'bs = 100', 'bs = 500']

# symmetric
# suffixes = ['20',      '21',      '22',       '23']
# descr =    ['bs = 10', 'bs = 50', 'bs = 100', 'bs = 500']

# cooperative
suffixes = ['30',      '31',      '32',       '33']
descr =    ['bs = 10', 'bs = 50', 'bs = 100', 'bs = 500']

# bs = 500
# suffixes = ['3',              '13',           '23',        '33']
# descr =    ['disctiminative', 'distillation', 'symmetric', 'cooperative']

# baseline
# suffixes = ['40']
# descr =    ['optimal']

nfiles = len(suffixes)
names = ['./logs/log-' + s + '.txt' for s in suffixes]

# default colors from T10-palette
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']

# load data
data = [None] * nfiles
for i in range(nfiles):
    data[i] = np.loadtxt(names[i], comments='#', delimiter=' ', ndmin=2, usecols = (1,3, 5,7,9,11, 13,15))

# common plot parameters
plt.rcParams.update({'font.size': 16})
# maxx = 20000
maxx = -1 # means no limit


In [None]:
# Classification

plt.rcParams["figure.figsize"] = (30,12)
fig, ax = plt.subplots(2, 2)
fig.suptitle('Classification (c)')

for i in range(nfiles):
    ax[0,0].plot(data[i][:,1],data[i][:,2], color=colors[i]);
ax[0,0].set_ylabel('clen');
ax[0,0].grid(True);
if maxx>0: ax[0,0].set_xlim((0, maxx));

for i in range(nfiles):
    ax[0,1].plot(data[i][:,1],data[i][:,3], color=colors[i]);
ax[0,1].set_ylabel('closs');
ax[0,1].grid(True);
if maxx>0: ax[0,1].set_xlim((0, maxx));

ax[0,1].legend(descr);

for i in range(nfiles):
    ax[1,0].plot(data[i][:,1],data[i][:,4], data[i][:,1],data[i][:,5], '--', color=colors[i]);
ax[1,0].set_ylabel('accuracy');
ax[1,0].grid(True);
if maxx>0: ax[1,0].set_xlim((0, maxx));
ax[1,0].set_ylim((0.8, 1));

lossnames = ['./logs/loss_c-' + s + '.txt' for s in suffixes]
for i in range(nfiles):
    losses = np.loadtxt(lossnames[i])
    ax[1,1].plot(losses, color=colors[i]);
ax[1,1].set_ylabel('c-losses');
ax[1,1].grid(True);


In [None]:
# Image model

plt.rcParams["figure.figsize"] = (30,6)
fig, ax = plt.subplots(1, 3)
fig.suptitle('Image model (x)')

for i in range(nfiles):
    ax[0].plot(data[i][:,1],data[i][:,6], color=colors[i]);
ax[0].set_ylabel('xlen');
ax[0].grid(True);
if maxx>0: ax[0].set_xlim((0, maxx));

for i in range(nfiles):
    ax[1].plot(data[i][:,1],data[i][:,7], color=colors[i]);
ax[1].set_ylabel('xloss');
ax[1].grid(True);
if maxx>0: ax[1].set_xlim((0, maxx));

lossnames = ['./logs/loss_x-' + s + '.txt' for s in suffixes]
for i in range(nfiles):
    losses = np.loadtxt(lossnames[i])
    ax[2].plot(losses, color=colors[i]);
    # ax[2].bar(range(9), losses, color=colors[i]);
ax[2].set_ylabel('x-losses');
ax[2].grid(True);

ax[2].legend(descr);


In [None]:
# reconstructed images

plt.rcParams["figure.figsize"] = (16, 6*nfiles)

fig, ax = plt.subplots(nfiles,1)

imgfiles = ['./images/img_' + s + '.png' for s in suffixes]

if nfiles == 1:
    ax.imshow(mpimg.imread(imgfiles[0]));
    ax.set_xlabel(descr[0]);
else:
    for i in range(nfiles):
        ax[i].imshow(mpimg.imread(imgfiles[i]))
        ax[i].set_xlabel(descr[i])


In [None]:
# times

for i in range(nfiles):
    num = data[i].size//8
    last_time = data[i][num-1,0]
    last_it = data[i][num-1,1]
    speed = last_it/last_time
    print(descr[i], '(it/sec):' , speed, '\t', last_it, 'iterations')