In [None]:
import sys
import os

import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import numpy as np

In [None]:
RESULTS_DIR = "results/one_layer_results"

In [None]:
data = pd.read_csv(os.path.join(RESULTS_DIR, 'lgrid.csv'), header=None)
print(data.shape)

In [None]:
data_arr = data.to_numpy().astype(np.float64)

In [None]:
lr1 = 0.05
lr2 = [0.0001, 0.0002, 0.0005, 0.0008, 0.001, 0.002, 0.005, 0.008, 0.01, 0.02, 0.05, 0.08, 0.1, 0.2, 0.5, 0.8, 1.0, 2.0, 5.0]

fig, ax = plt.subplots()

# Plot
num_samples = 500
geo_samples = [int(i) for i in np.geomspace(1, 1000 - 1, num=num_samples)]
im = ax.imshow(np.flip(data_arr[:, geo_samples].transpose(), axis=0), interpolation='none', aspect='auto', vmax=1.5, vmin=0)
fig.colorbar(im, ax=ax)

ax.set_xticks(np.arange(0, data_arr.shape[0], 2))
ax.set_xticklabels([round(lr/lr1, 3) for i, lr in enumerate(lr2) if i in np.arange(0, data_arr.shape[0], 2)])

ax.set_yticks(np.arange(0, num_samples, 50))
ax.set_yticklabels([geo_samples[i] for i in np.arange(num_samples - 1, 0, -50)])

plt.suptitle(fr"$\eta_{{\mathbf{{W}}}}={lr1}$")
ax.set_xlabel(r"$\eta_{\mathbf{v}} / \eta_{\mathbf{W}}$")
ax.set_ylabel("Epoch")

plt.savefig("plots/one_layer_im.pdf")

In [None]:
fig, ax = plt.subplots()
ax.set_xscale('log')

for i in np.arange(0, data_arr.shape[0], 2):    
    ax.plot(data_arr[i, :], label=fr"$\eta_{{\mathbf{{W}}}}={lr2[i]}$")

ax.set_xlabel("Epoch")
ax.set_ylabel("MSE")
ax.set_ylim([0, 2])
ax.legend(loc=1)

plt.suptitle(fr"$\eta_{{\mathbf{{W}}}}={lr1}$")
plt.savefig("plots/one_layer_ind.pdf")

In [None]:
# "Removing PCs"
data = pd.read_csv(os.path.join(RESULTS_DIR, 'lgrid_pc.csv'), header=None)
print(data.shape)

data_arr = data.to_numpy().astype(np.float64)

In [None]:
ll = np.linspace(0.1, 1.9, 19)
lr = 0.1

fig, ax = plt.subplots()

# Plot
num_samples = 500
geo_samples = [int(i) for i in np.geomspace(1, 1000 - 1, num=num_samples)]
im = ax.imshow(np.flip(data_arr[:, geo_samples].transpose(), axis=0), interpolation='none', aspect='auto', vmax=1.5, vmin=0)
fig.colorbar(im, ax=ax)

ax.set_xticks(np.arange(0, data_arr.shape[0], 2))
ax.set_xticklabels([round(l, 1) for i, l in enumerate(ll) if i in np.arange(0, data_arr.shape[0], 2)])

ax.set_yticks(np.arange(0, num_samples, 50))
ax.set_yticklabels([geo_samples[i] for i in np.arange(num_samples - 1, 0, -50)])

plt.suptitle(fr"$\eta_{{\mathbf{{W}}}}={lr}, \eta_{{\mathbf{{v}}}}={lr}$")
ax.set_xlabel(r"Lower $\lambda$")
ax.set_ylabel("Epoch")

plt.savefig("plots/one_layer_pcs_im.pdf")

In [None]:
fig, ax = plt.subplots()
ax.set_xscale('log')

for i in np.arange(0, data_arr.shape[0], 2):    
    ax.plot(data_arr[i, :], label=fr"Lower $\lambda={ll[i]}$")

ax.set_xlabel("Epoch")
ax.set_ylabel("MSE")
ax.set_ylim([0, 2])
ax.legend(loc=1)

plt.suptitle(fr"$\eta_{{\mathbf{{W}}}}={lr}, \eta_{{\mathbf{{v}}}}={lr}$")
plt.savefig("plots/one_layer_pcs_ind.pdf")