<a href="https://colab.research.google.com/github/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> &nbsp; <a href="https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/neuromatch/NeuroAI_Course/main/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial5.ipynb" target="_parent"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open in Kaggle"/></a>

# Bonus (Tutorial 5): Dynamical Similarity Analysis (DSA)

**Week 1, Day 3: Comparing Artificial And Biological Networks**

**By Neuromatch Academy**

__Content creators:__ Mitchell Ostrow

__Content reviewers:__ Xaq Pitkow, Hlib Solodzhuk

__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk, Patrick Mineault


This short notebook expands the toolset of network comparison by taking a look at another important dimension for analysis - time. In particular, it would be beneficial to understand how the systems evolve over time and whether their dynamics are similar. The presented materials are the most similar to the ones introduced in [Tutorial 2](https://neuroai.neuromatch.io/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/student/W1D3_Tutorial2.html) for this day, and one of the projects on [Comparing Networks](https://neuroai.neuromatch.io/projects/project-notebooks/ComparingNetworks.html) is exactly about DSA.

We begin by looking at the representational geometry of recurrent neural networks, providing an extension to the analyses we saw in Tutorial 2. We then look closer at the details of DSA and how the analysis of model dynamics can be approached from another direction.

In [None]:
# @title Install and import feedback gadget

!pip install vibecheck --quiet

from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch_neuroai",
            "user_key": "wb2cxze8",
        },
    ).render()


feedback_prefix = "W1D3_Bonus"

In [None]:
# @title Bonus material slides

from IPython.display import IFrame
from ipywidgets import widgets
out = widgets.Output()

link_id = "8fx23"

with out:
    print(f"If you want to download the slides: https://osf.io/download/{link_id}/")
    display(IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render", width=730, height=410))
display(out)

---
# Representational Geometry of Recurrent Models

Transformations of representations can occur across space and time, e.g., layers of a neural network and steps of recurrent computation.

Just as the layers in a feedforward DNN can change the representational geometry to perform a task, steps in a recurrent network can reuse the same layer to reach the same computational depth.

In this section, we look at a very simple recurrent network with only 2650 trainable parameters.

Here is a diagram of this network:

![Recurrent convolutional neural network](https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D3_ComparingArtificialAndBiologicalNetworks/static/rcnn_tutorial.png?raw=true)

In [None]:
# @title Helper functions

# TODO: 

In [None]:
# @title Grab a recurrent model

args = build_args()
train_loader, test_loader = fetch_dataloaders(args)
path = "recurrent_model.pth"
model_recurrent = torch.load(path, map_location=args.device)

<br>We can first look at the computational steps in this network. As we see below, the `conv2` operation is repeated for 5 times.

In [None]:
train_nodes, _ = get_graph_node_names(model_recurrent)
print('The computational steps in the network are: \n', train_nodes)

Plotting the RDMs after each application of the `conv2` operation shows the same progressive emergence of the blockwise structure around the diagonal, mediating the correct classification in this task.

In [None]:
imgs, labels = sample_images(test_loader, n=20)
return_layers = ['conv2', 'conv2_1', 'conv2_2', 'conv2_3', 'conv2_4']
model_features = extract_features(model_recurrent, imgs.to(device), return_layers)

rdms, rdms_dict = calc_rdms(model_features)
plot_rdms(rdms_dict)

We can also look at how the different dimensionality reduction techniques capture the dynamics of changing geometry.

In [None]:
return_layers = ['conv2', 'conv2_1', 'conv2_2', 'conv2_3', 'conv2_4']

imgs, labels = sample_images(test_loader, n=50) #grab 500 samples from the test set
model_features = extract_features(model_recurrent, imgs.to(device), return_layers)

plot_dim_reduction(model_features, labels, transformer_funcs =['PCA', 'MDS', 't-SNE'])

## Representational geometry paths for recurrent models

We can look at the model's recurrent computational steps as a path in the representational geometry space.

In [None]:
imgs, labels = sample_images(test_loader, n=50) #grab 500 samples from the test set
model_features_recurrent = extract_features(model_recurrent, imgs.to(device), return_layers='all')

#rdms, rdms_dict = calc_rdms(model_features)
features = {'recurrent model': model_features_recurrent}
model_colors = {'recurrent model': 'y'}

rep_path(features, model_colors, labels)

We can also look at the paths taken by the feedforward and the recurrent models and compare them.

In [None]:
imgs, labels = sample_images(test_loader, n=50) #grab 500 samples from the test set
model_features = extract_features(model, imgs.to(device), return_layers='all')
model_features_recurrent = extract_features(model_recurrent, imgs.to(device), return_layers='all')

features = {'feedforward model': model_features, 'recurrent model': model_features_recurrent}
model_colors = {'feedforward model': 'b', 'recurrent model': 'y'}

rep_path(features, model_colors, labels)

What we can see here is that different models can take very different paths in the space of representational geometries to map images to labels. This is because there exist many different functional mappings to do a classification task.

In [None]:
# @title Submit your feedback
content_review(f"{feedback_prefix}_recurrent_models")

We are now going to introduce an exciting new method to compare dynamical systems, Dynamic Similarity Analysis! Please watch the video below and follow through the example code in the following subsection to reproduce some of the figures you will see in the explanatory video. Enjoy!

In [None]:
# @title Video 1: Dynamical Similarity Analysis

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display

class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)

def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents

video_ids = [('Youtube', 'FHikIsQFQvM'), ('Bilibili', 'BV1qm421g7hV')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

In [None]:
# @title Submit your feedback
content_review(f"{feedback_prefix}_DSA_video")