##  Double descent in two layer neural network
This notebook contains the relevant code for the following figures in the paper "*Early stopping in deep networks: Double descent and how to eliminate it*":

- Figure 3

In [None]:
#!/usr/bin/env python
# coding: utf-8

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

import argparse
import os
import datetime
import pathlib
import random
import json
import numpy as np
import math

import torch

import sys
sys.path.append('code/')
from linear_utils import linear_model
from train_utils import save_config, prune_data

In [None]:
# argument written in command line format
cli_args = '--seed 12 --save-results --jacobian --risk-loss L1 -t 50000 -w 1 0.1 --lr 0.0001 0.0001 -d 100 -n 100 --hidden 50 --sigmas 1 --coupled_noise --sigma_noise 10.0 2.0 --no-bias --linear --transform-data'


In [None]:
"""
A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing the MSE loss.
"""

# get CLI parameters
parser = argparse.ArgumentParser(description='CLI parameters for training')
parser.add_argument('--root', type=str, default='', metavar='DIR',
                    help='Root directory')
parser.add_argument('-t', '--iterations', type=int, default=1e4, metavar='ITERATIONS',
                    help='Iterations (default: 1e4)')
parser.add_argument('-n', '--samples', type=int, default=100, metavar='N',
                    help='Number of samples (default: 100)')
parser.add_argument('--print-freq', type=int, default=100,
                    help='CLI output printing frequency (default: 1000)')
parser.add_argument('--gpu', type=int, default=None,
                    help='Number of GPUS to use')
parser.add_argument('--seed', type=int, default=None,
                    help='Random seed')                        
parser.add_argument('-d', '--dim', type=int, default=50, metavar='DIMENSION',
                    help='Feature dimension (default: 50)')
parser.add_argument('--hidden', type=int, default=200, metavar='DIMENSION',
                    help='Hidden layer dimension (default: 200)')
parser.add_argument('--no-bias', action='store_true', default=False,
                    help='Do not use bias')
parser.add_argument('--linear', action='store_true', default=False,
                    help='Linear activation function')
parser.add_argument('--sigmas', type=str, default=None,
                    help='Sigmas')     
parser.add_argument('-r','--s-range', nargs='*', type=float,
                    help='Range for sigmas')
parser.add_argument('-w','--scales', nargs='*', type=float,
                    help='scale of the weights')
parser.add_argument('--lr', type=float, default=1e-4, nargs='*', metavar='LR',
                    help='learning rate (default: 1e-4)')              
parser.add_argument('--normalized', action='store_true', default=False,
                    help='normalize sample norm across features')
parser.add_argument('--risk-loss', type=str, default='MSE', metavar='LOSS',
                    help='Loss for validation')
parser.add_argument('--jacobian', action='store_true', default=False,
                    help='compute the SVD of the jacobian of the network')
parser.add_argument('--save-results', action='store_true', default=False,
                    help='Save the results for plots')
parser.add_argument('--coupled_noise', action='store_true', default=False,
                    help='Couple noise in output to large eigenvalues.')
parser.add_argument('--sigma_noise', nargs='*', type=float, default=0.0,
                    help='Output noise.')
parser.add_argument('--pcs', type=int, default=None, 
                    help='Number of PCs to use in data.')
parser.add_argument('--transform-data', action='store_true', default=False, 
                    help='Use data in transformed space')
parser.add_argument('--details', type=str, metavar='N',
                    default='no_detail_given',
                    help='details about the experimental setup')


args = parser.parse_args(cli_args.split())

# directories
root = pathlib.Path(args.root) if args.root else pathlib.Path.cwd().parent

