In [4]:
import torch
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
import numpy as np
import torch.nn as nn
import os
from EigenvalueTools import calc_game_eigs, calc_full_game_eigs
from MoG import kde, dim_vis, real_builder_circle, real_builder_diamond
from ObjectiveTools import objective_function
from GameLosses import gan_model_mog, gan_loss_mog, gan_loss_mog_fixed
import glob
import scipy.stats as stats
import seaborn as sns
from seaborn import xkcd_rgb as xkcd
sns.set_style('whitegrid')
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
def save_mog_results(config, alg, size, steps, track_eigs, show=False):

    np.random.seed(0)
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    path = os.path.join(os.getcwd(), 'MoGResults')
    fig_dir = os.path.join(path, 'Figs', config+'_'+alg)
    if not os.path.exists(fig_dir): os.makedirs(fig_dir)
    data_dir = os.path.join(path, 'Data', config+'_'+alg)
    if not os.path.exists(data_dir): os.makedirs(data_dir)

    n_latent = 16
    n_hidden = 32
    n_out = 2

    if config == 'diamond':
        activation_function = nn.Tanh()
        real_builder = real_builder_diamond
        objective = 'gan'
    elif config == 'circle':
        activation_function = nn.ReLU()
        real_builder = real_builder_circle
        objective = 'nsgan'
        
    if alg == 'stack':
        color = xkcd['orange']
    elif alg == 'simgrad':
        color = '#1f77b4'
        
    bbox=[-2, 2, -2, 2]  
    regularization = 1

    G,D = gan_model_mog(n_latent=n_latent, n_out=n_out, n_hidden=n_hidden, \
                        num_gen_layers=2,
                        num_disc_layers=1,
                        activation_function=activation_function)

    device = "cpu"
    G.to(device)
    D.to(device)


    real_data = torch.from_numpy(real_builder(size)).float().to(device) 
    kde(real_data.detach().numpy().T, show=show, save=os.path.join(fig_dir, 'real_data.png'), bbox=bbox)
    G_latent = torch.randn(size, n_latent).to(device)
    
    for step, track_eig in zip(steps, track_eigs):
        
        if track_eig:
            track_A = []
            track_D = []
            track_J = []
            track_SC = [] 
        
        track_kl = []
        
        file_paths = glob.glob(os.path.join(path, config + '_' + alg + '*', 'Checkpoint', 'checkpoint_step_'+str(step)))
        file_paths = sorted(file_paths, key=lambda x: int(x.split('/')[-3].split(config+'_'+ alg)[1].split('_')[0]))

        for count, file_path in enumerate(file_paths):
            checkpoint = torch.load(file_path, 
                       map_location=torch.device('cpu'))

            G.load_state_dict(checkpoint['state_gen'])
            D.load_state_dict(checkpoint['state_dis'])
            
            if track_eig:
                G_loss, D_loss = gan_loss_mog_fixed(G=G,D=D,G_latent=G_latent,real_data=real_data,objective=objective, device=device)
                A_eigs, D_eigs, D_reg_eigs, J_eigs, SC_reg_eigs, SC_reg_eigs2 = calc_full_game_eigs([G_loss, D_loss], [G,D], regularization)

                track_A.append(A_eigs)
                track_D.append(D_eigs)
                track_J.append(J_eigs)
                track_SC.append(SC_reg_eigs2)

            G_generated = G(G_latent.to(device)).cpu().detach().numpy()
            fake_kernel = stats.gaussian_kde(G_generated.T)
            real_kernel = stats.gaussian_kde(real_data.detach().numpy().T)
            xx, yy = np.mgrid[bbox[0]:bbox[1]:300j, bbox[2]:bbox[3]:300j]
            positions = np.vstack([xx.ravel(), yy.ravel()])
            kl = stats.entropy(pk=fake_kernel(positions), qk=real_kernel(positions))        
            track_kl.append(kl)

            kde(G_generated.T, show=show, save=os.path.join(fig_dir, 'generator_sample_'+str(step)+'_run_'+str(count)+'.png'), bbox=bbox)
            dim_vis(D, show=show, save=os.path.join(fig_dir, 'discriminator_sample_'+str(step)+'_run_'+str(count)+'.png'), bbox=bbox)

        kl_ranks = sorted([(i, val) for i, val in enumerate(track_kl)], key=lambda x: x[1])

        if track_eig:
            track_A = np.array(track_A)
            track_D = np.array(track_D)
            track_J = np.array(track_J)
            track_SC = np.array(track_SC)
            num_runs = len(file_paths)
            
            best_idx = kl_ranks[0][0]
            mid_idx = kl_ranks[4][0]
            
            prefixes = ['A_eigs_best_', 'A_eigs_mid_', 'D_eigs_best_', 'D_eigs_mid_', 'J_eigs_best_', 'J_eigs_mid_', 'SC_eigs_best_', 'SC_eigs_mid_']
            arr_list = [track_A, track_A, track_D, track_D, track_J, track_J, track_SC, track_SC]
            index_list = [best_idx, mid_idx, best_idx, mid_idx, best_idx, mid_idx, best_idx, mid_idx]
            
            def plot_eigs(num1, num2, prefix, arr, idx):
                fig, ax = plt.subplots(1, 1, figsize=(8, 4.5))
                ax.bar(range(num1+num2), np.concatenate([arr[idx][:num1], arr[idx][-num2:]]), color=color)
                ax.tick_params(labelsize=30)
                ax.set_xticklabels([])
                plt.tight_layout()
                plt.savefig(os.path.join(fig_dir, prefix+str(idx)+'_'+str(step)+'.png'),bbox_inches='tight')
                plt.close()
                
            num1 = 5
            num2 = 15
            
            for prefix, arr, idx in zip(prefixes, arr_list, index_list):
                plot_eigs(num1, num2, prefix, arr, idx)
            
            prefixes = ['A_eigs_all_', 'D_eigs_all_', 'J_eigs_all_', 'SC_eigs_all_']
            arr_list = [track_A, track_D, track_J, track_SC]
            
            def plot_min_max_eigs(prefix, arr):
                fig, ax = plt.subplots(1, 1, figsize=(8, 4.5))
                ax.bar(range(1, num_runs+1), arr[:, 0], color=xkcd['black'])
                ax.bar(range(1, num_runs+1), arr[:, -1], color=color)
                ax.tick_params(labelsize=22)
                ax.set_xticks(range(1, num_runs+1))
                ax.set_xlabel('Run Number', fontsize=30)
                plt.tight_layout()
                plt.savefig(os.path.join(fig_dir, prefix+str(step)+'.png'),bbox_inches='tight')
                plt.close()
            
            for prefix, arr in zip(prefixes, arr_list):
                plot_min_max_eigs(prefix, arr)
            
            
        track_kl = np.array(track_kl)

        if track_eig:
            np.save(os.path.join(data_dir, 'A_eigs_'+str(step)+'.npy'), track_A)
            np.save(os.path.join(data_dir, 'D_eigs_'+str(step)+'.npy'), track_D)
            np.save(os.path.join(data_dir, 'J_eigs_'+str(step)+'.npy'), track_J)
            np.save(os.path.join(data_dir, 'SC_eigs_'+str(step)+'.npy'), track_SC)
            np.save(os.path.join(data_dir, 'kl_'+str(step)+'.npy'), track_kl)
            
        np.save(os.path.join(data_dir, 'file_paths_'+str(step)+'.npy'), file_paths)
        np.save(os.path.join(data_dir, 'kl_ranks_'+str(step)+'.npy'), kl_ranks)

