### elastic weight colsolidation

In the figure below, $\theta^{*}$ are the weights (ie parameters, synaptic strengths) learned by the neural network (NN) to solve old task A, shown as a vector in vector space. ie if this neural network as 100 total weights then this is a 2D representation of $\mathbb{R}^{100}$ space. The blue horizontal arrow shows an example of catastrophic forgetting, whereby $\theta^{*}$ moves out of the region that allows the NN to perform well at task A (grey), and into the center of a region that allows the NN to perform well at task B (cream). The downward green arrow is the update to $\theta^{*}$ regularized by L2 penalty that causes it to move toward the cream region. The desired update vector is the red arrow that moves the NN weights into a region capable of performing well at both tasks A and B. 

How Elastic Weight Cosolidation Changes Learning New Weights $\theta^{*}$
<p align="center">
<img src="https://raw.githubusercontent.com/clam004/intro_continual_learning/main/files/F1.large.jpg" height=500 width=500 >
</p>

EWC encourages movement of weights along the red path by modifying the loss function when re-training a NN that has already been trained to convergence using the loss function for task A, $L_{A}$, which has settled on weights $\theta_{A}$. When re-training the NN on task B using $L_{B}$, we add a term which penalizes changes to weights that are both far from $\theta_{A}$, ie $(\theta_{i} - \theta_{A , i}^{*})^{2})$, and also high in $F_{i}$

$$L \left(\right. \theta \left.\right) = L_{B} \left(\right. \theta \left.\right) + \underset{i}{\sum} \frac{\lambda}{2} F_{i} \left(\left(\right. \left(\theta\right)_{i} - \theta_{A , i}^{*} \left.\right)\right)^{2} $$

### But what is F? 

F is the Fisher information matrix. In the EWC paper:

"we approximate the posterior as a Gaussian distribution with mean given by the parameters θ∗A and a diagonal precision given by the diagonal of the Fisher information matrix F. F has three key properties (20): (i) It is equivalent to the second derivative of the loss near a minimum, (ii) it can be computed from first-order derivatives alone and is thus easy to calculate even for large models, and (iii) it is guaranteed to be positive semidefinite. Note that this approach is similar to expectation propagation where each subtask is seen as a factor of the posterior (21). where LB(θ) is the loss for task B only, λ sets how important the old task is compared with the new one, and i labels each parameter.

When moving to a third task, task C, EWC will try to keep the network parameters close to the learned parameters of both tasks A and B. This can be enforced either with two separate penalties or as one by noting that the sum of two quadratic penalties is itself a quadratic penalty."

### Lets learn what F is in the following example

the example below is taken from [Fisher-matrix-guide](https://wittman.physics.ucdavis.edu/Fisher-matrix-guide.pdf)

the toy example asks you to imagine a hotdog and bun universe where the space in that universe contains 2 processes, one which produces hotdogs alone, and one that produces hotdogs and buns in pairs. they presuppose that we are working with 2 parameters $\theta_1$ and  $\theta_2$. 

$\theta_1$ is how fast per unit volume and unit time that hotdogs and buns are produced together and $\theta_2$ is how fast per unit volume and unit time hotdogs are produced alone. If you create a unit volume vaccumm in this space, and observe for 1 unit time you would expect to see observable $h$ =  $\theta_1$ + $\theta_2$ hotdogs and $b$ = $\theta_1$ buns in this unit volume of space.

For N total parameters in $\theta$ The Fisher Information matrix at position ij is
$$
I(\theta)_{ij} = E\left[ \left( \frac{\partial}{\partial\theta_i}\log f(X;\theta) \right)\left( \frac{\partial}{\partial\theta_j}\log f(X;\theta) \right) \mid \theta\right]
$$

[fisher-info-matrix](https://andrewliao11.github.io/blog/fisher-info-matrix/)

In [4]:
import math
import numpy as np
from matplotlib import pyplot as plt

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
