## Preamble

In [None]:
%matplotlib inline

In [None]:
import torch
from torch.autograd import Variable
from sklearn.decomposition import PCA
import torch.nn.functional as F
import torch.utils.data as Data

import matplotlib.pyplot as plt
import matplotlib.animation as animation

import numpy as np
import imageio
from tqdm import tqdm

In [None]:
plt.rcParams.update({
    "animation.writer": "ffmpeg",
    "font.family": "serif",  # use serif/main font for text elements
    "font.size": 12,
    "text.usetex": True,     # use inline math for ticks
    "pgf.rcfonts": False,    # don't setup fonts from rc parameters
    "hist.bins": 20, # default number of bins in histograms
    "pgf.preamble": [
         "\\usepackage{units}",          # load additional packages
         "\\usepackage{metalogo}",
         "\\usepackage{unicode-math}",   # unicode math setup
         r"\setmathfont{xits-math.otf}",
         r"\setmainfont{DejaVu Serif}",  # serif font via preamble
         r'\usepackage{color}',
    ]
})

## Neural network architecture and initialisation

In [None]:
class Net(torch.nn.Module):
    """
    1 hidden layer Relu network
    """
    def __init__(self, n_feature, n_hidden, n_output, init_scale=1, bias_hidden=True, bias_output=False, balanced=True, clipping=False, **kwargs):
        """
        n_feature: dimension of input
        n_hidden: number of hidden neurons
        n_output: dimension of output
        init_scale: weights are initialized ~ N(0, init_scale^2/(md)) where d is the input dimension of this layer and m the width
        bias_hidden: if True, use bias parameters in hidden layer. Use no bias otherwise
        bias_output: if True, use bias parameters in output layer. Use no bias otherwise
        balanced: if True, use a balanced initialisation
        clipping: if True, ensure that ||(w_j,a_j)|| \in [minclip, maxclip]/sqrt(n_hidden) for any j
        """
        super(Net, self).__init__()
        self.init_scale = init_scale
        
        self.hidden = torch.nn.Linear(n_feature, n_hidden, bias=bias_hidden)   # hidden layer with rescaled init
        torch.nn.init.normal_(self.hidden.weight.data, std=(init_scale/np.sqrt(n_hidden*n_feature)))
        if bias_hidden:
            torch.nn.init.normal_(self.hidden.bias.data, std=(init_scale/np.sqrt(n_hidden*n_feature)))
            
        self.predict = torch.nn.Linear(n_hidden, n_output, bias=bias_output)   # output layer with rescaled init
        if balanced: # balanced initialisation
            if bias_hidden:
                neuron_norms = (self.hidden.weight.data.norm(dim=1).square()+self.hidden.bias.data.square()).sqrt()
            else:
                neuron_norms = (self.hidden.weight.data.norm(dim=1).square()).sqrt()
            self.predict.weight.data = 2*torch.bernoulli(0.5*torch.ones_like(self.predict.weight.data)) -1
            self.predict.weight.data *= neuron_norms
        else:
            torch.nn.init.normal_(self.predict.weight.data, std=(init_scale/np.sqrt(n_hidden)))
        if bias_output:
            torch.nn.init_normal_(self.predict.bias.data, std=(init_scale/np.sqrt(n_hidden)))
            
        if clipping:
            neuron_norms = self.hidden.weight.data.norm(dim=1).square() + self.predict.weight.data.norm().square()
            if bias_hidden:
                neuron_norms += self.hidden.bias.data.square()
            if bias_output:
                neuron_norms += self.predict.bias.data.square()
            neuron_norms = neuron_norms.sqrt()
            ra = kwargs.get('minclip', init_scale/10)/np.sqrt(n_hidden)
            rb = kwargs.get('maxclip', 10*init_scale)/np.sqrt(n_hidden)
            m_weights = torch.clip(neuron_norms, min=ra, max=rb)/neuron_norms
            self.hidden.weight.data *= m_weights.unsqueeze(1)
            self.predict.weight.data *= m_weights
            if bias_hidden:
                self.hidden.bias.data *= m_weights
            if bias_output:
                self.predict.bias.data *= m_weights
            
        self.activation = kwargs.get('activation', torch.nn.ReLU()) # activation of hidden layer

    def forward(self, z):
        z = self.activation(self.hidden(z))     
        z = self.predict(z)             # linear output
        return z

## Generate data

In [None]:
torch.manual_seed(4) # fix random seed
d = 150
n = 75
m_teacher = 6 # number of neurons of the teacher network

## Teacher network
teacher = Net(n_feature=d, n_hidden=m_teacher, n_output=1, init_scale=1, balanced=False, clipping=False, bias_hidden=False
          )

x = torch.randn(n, d)
y = teacher(x).detach()

