In [None]:
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import os
import re
import json
import collections
from collections import namedtuple
import glob

In [None]:
HyperParams = collections.namedtuple("HyperParams", "lr, beta1, l1_weight, gan_weight, ngf, ndf")
# assuming directori name is cp_ngf16_lr0.001_bs64 kind of format
# and containing hyper_params.json and output.log
params = {}
for filename in glob.glob("../data/output/*/*/hyper_params.json"):
    print("loading ", filename)
    with open(filename) as fd:
        json_str = fd.read()
        hyp = json.loads(json_str, object_hook=lambda d: namedtuple('HyperParams', d.keys())(*d.values()))
        dirname = os.path.dirname(filename)
        params[dirname] = hyp

In [None]:
def parseLog(dirname):
    filename=os.path.join(dirname, 'output.log')
    lines = [line.rstrip('\n') for line in open(filename)]
    
    pat_g_logss_L1  = re.compile(r".*g_loss_L1\s=\s(\d*\.\d+).*")
    pat_d_logss     = re.compile(r".*d_loss\s=\s(\d*\.\d+).*")
    pat_g_logss_GAN = re.compile(r".*g_loss_GAN\s=\s(\d*\.\d+).*")

    g_loss_L1  =list()
    d_loss     =list()
    g_loss_GAN =list()

    for line in lines:
        #print(v)
        result = pat_g_logss_L1.match(line)
        if result:
            g_loss_L1.append(float(result.group(1)))
        result = pat_d_logss.match(line)
        if result:
            d_loss.append(float(result.group(1)))
        result = pat_g_logss_GAN.match(line)
        if result:
            g_loss_GAN.append(float(result.group(1)))
    
    xs=np.arange(len(g_loss_L1))
    return xs, np.array(g_loss_L1), np.array(d_loss), np.array(g_loss_GAN)

In [None]:
def plot_learning_rate_change(batch_size, ngf):
    str_in_fn = "_bs{0}".format(batch_size)
    dict_bs = {k: v for k, v in params.items() if str_in_fn in k}
    dict_bs_ngf = {k: v for k,v in dict_bs.items() if v.ngf==ngf}#

    sorted_tuple = sorted(dict_bs_ngf.items(), key=lambda x:x[1].lr)

    fig = plt.figure()
    for tup in sorted_tuple:
        dirname = tup[0]
        xs, g_loss_L1, d_loss, g_loss_GAN = parseLog(dirname)
        plt.plot(xs, g_loss_L1)
    plt.legend(["lr={0}".format(x[1].lr) for x in sorted_tuple])
    plt.title("Generator Loss for batch size={0}, ngf={1}".format(batch_size, ngf))
    plt.show()

In [None]:
plot_learning_rate_change(64, 32)

In [None]:
plot_learning_rate_change(512, 32)



In [None]:
plot_learning_rate_change(1024, 32)

In [None]:
plot_learning_rate_change(64, 16)

In [None]:
plot_learning_rate_change(512, 16)

In [None]:
plot_learning_rate_change(1024, 16)

In [None]:
def plot_ngf_change(batch_size, lr):
    str_in_fn = "_bs{0}".format(batch_size)
    dict_bs = {k: v for k, v in params.items() if str_in_fn in k}
    dict_bs_lr = {k: v for k,v in dict_bs.items() if v.lr==lr}#
    
    sorted_tuple = sorted(dict_bs_lr.items(), key=lambda x:x[1].ngf)

    fig = plt.figure()
    for tup in sorted_tuple:
        dirname = tup[0]
        xs, g_loss_L1, d_loss, g_loss_GAN = parseLog(dirname)
        plt.plot(xs, g_loss_L1)
    plt.legend(["ngf={0}".format(x[1].ngf) for x in sorted_tuple])
    plt.title("Generator Loss for batch size={0}, lr={1}".format(batch_size, lr))
    plt.show()

In [None]:
plot_ngf_change(512, 0.00004)

In [None]:
plot_ngf_change(512, 0.0002)

In [None]:
plot_ngf_change(512, 0.001)