# Measuring model complexity

### Intro
In this assignment we are going to implement some complexity measures. Complexity measure is a magical blackbox function, which takes some information about a neural network, a train dataset and tells us how well our model should generalize. We have encountered several generalization bounds during the course, which can serve as complexity measures, but typically they do not work well in practice. Actually, the primary goal of this assignment is to figure how well current complexity measures work (if they work at all).

### Task description
Since we haven't polished this assignment yet, we decided to loose the requirements. We provide all the training/visualizion code (which is, actually, the most time-consuming part) and you will only have to implement the complexity measures. There are 3 complexity measures for you to implement:

* [2 points] Spectral Complexity
* [2 points] Normalized Spectral Complexity
* [1 point] Noise Sensitivity

We are going to investigate them in different setups:
* Depending on the network width
* Depending on the number of corrupted samples (i.e. samples with random labels)

For convenience we use fully-connected NN with 1 hidden layer in all the experiments.

**Important requirements**:
* Your results *must be reproducible* (that's why we are so paranoid about fixing random seeds)
* All the cells in your notebook must be executable in sequential order

#### Spectral Complexity
Check formula (1.2) of [this paper](https://arxiv.org/abs/1706.08498). We want you to implement this formula, setting reference matrices to be zero matrices.

### Complexity Measures

#### Mean Normalized Margin

Mean Normalized Margin is our ad-hoc complexity measure that is build upon Spectral Complexity.
We compute it the following way:
* let $R_\mathcal{A}$ be a spectral complexity of the model $\mathcal{A}$
* let $X$ be the training dataset.
* let $m_i$ be a margin for a training sample $x_i$, i.e.
$$
m_i = l_{y_i} - \arg\max_{c \neq y_i}l_c,
$$
where $l_c$ is the logit for class $c$, produced by our model.
So, margin is just a difference between "true" logit and the logit which is closest to it, i.e. our "second guess".
The idea is that when our model is very confident in its predictions, then all margins are huge.
If our model perfectly fits training dataset then $m_i$ is positive for all training pairs $(x_i, y_i)$.

We define normalized margin $\hat{m}_i$ to be:

$$
\hat{m}_i = m_i \cdot \frac{\sqrt{n}}{R_\mathcal{A} \cdot \|X\|_F}
$$

And Mean Normalized Margin $M_\mathcal{A}$ is just a mean value of normalized margins for our training dataset:

\begin{equation}
\begin{split}
M_\mathcal{A} &= \frac{1}{n}\sum_{i=1}^n \hat{m}_i \\
%&= \frac{1}{\sqrt{n} \cdot R_\mathcal{A} \cdot \|X\|_2} \cdot \sum_{i=1}^n m_i
\end{split}
\end{equation}

#### Noise Sensitivity
The idea is that if our model resistant to noise then it should generalize better.
We check it in a very straightforward way: just add noise to the inputs and see how the outputs change.
More precisly we compute Noise Sensitivity $N_\mathcal{A}$ the following way:

$$
N_\mathcal{A} = \frac{1}{n}\sum_{i=1}^n \| f(x_i) - f(x_i + \varepsilon_i) \|_2^2,
$$
where:
* $f(x_i)$ is a vector of logits
* $\{\varepsilon_i\}_{i=1}^n$ is a dataset of noise vectors sampled from $\mathcal{N}(0, I)$

### Tips
* You can delete all the provided code if it's more convenient for you to do everything from scratch
* Check your implementation twice
* It takes 5-10 hours for all the experiments to finish, so keep this in mind when planning your time
* It feels like the most difficult part of this assignemnt is to get familiar with the experiments logic (and the code provided :|). It should be relatively easy to implement the measures afterwards
* Reduce `NUM_ITERS_FOR_CONVERGENCE` value while debugging to make the experiments run faster
* Running on a GPU will be 50-100% faster

In [None]:
import sys; sys.path.append('.')
import random

import numpy as np
import torch
import matplotlib.pyplot as plt

from utils.trainer import Trainer


SEED = 42
DATA_DIR = './data'
DEVICE = 'cuda' # TIP: change to 'cuda' if you have a GPU available
NUM_ITERS_FOR_CONVERGENCE = 20000

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

## Complexity measures

In [None]:
from typing import List
import numpy as np
from torch import Tensor


class ComplexityMeasure:
    def __init__(self, name:str, value:float):
        self.name = name
        self.value = value
        
    def __str__(self) -> str:
        return f'{self.name}: {self.value}'
    
    def __repr__(self) -> str:
        return f'{self.name}: {self.value}'


def compute_accuracy(trainer:Trainer) -> ComplexityMeasure:
    return ComplexityMeasure('Test accuracy', trainer.compute_test_accuracy())
    
    
def compute_weights_norm(trainer:Trainer, p=2) -> ComplexityMeasure:
    weights = torch.cat([p.data.cpu().view(-1) for p in trainer.model.parameters()])
    
    return ComplexityMeasure(f'L{p}-norm', torch.norm(weights, p).item())


def compute_spectral_complexity(trainer:Trainer) -> ComplexityMeasure:
    #########################################
    ### Your code here. Difficulty: ☕☕ ###
    ########################################
    pass


def compute_mean_normalized_margin(trainer:Trainer) -> ComplexityMeasure:
    #########################################
    ### Your code here. Difficulty: ☕☕ ###
    ########################################
    pass


def compute_noise_sensitivity(trainer:Trainer, std:float=1.0, num_noised_points:int=1) -> ComplexityMeasure:
    #######################################
    ### Your code here. Difficulty: ☕ ###
    ######################################
    pass

## [Part 1] Measuring complexity for a normal dataset for different widths

In [None]:
base_config = {
    'max_num_iters': NUM_ITERS_FOR_CONVERGENCE,
    'model_type': 'DenseModel',
    'batch_size': 500,
    'device': DEVICE,
    'data_dir': DATA_DIR,
}
widths = [10, 100, 250, 500, 1500, 5000]
width_configs = [{'model_config': {'width': w}} for w in widths]
configs = [{**base_config, **c} for c in width_configs]
different_width_trainers = [Trainer(c) for c in configs]
different_width_test_accs = [t.run_training(True).compute_test_accuracy() for t in different_width_trainers]

In [None]:
complexity_measure_calculators = [
    compute_accuracy,
    compute_weights_norm,
    compute_spectral_complexity,
    compute_mean_normalized_margin,
    compute_noise_sensitivity
]

In [None]:
different_width_complexities = [[fn(t) for t in different_width_trainers] for fn in complexity_measure_calculators]

In [None]:
def visualize_complexities(x_values:List, complexities:List[List[ComplexityMeasure]], n_cols, n_rows, xlabel:str=''):
    assert len(x_values) == len(complexities[0])
    assert [len(set([c.name for c in cs])) == 1 for cs in complexities], \
        "Each complexities row should correspond to the single complexity measure"
    
    _, subplots = plt.subplots(n_rows, n_cols, figsize=(n_cols * 6, n_rows * 4))
    if n_rows > 1: subplots = [p for row in subplots for p in row]
    
    for cs, subplot in zip(complexities, subplots):
        subplot.set_title(cs[0].name)
        subplot.plot(x_values, [c.value for c in cs])
        subplot.set_xlabel(xlabel)
        #subplot.set_ylabel(f'{cs[0].name}')
        subplot.grid()
        
    # Hiding blank subplots for nicer visualization
    for empty_subplot in subplots[len(complexities):]:
        empty_subplot.set_axis_off()
        
    plt.subplots_adjust(hspace=0.3)

In [None]:
visualize_complexities(widths, different_width_complexities, 3, 2, xlabel='Width')

**Comment on plots you've obtained:**

Does accuracy reach asymptote for large width? Do complexity measures reach asymptote? Do these complexity measures increase or decrease with width? Does this seem natural?

## [Part 2] Complexity measures for spoiled datasets with different amounts of spoiled samples

### [Part 2.1] Checking of complexity measures converge to some value

In [None]:
from tqdm import tqdm

max_num_iters = NUM_ITERS_FOR_CONVERGENCE
freq_iters = 500
num_good_points = 1000
num_shuffled_labels_to_add = [0, 500, 2500, 5000, 50000]
measurement_iters = [s * freq_iters for s in range(1, max_num_iters // freq_iters + 1)]
base_config = {
    'num_good_points': num_good_points,
    'max_num_iters': max_num_iters,
    'model_type': 'DenseModel',
    'batch_size': 500,
    'model_config': {'width': 1024},
    'device': DEVICE,
    'data_dir': DATA_DIR,
}

print('Constructing trainers...')
configs = [{**base_config, **{'num_bad_points': n}} for n in num_shuffled_labels_to_add]
trainers = [Trainer(c) for c in tqdm(configs)]
complexity_measurements = {n:[] for n in num_shuffled_labels_to_add}

def track_complexities(trainer:Trainer):
    if trainer.num_iters_done % freq_iters != 0: return
    cs = [fn(trainer) for fn in complexity_measure_calculators]
    complexity_measurements[trainer.config['num_bad_points']].append(cs)

for trainer in trainers:
    trainer.on_iter_done_callbacks.append(track_complexities)
    trainer.run_training(True)

In [None]:
n_rows = 2
n_cols = 3
num_measures = len(complexity_measure_calculators)
bad_points_proportions = [n / (n + num_good_points) for n in num_shuffled_labels_to_add]

_, subplots = plt.subplots(n_rows, n_cols, figsize=(n_cols * 6, n_rows * 4))
if n_rows > 1: subplots = [p for row in subplots for p in row]

for c_idx, subplot in zip(range(num_measures), subplots):
    complexities = [[compls[c_idx] for compls in complexity_measurements[n]] for n in num_shuffled_labels_to_add]

    subplot.set_title(complexities[0][0].name)

    for bad_points_prop, cs in zip(bad_points_proportions, complexities):
        assert len(set(c.name for c in cs)) == 1, "Wrong format for `complexities` argument"
        subplot.plot(measurement_iters, [c.value for c in cs], label=f'Bad labels proportion: {bad_points_prop:0.2f}')

    subplot.set_xlabel('Iteration')
    subplot.legend()
    subplot.grid()
    
for subplot in subplots[num_measures:]:
    subplot.set_axis_off()

**Comment on plots you've obtained:**

Does accuracy reach asymptote with iterations? Which complexity measures (if any) reach asymptote?

### [Part 2.2] Final values

In [None]:
bad_points_proportions = [round(p, 2) for p in bad_points_proportions]
final_measurements = [complexity_measurements[n][-1] for n in num_shuffled_labels_to_add]
final_measurements = np.array(final_measurements).transpose().tolist()
visualize_complexities(bad_points_proportions, final_measurements, 3, 2, xlabel='Bad points proportion')

**Comment on plots you've obtained:**

Do accuracy / complexity measures increase or decrease with bad points proportion?

## Summarize your observations here:

Which complexity measures converge, and how are they correlated with test accuracy? Is this behavior seem desirable, or not?