In [36]:
size = 4096
steps = [60000]
track_eigs = [True]
save_mog_results('diamond', 'simgrad', size, steps, track_eigs)
save_mog_results('diamond', 'stack', size, steps, track_eigs)

size = 4096
steps = [10000, 20000, 40000, 60000]
track_eigs = [False, False, False, True]
save_mog_results('circle', 'simgrad', size, steps, track_eigs)
save_mog_results('circle', 'stack', size, steps, track_eigs)

In [43]:
alg = 'simgrad'
config = 'diamond'
path = os.path.join(os.getcwd(), 'MoGResults')
data_dir = os.path.join(path, 'Data', config+'_'+alg)
kl_1 = np.load(os.path.join(path, data_dir, 'kl_ranks_60000.npy'))

In [44]:
alg = 'stack'
config = 'diamond'
path = os.path.join(os.getcwd(), 'MoGResults')
data_dir = os.path.join(path, 'Data', config+'_'+alg)
kl_2 = np.load(os.path.join(path, data_dir, 'kl_ranks_60000.npy'))

In [45]:
alg = 'simgrad'
config = 'circle'
path = os.path.join(os.getcwd(), 'MoGResults')
data_dir = os.path.join(path, 'Data', config+'_'+alg)
kl_3 = np.load(os.path.join(path, data_dir, 'kl_ranks_60000.npy'))

In [46]:
alg = 'stack'
config = 'circle'
path = os.path.join(os.getcwd(), 'MoGResults')
data_dir = os.path.join(path, 'Data', config+'_'+alg)
kl_4 = np.load(os.path.join(path, data_dir, 'kl_ranks_60000.npy'))

In [74]:
kls = [kl_1[:, 1], kl_2[:, 1], kl_3[:, 1], kl_4[:, 1]]

In [78]:
for kl_ in kls:
    print(kl_.mean(), kl_.std()/np.sqrt(10), kl_)

0.24146381957647453 0.04552455557049911 [0.06 0.07 0.11 0.13 0.19 0.26 0.32 0.35 0.46 0.47]
0.08140712915483082 0.02536214449485558 [0.04 0.04 0.05 0.05 0.05 0.05 0.06 0.07 0.09 0.32]
0.10056611134825191 0.04483815885506369 [0.01 0.03 0.04 0.05 0.06 0.07 0.07 0.08 0.5 ]
0.05864831104675753 0.023848842794551954 [0.01 0.01 0.01 0.01 0.02 0.03 0.04 0.04 0.2  0.21]
