# NX-414 - Mini-project

Group members: Kolly Florian, Mikami Sarah, Montlahuc Louise

## Project description
The objectives of the project are:
- Predict neural activity using linear regression from images and from neural network layers.
- Quantify the goodness of the model
- Compare the results across the network layers and between trained/random neural network
- Predict the neural activity using a neural network in a data-driven approach
- Develop the most accurate model for predicting IT neural activity

Specifically, we use the data from the following [paper](https://www.jneurosci.org/content/jneuro/35/39/13402.full.pdf). The behavioral experiment consisted in showing to non-human primates some images while recording the neural activity with multielectrode arrays from the inferior temporal (IT) cortex. In the data we provided you, the neural activity and the images are already pre-processed and you will have available the images and the corresponding average firing rate (between 70 and 170 ms) per each neuron.

## Imports

In [None]:
# ALL NECESSARY IMPORTS
from abc import ABC
import inspect
import h5py
import os

import gdown
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, SequentialLR
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.calibration import LabelEncoder
from fvcore.common.registry import Registry
from torchvision.models import resnext101_32x8d, ResNeXt101_32X8D_Weights

## Models interface

For genericity and reusability, we define an interface for all models.

In [None]:
# IModel INTERFACE
class IModel(ABC, nn.Module):
    """
    Abstract base class for a model.
    This class defines the interface that all model classes must implement.
    All models inheriting from this class should have a self.model attribute!
    """
    def __init__(self):
        super().__init__()
        self.PCs = dict()
        self.PCA = None
        self.ACTs = dict()

    def forward(self, images):
        return self.model(images)
    
    def get_layers(self):
        """
        Returns the layers on which to do the linear probing.
        """
        layers = []
        layers_name = [name for name, _ in self.model.named_children()]
        for name in layers_name[-4:]:
            module = self.model.get_submodule(name)
            layers.append((name, module))
        return layers
        
    def get_activations(self, hook_name):
        """
        Returns the activations of the model.
        The hook_name can be 'all' for all activations or 'pca' for 1000 principal components.
        """
        if hook_name == 'all':
            return self.ACTs
        elif hook_name == 'pca':
            return self.PCs
        else:
            raise ValueError("Invalid hook name. Use 'all' or 'pca'.")
        
    def reset_activations(self):
        """
        Resets the activations of the model.
        """
        self.PCs = dict()
        self.ACTs = dict()

    def _get_PCs_hook(self, module, input, output, layer_name):
        print('Layer:', layer_name)
        activations = output.detach().cpu().numpy().reshape(output.shape[0], -1)
        print('Activations shape:', activations.shape)
        pca = PCA(n_components=1000)
        print(pca.type())
        self.PCA = pca
        pca_features = pca.fit_transform(activations)
        print('Principal components shape:', pca_features.shape)
        self.PCs[layer_name] = pca_features

    def _get_activations_hook(self, module, input, output, layer_name):
        activations = output.detach().cpu().numpy().reshape(output.shape[0], -1)
        self.ACTs[layer_name] = activations
    
    def register_hook(self, hook_name):
        """
        Registers a hook to the model.
        The hook can be 'all' for all activations or 'pca' for 1000 principal components.
        """
        handles = []
        for name, layer in self.get_layers():
            if hook_name == 'all':
                handle = layer.register_forward_hook(lambda m, i, o, n=name: self._get_activations_hook(m, i, o, n))
            elif hook_name == 'pca':
                handle = layer.register_forward_hook(lambda m, i, o, n=name: self._get_PCs_hook(m, i, o, n))
            handles.append(handle)
        return handles
    
    def change_head(self, layer, num_classes):
        """
        Sets a final head (classification or regression) after the indicated layer.
        """
        return ModifiedModel(self.model, layer, num_classes)

For finetuning models, we create a class that extends our generic model interface ```IModel``` and contains both the original model alongside the modified model. As we are testing multiple models and multiple finetuning methods, our goal is to stay as generic as possible.

In [None]:
# ModifiedModel class
class ModifiedModel(IModel):
    def __init__(self, base_model, insert_after, num_classes):
        super().__init__()
        self.base_model = base_model
        self.insert_after = insert_after
        self.num_classes = num_classes

        # Extract layers up to the insertion point
        self.features = nn.Sequential()
        for name, module in base_model.named_children():
            self.features.add_module(name, module)
            if name == insert_after:
                self.layer = (name, module)
                break

        # TODO testing freezing the layers
        # for param in self.features.parameters():
        #     param.requires_grad = False

        # Determine input dim for new head
        dummy_input = torch.randn(1, 3, 224, 224)
        with torch.no_grad():
            out = self._forward_features(dummy_input)
        out = out.view(out.size(0), -1)
        head_in_features = out.shape[1]

        # Define new head (classification or regression)
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(head_in_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, num_classes)
        )

    def _forward_features(self, x):
        for name, module in self.features.named_children():
            if isinstance(module, nn.ModuleList):
                for submodule in module:
                    x = submodule(x)
            else:
                x = module(x)
        return x

    def forward(self, x):
        x = self._forward_features(x)
        x = self.fc(x)
        return x
    
    def get_layers(self):
        return self.layer

