<table class="ee-notebook-buttons" align="left"><td>
<a target="_blank"  href="https://colab.research.google.com/github/eywalker/LVIV-2021/blob/main/notebooks/DeepLearing%20in%20Neuroscience.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" /> Run in Google Colab</a>
</td><td>
<a target="_blank"  href="https://github.com/eywalker/LVIV-2021/blob/main/notebooks/DeepLearing%20in%20Neuroscience.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" /> View source on GitHub</a></td></table>

# Welcome to Deep Learning in Neuroscience by Edgar Y. Walker

This is a Jupyter notebook to accompany the course on "Deep Learning in Neuroscience" taught as part of the Lviv Data Science Summer School 2021. This notebook as well as any other relevant information can be found in the [GitHub repository](https://github.com/eywalker/lviv-2021)!

In this course, we will learn how deep learning is getting utilized in studying neuroscience, specifically in building models of neurons to complex sensory inputs such as natural images. We will start by going through some neuroscience primer. We will then get our hands dirty by taking real neuronal responses recorded from mouse primary visual cortex (V1) as the mouse observes a bunch of natural images and developing a model to predict these responses. By the end of this course, you will gain some basic familiarity in utilizing deep learning models to predict responses of 1000s of neurons to natural images!

## Preparing the environment

#### <font color='red'>NOTE: Please run this section at the very beginning of the first session!</font>

Before we get to dive in and learn how deep learning is used in neuroscience and get your first neural predictive model trained, we need to install some prerequisite packages and download some neuronal data!

### Getting the code

We are going to primarily use [PyTorch](https://pytorch.org) to build, train and evaluate our deep learning models and I am going to assume some familiarity with PyTorch already.

Also to be able to handle the dataset containing neuronal activities, we are going to make our life easier by using a few existing libraries. I have prepared a library called [lviv2021](https://github.com/eywalker/lviv2021). This library has a dependency on [neuralpredictors](https://github.com/sinzlab/neuralpredictors), which is a collection of PyTorch layers, tools and other utilities that would prove helpful to train networks to predict neuronal responses.

Let's go ahead and install this inside the Colab environment.

In [None]:
# Install PyTorch dependency
!pip3 install torch==1.9.0+cu102 torchvision==0.10.0+cu102 -f https://download.pytorch.org/whl/torch_stable.html
    
# Install 
!pip3 install git+https://github.com/eywalker/lviv-2021.git

### Getting the dataset

We are going to use the dataset made available for our recent paper [Lurz et al. ICLR 2021](https://github.com/sinzlab/Lurz_2020_code), predicting responses of mouse visual cortex to natural images. 

The dataset can take anywhere from 5-10 min to download, so please be sure to **run the following at the very beginning of the session!** We are going to first spend some time learning the basics of computational neuroscience in the study of system identification. It would be best that you let the download take place while we go over the neursocience primer so that it will be ready when we come back here to get our hands dirty!

To download the data, simply execute the following cell, and let it run till completion.

In [None]:
!git clone https://gin.g-node.org/cajal/Lurz2020.git /content/data

# Developing models of neural population responses to natural images

Now that you have been primed with just enough background neuroscience, let's get our hand dirty and try to build our first neural predictive models.

As part of the setup, we have downloaded a 2-photon imaging dataset from mouse primary visual cortex as we present 1000s of natural images (if not done yet, please do so immediately by stepping through the beginning sections of this notebook).

In [1]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt

## Navigating the neuroscience data

As with any data science project, you must start by understanding your data! Let's take some time to navigate the data you downloaded.

In [2]:
ls ./data/static20457-5-9-preproc0/

change.log  config.json  [0m[01;34mdata[0m/  [01;34mmeta[0m/


In [None]:
ls ./data/static20457-5-9-preproc0/data

In [None]:
ls ./data/static20457-5-9-preproc0/data/responses | head -30

In [None]:
ls ./data/static20457-5-9-preproc0/data/images | head -30

You can see that both responses and contained in collections of `numpy` files named like `1.npy` or `31.npy`. The number here corresponds to a specific **trial** or simply different image presentation during an experiment.

Let's take a look at some of these files.

### Loading data files one at a time

Let's pick some trial and load the image as well as the response

In [None]:
trial_idx = 1100
trial_image = np.load(f'./data/static20457-5-9-preproc0/data/images/{trial_idx}.npy')
trial_responses = np.load(f'./data/static20457-5-9-preproc0/data/responses/{trial_idx}.npy')

The image is shaped as $\text{channel} \times \text{height} \times \text{width}$

In [None]:
trial_image.shape

In [None]:
plt.imshow(trial_image.squeeze(), cmap='gray', vmin=0, vmax=255)
plt.axis('off')

In contrast, the shape of `trial_response` is simply the number of neurons

In [None]:
trial_responses.shape

In [None]:
trial_responses.min() # responses are practically always >= 0

In [None]:
trial_responses.max()

In [None]:
fig, ax = plt.subplots(1, 1, dpi=150)
ax.hist(trial_responses, 100);

You can see most neuron's responses stay very close to 0 - signifying no activity.

### Loading the entire dataset

While we can inspect the image and the corresponding neural population responses one image at a time, this is quite cumbersome and also impractical for use in network training. Fortunately, the `lviv` package provides us with a convenience function that will help to load the entire dataset as PyTorch dataloaders.

In [3]:
from lviv.dataset import load_dataset

As we prepare the dataloaders, we get to specify the batch size.

In [4]:
dataloaders = load_dataset(path = '/content/data/static20457-5-9-preproc0', batch_size=30)

The function returns a dictionary consisting of three dataloaders for training, validation, and test set.

In [None]:
dataloaders

Let's specifically look at the trainset dataloader

In [13]:
train_loader = dataloaders['train']

Total number of images can be checked as follows:

In [14]:
len(train_loader.sampler)

4472

We can inspect what it returns per batch:

In [15]:
images, responses = next(iter(train_loader))

In [16]:
images.shape

torch.Size([60, 1, 36, 64])

In [17]:
responses.shape

torch.Size([60, 5335])

As expected, you can see it returns a batch size of 60 images and responses for all neurons.

Similar inspection can be done on the **validation** and **testing** dataloaders.

In [18]:
# number of images in validation set
len(dataloaders['validation'].sampler)

522

In [19]:
# number of images in test set
len(dataloaders['test'].sampler)

999

You might think that we have a lot of images in test set, but this is because test set consists of repeated images.

Some additional trial information can be observed by accessing the underlying PyTorch dataset object and looking at the `trial_info`. Note that this is not part of the standard PyTorch dataset/dataloader interface, but rather a feature specifically provided by the library!

In [20]:
# Access to the dataset object that underlies all dataloaders
testset = dataloaders['test'].dataset

In [21]:
test_trials = np.where(testset.trial_info['tiers'] == 'test')[0]

In [22]:
image_ids = testset.trial_info['frame_image_id']

In [23]:
np.unique(image_ids[test_trials])

array([ 104,  128,  183,  355,  479,  483,  656,  803,  830,  936, 1201,
       1494, 1596, 1652, 1656, 1682, 1731, 1756, 1796, 2005, 2008, 2014,
       2159, 2214, 2389, 2586, 2710, 2746, 2747, 2803, 2816, 2825, 2954,
       3018, 3107, 3144, 3163, 3372, 3427, 3438, 3487, 3507, 3562, 3702,
       3847, 3924, 4231, 4295, 4373, 4397, 4400, 4430, 4594, 4619, 4667,
       4674, 4717, 4739, 4782, 4812, 4814, 4821, 4923, 4953, 5034, 5128,
       5166, 5225, 5264, 5288, 5322, 5334, 5399, 5402, 5504, 5640, 5671,
       5679, 5754, 5782, 6013, 6034, 6066, 6082, 6205, 6238, 6248, 6490,
       6562, 6773, 6790, 6831, 6886, 7017, 7028, 7107, 7119, 7120, 7154,
       7495])

In [24]:
len(np.unique(image_ids[test_trials]))

100

So you can see that the test set consists of 100 unique images, each repeated up to 10 times.

In [25]:
testset.trial_info.keys()

['trial_idx',
 'session',
 'frame_trial_ts',
 'frame_last_flip',
 'frame_image_id',
 'frame_image_class',
 'frame_pre_blank_period',
 'condition_hash',
 'tiers',
 'animal_id',
 'scan_idx',
 'frame_presentation_time']

In [26]:
testset.trial_info.frame_image_id  # gives information about presented image ID

array([1301, 5927, 3982, ...,  464,  819, 3025])

In [27]:
testset

FileTreeDataset /content/data/static20457-5-9-preproc0 (n=5993 items)
	images, responses

In [28]:
dataloaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x7fbcf66025b0>,
 'validation': <torch.utils.data.dataloader.DataLoader at 0x7fbcf66024f0>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7fbcf66021f0>}

In [29]:
len(dataloaders['validation'].dataset.trial_info.frame_image_id)  # gives information about presented image ID

5993

# Modeling the neuronal responses

Now that we have successfully loaded the dataset and inspected its contents, it's time for us to start **modeling** the responses.

We will start by building a very basic **Linear-Nonlinear model** - which is nothing more than a shallow neural network with single linear layer followed by an activation function.

## Linear-Nonlinear (LN) model

### Background

Arguably one of the simplest model of a neuron's response to a stimulus is the **linear-nonlinear (LN) model**. 

Given an image $I \in \mathbb{R}^{h\,\times\,w}$ where $h$ and $w$ are the height and the width of the image, respectively, let us collapse the image into a vector $x \in \mathbb{R}^{hw}$.

A single neuron's response $r$ under linear-nonlinear model can then be expressed as:
$$
r = a(w^\top x + b),
$$
where $w \in \mathbb{R}^{hw}$ and $b \in \mathbb{R}$ are **weight** and **bias**, and $a:\mathbb{R}\mapsto\mathbb{R}$ is a scalar **activating function**.

We can in fact extend to capture the responses of all $N$ neurons simultaneously as:

$$
\mathbf{r} = a(\mathbf{W} x + \mathbf{b}),
$$
where $\mathbf{r} \in \mathbb{R}^{N}$, $\mathbf{W} \in \mathbb{R}^{N\,\times\,hw}$ and $\mathbf{b} \in \mathbb{R}^{N}$.

Hence, each neuron weights each pixel of the image according to the weight $w$ (a column of $\mathbf{W}$) and thus characterizes how much the each neuron "cares" about a specific pixel.

The nonlinear activation function $a(\cdot)$ ensures, among other things, that the output of the network stays above 0. In fitting neuronal responses, we tend to use $a(x) = ELU(x) + 1$ where ELU (Exponential Linear Unit) is defined as follows:

$$
    ELU(x) = 
\begin{cases}
    e^x - 1, & x \lt 0 \\
    x,   & x \ge 0
\end{cases}
$$

In [None]:
# Plotting ELU function
x = np.linspace(-2, 2)
plt.plot(x, F.elu(torch.Tensor(x)))
plt.axhline(0, c='r', ls='--')

We shift it by 1 to ensure it will always remain positive

In [None]:
# Plotting ELU+1 function
x = np.linspace(-2, 2)
plt.plot(x, F.elu(torch.Tensor(x))+1)
plt.axhline(0, c='r', ls='--')

Overall, it can be seen that a linear-nonlinear is nothing more than a single linear layer on flattened image, followed by a nonlinear activation. Now let's go ahead and implment our LN model in PyTorch!

### Implementation

We therefore go ahead and implement a simple network consisting of a linear layer followed by ELU + 1 activation

In [5]:
class Linear(nn.Module):
    def __init__(
        self,
        input_height,
        input_width,
        n_neurons,
        momentum=0.1,
        init_std=1e-3,
        gamma=0.0,
    ):
        super().__init__()
        self.bn = nn.BatchNorm2d(1, momentum=momentum, affine=False)
        self.linear = nn.Linear(input_height * input_width, n_neurons)
        self.gamma = gamma
        self.init_std = init_std
        self.initialize()
        

    def initialize(self, std=None):
        if std is None:
            std = self.init_std
        nn.init.normal_(self.linear.weight.data, std=std)

    def forward(self, x):
        x = self.bn(x)
        x = self.linear(x.flatten(1))
        return nn.functional.elu(x) + 1

    def regularizer(self):
        return self.gamma * self.linear.weight.abs().sum()


And that's it! We have now designed our first network model of the neuron's responses!

**BONUS**: notice that we used batch normalization (BN) layer right before the linear layer? This empirically helps to stabilize the training, allowing us to be not too sensitive to the weight and bias initialization. You could totally implement and train a LN network without such BN layer and you are more than welcome to try! However if you do, be very aware of the network weight initializations and the chocie of learning rate during the training.

Finally, let's instantiate the model before we move onto the next step of training the model!

In [6]:
ln_model = Linear(input_height=64, input_width=36, n_neurons=5335, gamma=0.1)

### Training the network

Now that we have a candidate model designed, it's time to train it. While we could use standard set of optimizers as provided by PyTorch to implement our training routine, here we are provided with a convenience function `train_model` that would handle a lot of the training boiler plate.

In [7]:
from lviv.trainers import train_model

Critically, `train_model` sets up training based on **Poisson loss** and also perform early stopping based on **correlation** of the predicted neuronal responses with the actual neuronal responses on the **validation set**. Let's now talk briefly about our objective (loss) function of choice in training neuron response models - the Poisson loss.

### Mathematical aside: Poisson Loss

#### How we are **actually** modeling the noisy neuronal responses

The use of **Poisson loss** follows from the assumption that, *conditioned on the stimulus*, the neurons' responses follow an **independent Poisson** distribution. That is, given an input image $x$, the population response $\mathbf{r}$ is distributed as:

$$
p(\mathbf{r} | x) = \prod_i^N \text{Poiss}(r_i; \lambda_i(x))
$$

where $r_i$ is the $i^\text{th}$ neuron in the population $\mathbf{r}$. The $\lambda_i$ is the parameter for Poisson distribution that controls its **average value**. Here we express $\lambda_i(x)$ to indicate the fact that the average response for each neuron is expected to vary *as a function of the input image*. We can express this average matching as:

$$
\mathbb{E}[r_i|x] = \lambda_i(x)
$$

In fact, it is precise this function $\lambda_i(x)$ that we are modeling using LN models and, in the next step, more complex neural networks. In otherwords, we are learning $\lambda_i(x) = f_i(x, \theta)$, where $\theta$ is the trainable parameters of the model.

Putting all together, this means that, our model $f(x, \theta)$ is really modeling the average activity of the neurons,

$$
\mathbb{E}[\mathbf{r}|x] = \mathbf{f}(x, \theta)
$$

while we are assuming that the neurons are distribution according to **independent Poisson** distribution around the average responses by our model $\mathbf{f}(x, \theta)$.

#### Deriving the objective function

Poisson distribution is defined as follows:

$$
p(r) = \text{Poiss}(r; \lambda) = \frac{e^{-\lambda}\lambda^{r}}{r!}
$$

During the training, we would want to adjust the model parameter $\theta$ to maximize the chance of observing the response $\mathbf{r}$ to a known image $x$. This is achieved by **maximizing** the log-likelihood function $\log p(\mathbf{r}|x, \theta)$, or equivalently by **minimzing the negative log-likelihood function** as the objective function $L(x, \mathbf{r}, \theta)$:

$$
\begin{align}
L(x, \mathbf{r}, \theta) &= -\log p(\mathbf{r}|x, \theta) \\
&= -\log \prod_i \text{Poiss}(r_i; f_i(x, \theta)) \\
&= -\sum_i \log \frac{e^{-f_i(x, \theta)}f_i(x, \theta)^{r_i}}{r_i!} \\
&= \sum_i \left(f_i(x, \theta) - r_i \log f_i(x, \theta) + \log r_i! \right)
\end{align}
$$


During the optimization, we seek for $\theta$ that would minimize the loss $L$. Note that since the term $log r_i!$ does not depend on $\theta$, it can be safely dropped from Poisson loss. Hence you would commonly see the following expression as the definition of the **Poisson loss**

$$
L_\text{Poiss}(x, \mathbf{r}, \theta) = \sum_i \left(f_i(x, \theta) - r_i \log f_i(x, \theta)\right)
$$

### Performing the training

Now that we have the theoretical foundation for the training and the choice of the objective function under our belt, let's go ahead and train the network. Because the function `train_model` handles a lot underneath the hood, training a model is just as easy as invoking the function by passing it the model to be trained and the dataloaders!

In [7]:
from lviv.trainers import train_model

In [8]:
score, output, model_state = train_model(model=ln_model, dataloader=dataloaders)

correlation -0.0019075753
poisson_loss 9395851.0


Epoch 1: 100%|██████████| 150/150 [00:03<00:00, 44.69it/s]

[001|00/05] ---> 0.05955984443426132





correlation 0.059559844
poisson_loss 3384139.8


Epoch 2: 100%|██████████| 150/150 [00:00<00:00, 176.45it/s]


[002|00/05] ---> 0.06592471897602081
correlation 0.06592472
poisson_loss 3274338.0


Epoch 3: 100%|██████████| 150/150 [00:00<00:00, 177.13it/s]


[003|01/05] -/-> 0.06381010264158249
correlation 0.0638101
poisson_loss 3274534.0


Epoch 4: 100%|██████████| 150/150 [00:00<00:00, 174.58it/s]


[004|01/05] ---> 0.07152469456195831
correlation 0.071524695
poisson_loss 3243809.5


Epoch 5: 100%|██████████| 150/150 [00:00<00:00, 175.18it/s]


[005|00/05] ---> 0.07718321681022644
correlation 0.07718322
poisson_loss 3174631.8


Epoch 6: 100%|██████████| 150/150 [00:00<00:00, 176.82it/s]


[006|00/05] ---> 0.07769027352333069
correlation 0.07769027
poisson_loss 3153348.0


Epoch 7: 100%|██████████| 150/150 [00:00<00:00, 176.00it/s]


[007|01/05] -/-> 0.07006139308214188
correlation 0.07006139
poisson_loss 3368900.8


Epoch 8: 100%|██████████| 150/150 [00:00<00:00, 176.07it/s]


[008|02/05] -/-> 0.06875057518482208
correlation 0.068750575
poisson_loss 3328448.2


Epoch 9: 100%|██████████| 150/150 [00:00<00:00, 177.04it/s]


[009|03/05] -/-> 0.06966907531023026
correlation 0.069669075
poisson_loss 3304112.5


Epoch 10: 100%|██████████| 150/150 [00:00<00:00, 176.75it/s]


[010|04/05] -/-> 0.0717167928814888
correlation 0.07171679
poisson_loss 3497721.5


Epoch 11: 100%|██████████| 150/150 [00:00<00:00, 173.33it/s]


[011|05/05] -/-> 0.07496411353349686
Restoring best model after lr decay! 0.074964 ---> 0.077690
correlation 0.07769027
poisson_loss 3153348.0


Epoch 12: 100%|██████████| 150/150 [00:00<00:00, 173.67it/s]


Epoch    12: reducing learning rate of group 0 to 1.5000e-03.
[012|01/05] -/-> 0.07143863290548325
correlation 0.07143863
poisson_loss 3309589.8


Epoch 13: 100%|██████████| 150/150 [00:00<00:00, 174.99it/s]


[013|01/05] ---> 0.10395169258117676
correlation 0.10395169
poisson_loss 2500118.0


Epoch 14: 100%|██████████| 150/150 [00:00<00:00, 175.65it/s]


[014|01/05] -/-> 0.10067544877529144
correlation 0.10067545
poisson_loss 2479766.2


Epoch 15: 100%|██████████| 150/150 [00:00<00:00, 177.24it/s]


[015|02/05] -/-> 0.10142096132040024
correlation 0.10142096
poisson_loss 2414144.5


Epoch 16: 100%|██████████| 150/150 [00:00<00:00, 176.30it/s]


[016|03/05] -/-> 0.10160883516073227
correlation 0.101608835
poisson_loss 2430222.5


Epoch 17: 100%|██████████| 150/150 [00:00<00:00, 176.67it/s]


[017|04/05] -/-> 0.09739978611469269
correlation 0.097399786
poisson_loss 2449914.8


Epoch 18: 100%|██████████| 150/150 [00:00<00:00, 170.65it/s]


[018|05/05] -/-> 0.09415452182292938
Restoring best model after lr decay! 0.094155 ---> 0.103952
correlation 0.10395169
poisson_loss 2500118.0


Epoch 19: 100%|██████████| 150/150 [00:00<00:00, 176.22it/s]


Epoch    19: reducing learning rate of group 0 to 4.5000e-04.
[019|01/05] -/-> 0.09806245565414429
correlation 0.098062456
poisson_loss 2533733.0


Epoch 20: 100%|██████████| 150/150 [00:00<00:00, 174.56it/s]


[020|01/05] ---> 0.11151435226202011
correlation 0.11151435
poisson_loss 2310042.0


Epoch 21: 100%|██████████| 150/150 [00:00<00:00, 175.49it/s]


[021|01/05] -/-> 0.10864903032779694
correlation 0.10864903
poisson_loss 2329606.5


Epoch 22: 100%|██████████| 150/150 [00:00<00:00, 176.26it/s]


[022|02/05] -/-> 0.10798943787813187
correlation 0.10798944
poisson_loss 2349926.2


Epoch 23: 100%|██████████| 150/150 [00:00<00:00, 176.19it/s]


[023|03/05] -/-> 0.10652267187833786
correlation 0.10652267
poisson_loss 2274603.0


Epoch 24: 100%|██████████| 150/150 [00:00<00:00, 175.45it/s]


[024|04/05] -/-> 0.10494276136159897
correlation 0.10494276
poisson_loss 2338915.5


Epoch 25: 100%|██████████| 150/150 [00:00<00:00, 177.24it/s]


[025|05/05] -/-> 0.1067856028676033
Restoring best model after lr decay! 0.106786 ---> 0.111514
Restoring best model! 0.111514 ---> 0.111514


### Analyzing the trained network

Woohoo! We have now successfully trained our very first LN model on real neuronal responses! But really, how good is the model?

During the training, the `train_model` function iteratively reported two values: the loss function (Poisson loss) value and the average correlation. 

But what is this correlation? It's simply the correlation computed between our predicted neuronal responses $\hat{r}_i$ and the actual neuronal responses $r_i$ across images in the validation set. We then take the average correlation value **across neurons** to get average correlation.

Being a correlation, the highest possible value is of course 1.0, but practically this is never reached both due to 1) imperfection of our model but more fundamentally due to the noiseness of the neuron's responses. Because of the noise, even a perfect model would never reach a correlation of 1.0.

<font color='red'>
    NOTE TO SELF: Add more here probably plotting some scatter plot for an example neuron, histogram of correlation scores both done on the testset.
</font>

## Going beyond Linear-Nonlinear model by using CNN

We saw that a simple LN model can be trained to achieve above chance performance in predicting the responses of mouse V1 neurons to natural images. But we certainly must be able to do better than that, right?

In the past decase, what has really driven system identification in visual neurons has been the use of convolutional neural networks (CNN). Below, we will try out a very simple CNN to see if we can already reach better performance than LN.

<font color='green'>
    NOTE to collaborators: 
    Please add a simpler implementation of CNN. Ideally it would train just as fast as the simple fully connected linear model given above. 
</font>

In [18]:
from collections import OrderedDict
class CNN(nn.Module):
    def __init__(
        self,
        input_height,
        input_width,
        n_neurons,
        momentum=0.1,
        init_std=1e-3,
        gamma=0.1,
        hidden_channels=8,
    ):
        super(CNN, self).__init__()
        self.init_std = init_std
        self.gamma = gamma

        # CNN core
        self.cnn_core = nn.Sequential(
            OrderedDict(
                [
                    ("conv1", nn.Conv2d(1, hidden_channels, 15, padding=15 // 2, bias=False)),
                    ("bn1", nn.BatchNorm2d(hidden_channels, momentum=momentum)),
                    ("elu1", nn.ELU()),
                    ("conv2", nn.Conv2d(hidden_channels, hidden_channels, 13, padding=13 // 2, bias=False)),
                    ("bn2", nn.BatchNorm2d(hidden_channels, momentum=momentum)),
                    ("elu2", nn.ELU()),
                    ("conv3", nn.Conv2d(hidden_channels, hidden_channels, 13, padding=13 // 2, bias=False)),
                    ("bn3", nn.BatchNorm2d(hidden_channels, momentum=momentum)),
                    ("elu3", nn.ELU()),
                    ("conv4", nn.Conv2d(hidden_channels, hidden_channels, 13, padding=13 // 2, bias=False)),
                    ("bn4", nn.BatchNorm2d(hidden_channels, momentum=momentum)),
                    ("elu4", nn.ELU()),
                ]
            )
        )

        # Fully connected readout
        self.readout = nn.Sequential(
            OrderedDict(
                [
                    ('fc_ro', nn.Linear(input_height * input_width * hidden_channels, n_neurons)),
                    ('bn_ro', nn.BatchNorm1d(n_neurons, momentum=momentum)),
                ]
            )
        )


    def initialize(self, std=None):
        if std is None:
            std = self.init_std
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, std=std)

    def forward(self, x):
        x = self.cnn_core(x)
        x = x.view(x.size(0), -1)
        x = self.readout(x)
        return nn.functional.elu(x) + 1
    
    def regularizer(self):
        return self.readout[0].weight.abs().sum() * self.gamma


In [19]:
cnn_model = CNN(input_height=64, input_width=36, n_neurons=5335)
score, output, model_state = train_model(model=cnn_model, dataloader=dataloaders)

correlation -0.0016252714
poisson_loss 3999782.0


Epoch 1: 100%|██████████| 150/150 [00:05<00:00, 27.86it/s]


[001|00/05] ---> 0.08700721710920334
correlation 0.08700722
poisson_loss 2136667.2


Epoch 2: 100%|██████████| 150/150 [00:05<00:00, 27.90it/s]


[002|00/05] ---> 0.12622417509555817
correlation 0.12622418
poisson_loss 1907168.5


Epoch 3: 100%|██████████| 150/150 [00:05<00:00, 27.85it/s]


[003|00/05] ---> 0.14798565208911896
correlation 0.14798565
poisson_loss 1882363.8


Epoch 4: 100%|██████████| 150/150 [00:05<00:00, 27.77it/s]


[004|00/05] ---> 0.1612660437822342
correlation 0.16126604
poisson_loss 1846741.8


Epoch 5: 100%|██████████| 150/150 [00:05<00:00, 27.82it/s]


[005|00/05] ---> 0.18843691051006317
correlation 0.18843691
poisson_loss 1812948.1


Epoch 6: 100%|██████████| 150/150 [00:05<00:00, 27.79it/s]


[006|00/05] ---> 0.20305706560611725
correlation 0.20305707
poisson_loss 1804047.0


Epoch 7: 100%|██████████| 150/150 [00:05<00:00, 27.83it/s]


[007|00/05] ---> 0.21175087988376617
correlation 0.21175088
poisson_loss 1786120.8


Epoch 8: 100%|██████████| 150/150 [00:05<00:00, 27.83it/s]


[008|00/05] ---> 0.2153705209493637
correlation 0.21537052
poisson_loss 1803845.0


Epoch 9: 100%|██████████| 150/150 [00:05<00:00, 27.82it/s]


[009|01/05] -/-> 0.21321766078472137
correlation 0.21321766
poisson_loss 1801716.8


Epoch 10: 100%|██████████| 150/150 [00:05<00:00, 27.83it/s]


[010|02/05] -/-> 0.20773832499980927
correlation 0.20773832
poisson_loss 1819412.0


Epoch 11: 100%|██████████| 150/150 [00:05<00:00, 27.82it/s]


[011|03/05] -/-> 0.20487961173057556
correlation 0.20487961
poisson_loss 1854615.2


Epoch 12: 100%|██████████| 150/150 [00:05<00:00, 27.82it/s]


[012|04/05] -/-> 0.19463412463665009
correlation 0.19463412
poisson_loss 1899826.2


Epoch 13: 100%|██████████| 150/150 [00:05<00:00, 27.81it/s]


[013|05/05] -/-> 0.18218643963336945
Restoring best model after lr decay! 0.182186 ---> 0.215371
correlation 0.21537052
poisson_loss 1803845.0


Epoch 14: 100%|██████████| 150/150 [00:05<00:00, 27.87it/s]


Epoch    14: reducing learning rate of group 0 to 1.5000e-03.
[014|01/05] -/-> 0.21291543543338776
correlation 0.21291544
poisson_loss 1814011.0


Epoch 15: 100%|██████████| 150/150 [00:05<00:00, 27.84it/s]


[015|02/05] -/-> 0.21255838871002197
correlation 0.21255839
poisson_loss 1802551.0


Epoch 16: 100%|██████████| 150/150 [00:05<00:00, 27.84it/s]


[016|03/05] -/-> 0.20717494189739227
correlation 0.20717494
poisson_loss 1828587.0


Epoch 17: 100%|██████████| 150/150 [00:05<00:00, 27.85it/s]


[017|04/05] -/-> 0.19936375319957733
correlation 0.19936375
poisson_loss 1855954.0


Epoch 18: 100%|██████████| 150/150 [00:05<00:00, 27.85it/s]


[018|05/05] -/-> 0.19151179492473602
Restoring best model after lr decay! 0.191512 ---> 0.215371
correlation 0.21537052
poisson_loss 1803845.0


Epoch 19: 100%|██████████| 150/150 [00:05<00:00, 27.80it/s]


[019|01/05] -/-> 0.2124868631362915
correlation 0.21248686
poisson_loss 1791308.9


Epoch 20: 100%|██████████| 150/150 [00:05<00:00, 27.81it/s]


Epoch    20: reducing learning rate of group 0 to 4.5000e-04.
[020|02/05] -/-> 0.21149353682994843
correlation 0.21149354
poisson_loss 1804995.8


Epoch 21: 100%|██████████| 150/150 [00:05<00:00, 27.80it/s]


[021|03/05] -/-> 0.2077120542526245
correlation 0.20771205
poisson_loss 1812928.5


Epoch 22: 100%|██████████| 150/150 [00:05<00:00, 27.86it/s]


[022|04/05] -/-> 0.20526854693889618
correlation 0.20526855
poisson_loss 1815964.8


Epoch 23: 100%|██████████| 150/150 [00:05<00:00, 27.84it/s]


[023|05/05] -/-> 0.20167630910873413
Restoring best model after lr decay! 0.201676 ---> 0.215371
Restoring best model! 0.215371 ---> 0.215371


## Trying out the State-of-the-Art (SOTA) model

Now that we got some sense on how we could go about training linear and nonlinear network models to predict V1 neuron responses to natural images, and we just saw how the nonlinear network seems to bring significant improvement to the performance beyond the LN network.

You might now be wondering, how good can we get? To get a sense of this, let's go ahead and train a state-of-the-art (SOTA) network model for mouse V1 responses to natual images as published in our recent work in [Lurz et al. ICLR 2021](https://github.com/sinzlab/Lurz_2020_code).

To keep things simple, I have provided for the network implementation in the `lviv` package, so you can build the model just by invoking a function!

In [None]:
from lviv.models import build_lurz2020_model
model_config = {'init_mu_range': 0.55,
                'init_sigma': 0.4,
                'input_kern': 15,
                'hidden_kern': 13,
                'gamma_input': 1.0,
                'grid_mean_predictor': {'type': 'cortex',
                                        'input_dimensions': 2,
                                        'hidden_layers': 0,
                                        'hidden_features': 0,
                                        'final_tanh': False},
                'gamma_readout': 2.439
               }

sota_model = build_lurz2020_model(**model_config, dataloaders=dataloaders, seed=1234)

In [None]:
score, output, model_state = train_model(model=sota_model, dataloader=dataloaders)

It turns out that we can have *linearized* version of the SOTA model. This effectively removes all nonlinear operations within the network except for the very last nonlinear activation, rendering the network into a **LN model** but with more complex architecture.

In [None]:
linear_model_config = dict(model_config) # copy the config
linear_model_config['linear'] = True # set linear to True to make it a LN model!

sota_ln_model = build_lurz2020_model(**linear_model_config, dataloaders=dataloaders, seed=1234)

In [None]:
score, output, model_state = train_model(model=sota_model, dataloader=dataloaders)

# Analyzing the trained model to gain insights into the brain

<font color='green'>
    NOTE to collaborators: 
    Please provide code for generating gradient receptive field and MEI for the sota networks. By this point, they should have `sota_model` and `sota_ln_model` corresponding to the best nonlinear and linear model based on the model architecture as found in Lurz et al. 2021.
</font>