# bias-variance tradeoff

- https://en.wikipedia.org/wiki/Bias%E2%80%93variance_tradeoff
- https://stats.stackexchange.com/questions/129478/when-is-the-bootstrap-estimate-of-bias-valid
- https://johanndejong.wordpress.com/2016/12/18/bias-variance-decomposition/

so:
- we sample some datapoints from a population
- we fit parameters to these sampled datapoints
- then we create predictions, using these parameters, for each input $x_n$
- we're going to fix the input data points $\mathcal{X} = \{x_1, x_2, \dots, x_N \}$, otherwise how can we average across datasets?

So, we have:

- $N$ samples in each dataset
- $M$ datasets

Input data $\mathcal{X}$ is common across all datasets:

- $\mathcal{X} = x_{1}, x_{2}, \dots, x_{N}$

This means that ground-truth $\mathcal{Y}^*$ is also common across all datasets:

- $\mathcal{Y}^* = \{ y^*_{1}, y^*_{2}, \dots, y^*_{N} \}$

Then, the following are per-dataset:

- targets $\mathcal{Y}_m = \{ y_{m,1}, y_{m,2}, \dots, y_{m,N} \}$
- parameters $\theta_m$ (fitted to above $\mathcal{X}$ and $\mathcal{Y}_m$
- predictions $\hat{\mathcal{Y}}_m = \{ \hat{y}_{m,1}, \hat{y}_{m,2}, \dots, \hat{y}_{m,N} \}$


Then conceptually:

- bias and variance are first calculated for each datapoint $x_n$, and then averaged over all datapoints
- for each datapoint, the bias is the difference between the expected prediction, and the ground truth, squared, ie:

$$
\text{bias}_n = \left(
    y^*_n - \frac{1}{M} \sum_{m=1}^M \hat{y}_{m,n}
\right)^2
$$

So, the bias is:

$$
\text{bias} = \frac{1}{N} \sum_{n=1}^N
    \left(
        y_n^* - \frac{1}{M} \sum_{m=1}^M \hat{y}_{m,n}
    \right)^2
$$

Similarly, variance is:

$$
\text{variance} = \frac{1}{N} \sum_{n=1}^N \frac{1}{M} \sum_{m=1}^M
    \left(
        \hat{y}_{m,n} - \frac{1}{M} \sum_{m=1}^M \hat{y}_{m,n}
    \right)^2
$$

Lastly, mse is averaged over each dataset. For each dataset, we calculate the mse over all datapoints, ie:

$$
\text{mse} = \frac{1}{M} \sum_{m=1}^M
   \frac{1}{N} \sum_{n=1}^N \left(
       y^*_n - \hat{y}_{m,n}
   \right)^2 \\
= \frac{1}{N} \sum_{n=1}^N 
\frac{1}{M}
\sum_{m=1}^M
\left(
       y^*_n - \hat{y}_{m,n}
   \right)^2
   $$

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

# N = 2000
N = 10
# K = 1
num_runs = 1000
true_a = 1.23
true_b = 0.34
# true_epsilon = 0.23
true_epsilon = 0.51
# true_epsilon = 0.0

print('ground truth a=%s b=%s e=%s' % (true_a, true_b, true_epsilon))

def generate_X():
#     torch.manual_seed(seed)
    torch.manual_seed(123)
    X = torch.rand(N, 1)
#     Y = X.view(-1) * true_a + true_b + torch.randn(N) * true_epsilon
#     return X, Y
    return X

def generate_Y(X, seed):
    torch.manual_seed(seed)
#     X = torch.rand(N, 1)
    Y = X.view(-1) * true_a + true_b + torch.randn(N) * true_epsilon
    return Y

def generate_features(X, order):
    X2 = torch.zeros(N, order + 1)
    for k in range(order + 1):
        X2[:, k] = X[:, 0].pow(k)
    return X2

def fit(X, Y):
    W = (X.transpose(0, 1) @ X)
    W = W.inverse()
    W = W @ X.transpose(0, 1)
    W = W @ Y.view(-1, 1)
    return W

def calc_stats(Y_star, Y, preds):
    bias_sum = 0
    variance_sum = 0
    mse_sum = 0
    
    N = Y.size()[0]
    print('N', N)
    num_samples = len(preds)
    for n in range(N):
#         yv = torch.zero(num_samples)
        predv = torch.zeros(num_samples)
        for j in range(num_samples):
#             print('preds[j][n]', preds[j][n])
            predv[j] = preds[j][n]
            mse_sqrt = preds[j][n] - Y[n]
            mse = mse_sqrt * mse_sqrt
            mse_sum += mse
        pred_avg = predv.mean()
        bias_sqrt = pred_avg - Y_star[n]
        bias = bias_sqrt * bias_sqrt
        bias_sum += bias
        variance = predv.var()
#         help(torch.var)
#         asdf
        variance_sum += variance
    return bias_sum / N, variance_sum / N, mse_sum / N / num_samples
    
#     err = pred - Y.view(-1, 1)
#     mse = (err * err).sum() / N
#     bias = (err * err).sum() / N
#     print('mse', mse)
#     print('bias', bias)
#     variance = 
#     asdfasdf
#     bias = err.sum() * err.sum() / N / N
#     variance = (pred * pred).sum() / N - pred.sum() * pred.sum() / N / N
#     noise = mse - bias - variance

def run(order):
    print('')
    print('order %s' % order)
    preds = []
    X = generate_X()
    Y_star = (X * true_a + true_b).view(-1)
    X2 = generate_features(X=X, order=order)
    Ws = torch.zeros(num_runs, order + 1)
    for i in range(num_runs):
        Y = generate_Y(X=X, seed=i)
        W = fit(X2, Y)
        Ws[i] = W
        pred = X2 @ W
        pred = pred.view(-1)
        preds.append(pred)
    bias, variance, mse = calc_stats(Y_star=Y_star, Y=Y, preds=preds)
    noise = mse - bias - variance
    W_avg = Ws.mean(0)
    print('W_avg', W_avg.view(1, -1))
    print('bias %.3f' % bias, 'variance %.4f' % variance,
          'noise %.4f' % noise, 'mse %.4f' % mse)
#         calc_stats(Y=Y, pred=pred)
#         run_one(order=order, seed=i)

run(order=0)
run(order=1)
run(order=2)
run(order=3)
# run(order=4)


## Derivation

"The derivation of the bias-variance decomposition for squared error proceeds as follows. For notational convenience, abbreviate $f = f(x)$ and $\hat{f} = \hat{f}(x)$. First recall that 