In [None]:
import torch
from torch import nn
from torch.optim.lr_scheduler import LambdaLR

from train_utils import *
from train_xor import *

import numpy as np

import matplotlib
import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina'

In [None]:
#prepare training parameters

d = 2
n = 2**8
m = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 2**15

def my_lr(epoch):
    if epoch < 12:
        return 2**9
    if epoch < 8192:
        return 1
    if epoch < 16384:
        return 1 + 2**5 * (epoch / 8192 - 1)
    return 1 + 2**5 * 2**((epoch - 16384)/2**9)

data_params = {
    'n': n,
    'data_fn': gen_sym_data,
    'data_params': {'d': d, 'mu': 1, 'delta': 0.01, 'sigma': 0.005}
}
loader_params = {'batch_size': n, 'num_workers': 0}

#initializations for different figures
sc = np.array([1, 0.1, 0.001, 0.01]) * 0.0001 #evo1
#sc = np.array([1, 0.1, 0.01, 0.001]) * 0.0001 #evo2
#sc = np.array([1, 0.01, 0.1, 0.001]) * 0.0001 #evo3
init_weight = torch.tensor(
    [[sc[0], 0.0],
     [0.0, sc[1]],
     [-sc[2], 0.0],
     [0.0, -sc[3]]])
model_params = {'init_weight': init_weight}

loss_params = {}
optimizer_params = {
    'momentum': 0.0, 'weight_decay': 0.0, 'nesterov': False
}
scheduler_params = {'lr_lambda': my_lr}
correction_params = {'lr': 2**(-7)}

path = np.empty((epochs+1, 4, 2), dtype=np.float64)
train_params = {
    'val_epoch_params': {'weights_path': path},
    'get_loss_fn': get_loss_one_dim,
    'train_epoch_fn': train_epoch_epoch_sch,
    'val_epoch_fn': log_epoch
}

train_data_seed = 100
model_seed = 2

train_kwargs = {
    'data_fn': gen_same_data,
    'model_fn': gen_2d_model,
    'loss_fn': nn.BCEWithLogitsLoss,
    'scheduler_fn' : LambdaLR,
    'correct_training_params_fn': my_correct_training_params
}

In [None]:
#train model

model = get_trained_model(
    epochs,
    data_params,
    loader_params,
    model_params,
    loss_params,
    optimizer_params,
    scheduler_params,
    correction_params,
    train_params,
    train_data_seed,
    model_seed,
    model_seed,
    model_seed,
    **train_kwargs
)

angles = np.empty((epochs+1, 4))
angles[:, 0] = np.arctan2(path[:, 0, 1], path[:, 0, 0])
angles[:, 1] = np.arctan2(path[:, 1, 0], path[:, 1, 1])
angles[:, 2] = np.arctan2(path[:, 2, 1], -path[:, 2, 0])
angles[:, 3] = np.arctan2(path[:, 3, 0], -path[:, 3, 1])

scales = (path**2).sum(2)**0.5
fin_weight = model[0].weight.cpu()

In [None]:
#plot figure 1

matplotlib.use("pgf")
matplotlib.rcParams.update({
    'pgf.texsystem': 'pdflatex',
    'text.usetex': True,
})

def plot_four_times(ax, x, s, e, title, legend, font=7, label=6, width=0.5):
    for i in range(4):
        ax.plot(
            np.arange(e-s)+s, x[s:e, i], label=f"k={i+1}", linewidth=width*1.5)
    ax.set_title(title, fontsize=font)
    #ax.set_xlabel("Epoch", labelpad=0, fontsize=12)
    if legend:
        ax.legend(fontsize=label)
    ax.tick_params(axis='both', labelsize=label, length=width*4, width=width)
    plt.setp(ax.spines.values(), linewidth=width)

fig, ax = plt.subplots(2, 4, constrained_layout=True)
fig.set_size_inches(h=2.6, w=6.8)

for i, (title, x) in enumerate(zip(["Angles", "Scales"], [angles, scales])):
    for j, (s, e) in enumerate(
        zip([0, 0, 3584, 15872],
            [epochs+1, 3584, 15872, epochs+1])):
        plot_four_times(ax[i][j], x, s, e, title, j==0)

plt.gcf()
plt.savefig('test.pgf', bbox_inches='tight', pad_inches=0)
plt.close()

In [None]:
#plot first appendix figure

fig, ax = plt.subplots(4, 2, constrained_layout=True)
fig.set_size_inches(h=8.1, w=6.8)

for i, (title, x) in enumerate(zip(["Angles", "Scales"], [angles, scales])):
    for j, (s, e) in enumerate(
        zip([0, 0, 3584, 15872],
            [epochs+1, 3584, 15872, epochs+1])):
        plot_four_times(ax[j][i], x, s, e, title, j==0, 9, 8, 0.7)

plt.gcf()
plt.savefig('evo1.pgf', bbox_inches='tight', pad_inches=0)
plt.close()

In [None]:
#plot second and third appendix figure

fig, ax = plt.subplots(4, 2, constrained_layout=True)
fig.set_size_inches(h=8.2, w=6.8)

for i, (title, x) in enumerate(zip(["Angles", "Scales"], [angles, scales])):
    for j, (s, e) in enumerate(
        zip([0, 0, 3584, 15872],
            [epochs+1, 3584, 15872, epochs+1])):
        plot_four_times(ax[j][i], x, s, e, title, j==0, 9, 8, 0.7)

plt.gcf()
#plt.savefig('evo2.pgf', bbox_inches='tight', pad_inches=0)
plt.close()