In [1]:
import numpy as np
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

In [2]:
def load_opt_arrays(dir="arr_2/", opt="sgd adagrad adam"):
    opt_methods_keys = opt.split(" ")
    opt_methods = dict()
    if not os.path.exists(dir):
        os.makedirs(dir)
    for method in opt_methods_keys:
        opt_methods[method + "_g"] = np.array([])
        opt_methods[method + "_d"] = np.array([])
        if os.path.isfile(dir + method +"_g.npy"):
            opt_methods[method + "_g"] = np.load(dir + method +"_g.npy")#[-34:]
        if os.path.isfile(dir + method +"_d.npy"):
            opt_methods[method + "_d"] = np.load(dir + method +"_d.npy")#[-34:]
    return opt_methods

def save_opt_methods(opt_methods, dir="arr_0/"):
    for key in opt_methods.keys():
        np.save(dir + key + ".npy", opt_methods[key])

opt_methods = load_opt_arrays()

In [3]:
print(opt_methods["adam_g"])
print(opt_methods["adam_g"][-3:])
print(len(opt_methods["adam_g"]))

[ 0.36009717  0.12011404  0.16176152  0.14423005  0.07751623  0.32306975
  0.12437218  0.01175684  0.01499612  0.02090667  0.08413048  0.28970531
  0.01811509  0.00917316  0.14867638  0.025017  ]
[ 0.00917316  0.14867638  0.025017  ]
16


In [4]:
save_opt_methods(opt_methods)

In [4]:
def print_opt_methods(opt_methods, dir="out_2/"):
    discr_keys = [key for key, val in opt_methods.items() if '_d' in key]
    gen_keys = [key for key, val in opt_methods.items() if '_g' in key]

    plt.figure(1, figsize=(10,10))
    plt.subplot(211)
    plt.title("Gradientsum Discriminator")
    plt.xlabel("Iterations")
    for key in discr_keys:
        val = opt_methods[key]
        x = np.array(range(len(val))) * int(1000)
        plt.plot(x, val, label=key.split("_")[0])
    plt.legend()

    plt.subplot(212)
    plt.title("Gradientsum Generator")
    for key in gen_keys:
        val = opt_methods[key]
        x = np.array(range(len(val))) * int(1000)
        plt.plot(x, val, label=key.split("_")[0])
    plt.legend()

    plt.savefig(dir + "Gradients.png")
    plt.show()
    
print_opt_methods(opt_methods)