current_date = str(datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
args.outpath = (pathlib.Path.cwd().parent / 'results' / 'two_layer_nn' /  current_date)

if args.save_results:
    args.outpath.mkdir(exist_ok=True, parents=True)

if args.seed is not None:
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
device = torch.device('cpu')
# device = torch.device('cuda') # Uncomment this to run on GPU

In [None]:
d_out = 1      # dimension of y

# sample training set from the linear model
lin_model = linear_model(args.dim, sigma_noise=args.sigma_noise, beta=None, normalized=False, sigmas=args.sigmas, s_range=args.s_range, coupled_noise=args.coupled_noise, transform_data=args.transform_data)
Xs, ys = lin_model.sample(args.samples)
Xs = torch.Tensor(Xs).to(device)
ys = torch.Tensor(ys.reshape((-1,1))).to(device)

if args.pcs:
    Xs = prune_data(Xs, args.pcs)

# sample the set for empirical risk calculation
Xt, yt = lin_model.sample(args.samples)
Xt = torch.Tensor(Xt).to(device)
yt = torch.Tensor(yt.reshape((-1,1))).to(device)

In [None]:
# define loss functions
loss_fn = torch.nn.MSELoss(reduction='sum')
risk_fn = torch.nn.L1Loss(reduction='mean') if args.risk_loss == 'L1' else loss_fn

In [None]:
def get_jacobian_two_layer(X, y, model, crit):
    
    grads = []
    for cx, cy in zip(X, y):

        cur_grads = []
        model.zero_grad()
        co = model(cx)
        co.backward(torch.ones(len(cy)))

        for p in model.parameters():
            if p.grad is not None and len(p.data.shape)>1:
                cur_grads.append(p.grad.data.numpy().flatten())
        grads.append(np.concatenate(cur_grads))
    return np.array(grads)

In [None]:
# Two layer neural network in pytorch
model = torch.nn.Sequential(
        torch.nn.Linear(args.dim, args.hidden, bias=not args.no_bias),
        torch.nn.Identity() if args.linear else torch.nn.ReLU(),
        torch.nn.Linear(args.hidden, 1, bias=not args.no_bias),
    ).to(device)


#### re-initialize the weights (regular initialization is too unstable)
# if args.scales:
#     i = 0
#     with torch.no_grad():
#         for m in model:
#             if type(m) == torch.nn.Linear:
#                 if i == 0:
#                     m.weight.data.normal_(0, args.scales[0])
#                 if i == 1:
#                     m.weight.data.uniform_(-args.scales[1], args.scales[1])
#                 i += 1
                
                
# use kaiming initialization instead                
if args.scales:
    i = 0
    with torch.no_grad():
        for m in model:
            if type(m) == torch.nn.Linear:
                if i == 0:
                    torch.nn.init.kaiming_normal_(m.weight, a=math.sqrt(5))
                    m.weight.data = torch.mul(m.weight.data, args.scales[0])
                if i == 1:
                    torch.nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
                    m.weight.data = torch.mul(m.weight.data, args.scales[1])
                i += 1
                

# use same learning rate for the two layers in case of a single learning rate or none.
if isinstance(args.lr, float):
    args.lr = [args.lr] * 2

In [None]:
# compute the Jacobian at initialization
if args.jacobian:
    J = get_jacobian_two_layer(Xs, ys, model, loss_fn)
    uv, sv, vtv = np.linalg.svd(J)

    v1 = []
    v2 = []
    for i in range(sv.shape[0]):
        v1.append(np.linalg.norm(vtv[i,:][:np.prod([250, 50])]))
        v2.append(np.linalg.norm(vtv[i,:][-np.prod([1, 250]):]))
    v1 = np.array(v1)
    v2 = np.array(v2)
    vTrec = np.linalg.norm(np.stack((v1, v2)), axis=0)


    if args.save_results:
        save_config(args)

        np_save_file = args.outpath / ('two_layer_nn_jacobian_' + 
                                       str(args.scales[0]).replace('.', '-') + '_' + 
                                       str(args.scales[1]).replace('.', '-') + 
                                       '.txt')

        np.savetxt(np_save_file, 
                   np.column_stack((sv, 
                                    v1,
                                    v2,
                                    vTrec
                                   )), 
                   header='t vw vv vT', 
                   comments='',
                   newline='\n' )

In [None]:
cmap = matplotlib.cm.get_cmap('viridis')
colorList = [cmap(50/1000), cmap(350/1000), cmap(650/1000)]
labelList = [r'$W$', r'$v$', r'$W + v$']

fig = plt.figure(figsize=(12,8))

ax_list = [plt.subplot(111)]

ax_list[0].scatter(sv, v1, 
                color=colorList[0], 
                label=labelList[0],
                lw=4)
ax_list[0].scatter(sv, v2, 
                color=colorList[1], 
                label=labelList[1],
#                 ls='dashed',
                lw=4)
ax_list[0].scatter(sv, vTrec, 
                color=colorList[2], 
                label=labelList[2],
#                 ls='dashed',
                lw=4)
    
ax_list[-1].legend(loc=0, bbox_to_anchor=(1, 0.5), fontsize='x-large',
                   frameon=True, fancybox=True, shadow=True, ncol=1)
ax_list[0].set_ylabel(r'$\Vert v \Vert_2^2$')

# for i, ax in enumerate(ax_list): ax.set_title(r'$w = $' + str(weights[i]['w']) + 
#                                               r';$v = $' + str(weights[i]['v']))
for ax in ax_list: ax.set_xlabel(r'$\sigma_i$')
for ax in ax_list: ax.set_xscale('log')
for ax in ax_list: ax.set_yscale('log')
plt.show()