In [None]:
print(teacher.predict.weight)
print('-'*20)
print(y.reshape(-1))

## Visualisation functions

In [None]:
def multicolor_label(ax,list_of_strings,list_of_colors,axis='x',anchorpad=0,**kw):
    """this function creates axes labels with multiple colors
    ax: specifies the axes object where the labels should be drawn
    list_of_strings: a list of all of the text items
    list_if_colors: a corresponding list of colors for the strings
    axis:'x', 'y', or 'both' and specifies which label(s) should be drawn"""
    from matplotlib.offsetbox import AnchoredOffsetbox, TextArea, HPacker, VPacker

    # x-axis label
    if axis=='x' or axis=='both':
        boxes = [TextArea(text, textprops=dict(color=color, ha='left',va='bottom',**kw)) 
                    for text,color in zip(list_of_strings,list_of_colors) ]
        xbox = HPacker(children=boxes,align="center",pad=0, sep=60)
        anchored_xbox = AnchoredOffsetbox(loc=3, child=xbox, pad=anchorpad,frameon=False,bbox_to_anchor=(0.27, -0.18),
                                          bbox_transform=ax.transAxes, borderpad=0.)
        ax.add_artist(anchored_xbox)

    # y-axis label
    if axis=='y' or axis=='both':
        boxes = [TextArea(text, textprops=dict(color=color, ha='left',va='bottom',rotation=90,**kw)) 
                     for text,color in zip(list_of_strings[::-1],list_of_colors) ]
        ybox = VPacker(children=boxes,align="center", pad=0, sep=5)
        anchored_ybox = AnchoredOffsetbox(loc=3, child=ybox, pad=anchorpad, frameon=False, bbox_to_anchor=(-0.10, 0.2), 
                                          bbox_transform=ax.transAxes, borderpad=0.)
        ax.add_artist(anchored_ybox)


In [None]:
def save_single_frame(fig, arts, frame_number):
    """save as a pdf a single frame of an animation
    fig: the figure to save
    arts: list of images resulting in the animation
    frame_number: the specific frame to save as a pdf
    """
    # make sure everything is hidden
    for frame_arts in arts:
        for art in frame_arts:
            art.set_visible(False)
    # make the one artist we want visible
    for i in range(len(arts[frame_number])):
        arts[frame_number][i].set_visible(True)
    fig.savefig("frame_{}.pdf".format(frame_number))

## Custom PCA

In [None]:
def svd_flip(u, v, u_based_decision=True):
    """Sign correction to ensure deterministic output from SVD.
    Adjusts the columns of u and the rows of v such that the loadings in the
    columns in u that are largest in absolute value are always positive.
    Parameters
    ----------
    u : ndarray
        u and v are the output of `linalg.svd` or
        :func:`~sklearn.utils.extmath.randomized_svd`, with matching inner
        dimensions so one can compute `np.dot(u * s, v)`.
    v : ndarray
        u and v are the output of `linalg.svd` or
        :func:`~sklearn.utils.extmath.randomized_svd`, with matching inner
        dimensions so one can compute `np.dot(u * s, v)`.
        The input v should really be called vt to be consistent with scipy's
        output.
    u_based_decision : bool, default=True
        If True, use the columns of u as the basis for sign flipping.
        Otherwise, use the rows of v. The choice of which variable to base the
        decision on is generally algorithm dependent.
    Returns
    -------
    u_adjusted, v_adjusted : arrays with the same dimensions as the input.
    """
    if u_based_decision:
        # columns of u, rows of v
        max_abs_cols = np.argmax(np.abs(u), axis=0)
        signs = np.sign(u[max_abs_cols, range(u.shape[1])])
        u *= signs
        v *= signs[:, np.newaxis]
    else:
        # rows of v, columns of u
        max_abs_rows = np.argmax(np.abs(v), axis=1)
        signs = np.sign(v[range(v.shape[0]), max_abs_rows])
        u *= signs
        v *= signs[:, np.newaxis]
    return u, v

In [None]:
class MyPCA:
    """
    Same implementation as sklearn but allows to not center data before processing it.
    """
    def __init__(self, n_components, centered=False):
        self.n_components = n_components
        self.centered_ = centered

    def fit(self, X):
        """
        Assumes observations in X are passed as rows of a numpy array.
        """
        n_samples, n_features = X.shape

        # Center data
        if self.centered_:
            self.mean_ = np.mean(X, axis=0)
            X -= self.mean_

        U, S, Vt = np.linalg.svd(X, full_matrices=False)
        # flip eigenvectors' sign to enforce deterministic output
        U, Vt = svd_flip(U, Vt)

        components_ = Vt

        # Get variance explained by singular values
        explained_variance_ = (S**2) / (n_samples - 1)
        total_var = explained_variance_.sum()
        explained_variance_ratio_ = explained_variance_ / total_var
        singular_values_ = S.copy()  # Store the singular values.

        self.noise_variance_ = explained_variance_[self.n_components:].mean()

        self.n_samples_, self.n_features_ = n_samples, n_features
        self.components_ = components_[:self.n_components]
        self.explained_variance_ = explained_variance_[:self.n_components]
        self.explained_variance_ratio_ = explained_variance_ratio_[:self.n_components]
        self.singular_values_ = singular_values_[:self.n_components]
        
    def transform(self, X):
        if self.centered_:
            X = X - self.mean_
        X_transformed = np.dot(X, self.components_.T)
        return X_transformed

