In [None]:
import torch
from torch import nn

from train_utils import *
from train_xor import *

import numpy as np

import matplotlib
import matplotlib.pyplot as plt

%config InlineBackend.figure_format = 'retina'

In [None]:
d = 2
n = 2**9
m = 2**12
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 2**13

data_params = {
    'n': n,
    'data_fn': gen_skew_data,
    'data_params': {
        'd': d, 'mu': 1, 'delta': 0.2, 'sigma': 0.1 / d**0.5,
        #'alpha': np.pi/3}
        'alpha': np.pi/2}
}
loader_params = {'batch_size': n, 'num_workers': 0}

model_params = {'d': d, 'm': m, 'scale': 2**(-7)}

loss_params = {}
optimizer_params = {
    'momentum': 0.0, 'weight_decay': 0.0, 'nesterov': False
}
scheduler_params = {}
correction_params = {'lr': 2**(-4)}

path = np.empty((epochs+1, m, d), dtype=np.float64)
train_params = {
    'val_epoch_params': {'weights_path': path},
    'get_loss_fn': get_loss_one_dim,
    'train_epoch_fn': train_epoch_no_sch,
    'val_epoch_fn': log_epoch
}

train_data_seed = 100
model_seed = 2

train_kwargs = {
    'data_fn': gen_same_data,
    'model_fn': gen_random_model,
    'loss_fn': nn.BCEWithLogitsLoss,
    'scheduler_fn' : lambda *x, **y: None,
    'correct_training_params_fn': my_correct_training_params
}

In [None]:
scales = []
angles = []
shares = []

for alpha in [np.pi/2, np.pi/3]:
    data_params['data_params']['alpha'] = alpha
    model = get_trained_model(
        epochs,
        data_params,
        loader_params,
        model_params,
        loss_params,
        optimizer_params,
        scheduler_params,
        correction_params,
        train_params,
        train_data_seed,
        model_seed,
        model_seed,
        model_seed,
        **train_kwargs
    )
    neuron_scales = (path**2).sum(2)
    network_scales = neuron_scales.sum(1)
    neuron_shares = (neuron_scales.T / network_scales.T).T
    scales.append(network_scales**0.5)
    shares.append(neuron_shares[-1])
    angles.append(np.arctan2(path[:, :, 1], path[:, :, 0])[-1])

In [None]:
fig, ax = plt.subplots(2, 2, constrained_layout=True)
fig.set_size_inches(h=3.3, w=3.3)

ax[0][0].plot(np.arange(len(scales[0])), scales[0], linewidth=0.75)
ax[0][0].set_title(r"Scale, $\alpha=\pi/2$", fontsize=7)
ax[0][0].set_xlabel("Epoch", fontsize=6, labelpad=1)
ax[0][1].plot(np.arange(len(scales[1])), scales[1], linewidth=0.75)
ax[0][1].set_title(r"Scale, $\alpha=\pi/3$", fontsize=7)
ax[0][1].set_xlabel("Epoch", fontsize=6, labelpad=1)
ax[1][0].scatter(shares[0], angles[0], s=1.5, linewidth=0.75)
ax[1][0].set_title(r"Angle vs. Rel. Scale, $\alpha=\pi/2$", fontsize=7)
ax[1][0].set_xlabel("Relative Scale", fontsize=6, labelpad=1)
ax[1][0].set_ylabel("Angle", fontsize=6, labelpad=0)
ax[1][1].scatter(shares[1], angles[1], s=1.5, linewidth=0.75)
ax[1][1].set_title(r"Angle vs. Rel. Scale, $\alpha=\pi/3$", fontsize=7)
ax[1][1].set_xlabel("Relative Scale", fontsize=6, labelpad=1)
ax[1][1].set_ylabel("Angle", fontsize=6, labelpad=0)

for i in range(2):
    for j in range(2):
        ax[i][j].tick_params(
            axis='both', labelsize=6, length=2, width=0.5)
        plt.setp(ax[i][j].spines.values(), linewidth=0.5)

plt.savefig('align.pdf', bbox_inches='tight', pad_inches=0)
plt.close()

In [None]:
angle = angles[-1]
angle_dev = np.minimum(
    np.minimum(np.abs(angle - np.pi/2), np.abs(angle + np.pi/2)),
    np.minimum(
        np.minimum(np.abs(angle), np.abs(angle - np.pi)),
        np.abs(angle + np.pi)))
np.quantile(angle_dev, 0.95)