##  Sum of bias-variance tradeoffs
This notebook contains the relevant code for the following figures in the paper "*Early stopping in deep networks: Double descent and how to eliminate it*":

- Figure 2

In [3]:
import numpy as np
import matplotlib as mpl
mpl.use('tkagg')
from scipy.stats import ortho_group
from tqdm import tqdm
import matplotlib.pyplot as plt

np.random.seed(1234)

In [4]:
def get_modulation_matrix(d, p, k):
    U = ortho_group.rvs(d)
    VT = ortho_group.rvs(d)
    S = np.eye(d)
    S[:p, :p] *= 1
    S[p:, p:] *= 1 / k
    F = np.dot(U, np.dot(S, VT))
    return F


# Implements the teacher and generates the data
def get_data(seed, n, d, p, k, noise):
    np.random.seed(seed)
    Z = np.random.randn(n, d) / np.sqrt(d)
    Z_test = np.random.randn(1000, d) / np.sqrt(d)

    # teacher
    w = np.random.randn(d, 1)
    y = np.dot(Z, w)
    y = y + noise * np.random.randn(*y.shape)
    # test data is noiseless
    y_test = np.dot(Z_test, w)

    # the modulation matrix that controls students access to the data
    F = get_modulation_matrix(d, p, k)

    # X = F^T Z
    X = np.dot(Z, F)
    X_test = np.dot(Z_test, F)

    return X, y, X_test, y_test, F, w


def get_RQ(w_hat, F, w, d):
    # R: the alignment between the teacher and the student
    R = np.dot(np.dot(F, w_hat).T, w).item() / d
    # Q: the student's modulated norm
    Q = np.dot(np.dot(F, w_hat).T, np.dot(F, w_hat)).item() / d
    return R, Q

In [5]:
d = 100
# p: number of fast learning dimensions
p = 70
# k: kappa -> the condition number of the modulation matrix, F
k = 100
# standard deviation of the noise added to the teacher output
noise = 0.0
# L2 regularization coefficient
l2 = 0.0

In [None]:
F = get_modulation_matrix()