## Training

In [None]:
# init network
net = Net(n_feature=d, n_hidden=200, n_output=1, init_scale=1e-20, balanced=True, bias_hidden=False)     # define the network
 
optimizer = torch.optim.SGD(net.parameters(), lr=0.001) #Gradient descent
loss_func = torch.nn.MSELoss(reduction='mean')  # mean squared error

n_samples = x.shape[0]
n_iterations = 200000 # number of gradient descent steps

loss = torch.Tensor(np.array([0]))
previous_loss = torch.Tensor(np.array([np.infty]))

losses = []

# plot parameters
iter_geom = 1.1 #saved frames correspond to steps t=\lceil k^{iter_geom} \rceil for all integers k 
last_iter = 0
frame = 0
weights = []
signs = []
iters = []

# train the network
for it in tqdm(range(n_iterations)):
    prediction = net(x)
    loss = loss_func(prediction, y) 
    if (it<2 or it==int(last_iter*iter_geom)+1): # save net weights
        weights.append(net.hidden.weight.data.detach().numpy().copy())
        signs.append(net.predict.weight.data.heaviside(torch.as_tensor(float(0.5))).reshape(-1).numpy().copy())
        iters.append(it)
        last_iter = it
    losses.append(loss.data.numpy())
    optimizer.zero_grad()   # clear gradients for next train
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # descent step

## Loss profile

In [None]:
plt.figure()
plt.plot(losses, lw=3)
plt.ylim(ymin=0)
#plt.xlim(xmin=0, xmax=100000)
plt.ylabel(r'$L(\theta)$',fontsize=20)
plt.xlabel('Iterations', fontsize=20)
plt.grid(alpha=0.2)
plt.tight_layout()
plt.savefig('loss_profile_n{}_d{}.pdf'.format(n,d))
plt.show()

In [None]:
print(loss)

## Neuron alignment visualisation

In [None]:
## PCA to represent the neurons in 2D

pca = MyPCA(n_components=d, centered=False)
pca.fit(weights[-1])

In [None]:
ims = []
fig = plt.figure("Neuron alignment")
plt.ioff()

# Cosmetics
c1 = 'tab:green' # color of left axis
c2 = 'tab:blue' # color of right axis
color_map = {0 : 'firebrick',
             0.5 : 'black',
             1 : 'darkviolet'}

#plt.subplots_adjust(left=0.15, right=0.85)

ax = fig.add_subplot(111, projection='polar') # polar coordinates
ax.set_rorigin(-5e-2) # set inner circle for 0 norm vectors
ax.set_theta_zero_location("E")
ax.yaxis.set_ticklabels([])


#######
for i,w in enumerate(weights):
    s = signs[i]
    it = iters[i]
    c = [color_map[d] for d in s] # color of stars given their output layer sign
    w1 = pca.transform(w) # projection in 2D space
    im = ax.scatter(np.arctan(w1[:,1]/w1[:,0])+np.pi*np.heaviside(w1[:,0],0), np.linalg.norm(w1[:,:2],axis=1), animated=True, c=c, marker='*')
    t1 = ax.annotate("iteration: "+str(it),(0.1,0.95),xycoords='figure fraction',annotation_clip=False) # add text
    t2 = ax.annotate("frame: "+str(i),(0.8,0.95),xycoords='figure fraction',annotation_clip=False) # add text
    ims.append([im,t1,t2])
    
ani = animation.ArtistAnimation(fig, ims, interval=100, repeat=False)
plt.close()

In [None]:
ani.save('highdim_alignment.mp4', fps=10, dpi=120) # save animation as .mp4 

## Save specific frames

In [None]:
del ani

In [None]:
save_single_frame(fig, ims, 108) # save specific frame of animation as .pdf

## PCA explained variance repartition

In [None]:
plt.plot(pca.explained_variance_ratio_, '+')
plt.xlabel("Component number", fontsize=20)
plt.ylabel("Explained variance ratio", fontsize=20)
plt.grid(alpha=0.2)
plt.tight_layout()
plt.savefig('explained_variance_n{}_d{}.pdf'.format(n,d))
plt.show()