We then create a build function that takes a model name and a set of parameters, and returns the corresponding model.

In [None]:
MODEL_REGISTRY = Registry("MODEL")
MODEL_REGISTRY.__doc__ = """
Registry for models.

The registered object will be called with `obj()`.
The call should return a `nn.Module` object.
"""

def accepts_seed(cls):
    init = cls.__init__
    sig = inspect.signature(init)
    return 'seed' in sig.parameters

def make_model(name, seed):
    """
    Builds the video model.
    Args:
        name (string): name of the model to build.
    Returns:
        model (nn.Module): the built model.
    """
    model = MODEL_REGISTRY.get(name)
    if accepts_seed(model):
        model = model(seed)
    else:
        model = model()
    
    return model

## Loading the data

Let's now setup the loading of the data. We start with some utility functions that were given in the project.

In [None]:
### Utils
def download_data(path_to_data):
    if not os.path.exists(path_to_data):
        os.makedirs(os.path.dirname(path_to_data))
    output = "IT_data.h5"
    data_path = os.path.join(path_to_data, output)
    if not os.path.exists(data_path):
        url = "https://drive.google.com/file/d/1s6caFNRpyR9m7ZM6XEv_e8mcXT3_PnHS/view?usp=share_link"
        gdown.download(url, os.path.join(path_to_data, output), quiet=False, fuzzy=True)

def load_it_data(path_to_data):
    """ Load IT data

    Args:
        path_to_data (str): Path to the data

    Returns:
        np.array (x6): Stimulus train/val/test; objects list train/val/test; spikes train/val
    """

    datafile = h5py.File(os.path.join(path_to_data,'IT_data.h5'), 'r')

    stimulus_train = datafile['stimulus_train'][()]
    spikes_train = datafile['spikes_train'][()]
    objects_train = datafile['object_train'][()]
    
    stimulus_val = datafile['stimulus_val'][()]
    spikes_val = datafile['spikes_val'][()]
    objects_val = datafile['object_val'][()]
    
    stimulus_test = datafile['stimulus_test'][()]
    objects_test = datafile['object_test'][()]

    ### Decode back object type to latin
    objects_train = [obj_tmp.decode("latin-1") for obj_tmp in objects_train]
    objects_val = [obj_tmp.decode("latin-1") for obj_tmp in objects_val]
    objects_test = [obj_tmp.decode("latin-1") for obj_tmp in objects_test]

    return stimulus_train, stimulus_val, stimulus_test, objects_train, objects_val, objects_test, spikes_train, spikes_val

Let's now create a function for retrieving the data we are interested in.

In [None]:
def get_data():
    """Get the data from the IT dataset.

    Returns:
        tuple: tuples (stimulus, objects, spikes) for training and validation sets.
    """
    stimulus_train, stimulus_val, stimulus_test, objects_train, objects_val, objects_test, spikes_train, spikes_val = load_it_data('./data/')
    return (stimulus_train, objects_train, spikes_train), (stimulus_val, objects_val, spikes_val)

## Testing basic models

We start by analyzing the results of using simple models. We here test linear and ridge regression.

In [None]:
# TODO

## Loading the best model

Our best $R^2$ score was obtained by finetuning in a data-driven way a pretrained ResNeXt model.

In [None]:
@MODEL_REGISTRY.register()
class ResNeXt(IModel):
    def __init__(self):
        super(ResNeXt, self).__init__()
        self.model = resnext101_32x8d(weights=ResNeXt101_32X8D_Weights.IMAGENET1K_V1)

base_model = make_model('ResNeXt', 0)
model = base_model.change_head('layer4', 168)

We can now load our saved weights:

In [None]:
weights = torch.load('best_model.pt')
model.load_state_dict(weights)

## Inference time!

Let's try our model on the validation data and check the $R^2$ score we obtain.

In [None]:
train_data, val_data = get_data()
out = model(val_data[0])