# W2D5Tutorial1

**Week 2, Day 5: Mysteries**

**By Neuromatch Academy**

__Content creators:__ Names & Surnames

__Content reviewers:__ Names & Surnames

__Production editors:__ Names & Surnames

<br>

Acknowledgments: [ACKNOWLEDGMENT_INFORMATION]


___


# Tutorial Objectives

*Estimated timing of tutorial: [insert estimated duration of whole tutorial in minutes]*

In this tutorial, you will observe how performance degrades as testing data distribution strays from training distribution.


In [1]:
# @title Tutorial slides

# @markdown These are the slides for the videos in all tutorials today


## Uncomment the code below to test your function

#from IPython.display import IFrame
#link_id = "<YOUR_LINK_ID_HERE>"

print("If you want to download the slides: 'Link to the slides'")
      # Example: https://osf.io/download/{link_id}/

#IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{link_id}/?direct%26mode=render", width=854, height=480)

If you want to download the slides: 'Link to the slides'


---
# Setup



In [2]:
# @title Install dependencies
# @markdown

!pip install numpy matplotlib Pillow torch torchvision transformers ipywidgets gradio trdg scikit-learn networkx pickleshare




In [3]:
# @title Import dependencies
# @markdown

# Standard libraries for basic operations and file handling
import random
import pickleshare
from tqdm import tqdm

# Image processing libraries for handling and manipulating image data
from PIL import Image
import matplotlib.pyplot as plt
import logging

# Deep learning libraries for model building, training, and evaluation
import torch
import torch.nn as nn
import torch.nn.functional as F

# Utility libraries for creating interactive elements and interfaces
import ipywidgets as widgets
import gradio as gr
from IPython.display import IFrame

# Libraries for graph analysis
import networkx as nx

In [4]:
# @title Figure settings
# @markdown

logging.getLogger('matplotlib.font_manager').disabled = True

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

# Section 1: Recurrent Independent Mechanisms

The crucial idea behind this section is that machine learning aims to capture the modular structure of the physical world, where complexity emerges from simpler, independently evolving subsystems. This concept aligns with causal inference, suggesting that understanding and modeling the world involves identifying and integrating these autonomous mechanisms. These mechanisms, which interact sparsely, maintain their functionality even amidst changes in others, highlighting their robustness. Recurrent Independent Mechanisms (RIMs) embody this principle by operating mostly independently, occasionally interacting through an attention-based mechanism for efficient and dynamic information processing. This approach suggests a preference for models that can capture the independence and sparse interactions of mechanisms, potentially leading to more adaptable and generalizable AI systems.

In [5]:
## This will take five minutes, as the repository contains a torch model that is quite heavy

# URL of the repository to clone
!git clone https://github.com/SamueleBolotta/RIMs-Sequential-MNIST/


fatal: destination path 'RIMs-Sequential-MNIST' already exists and is not an empty directory.


In [6]:
%cd RIMs-Sequential-MNIST

/home/samuele/Documenti/GitHub/NeuroAI_Course/tutorials/W2D5_Mysteries/RIMs-Sequential-MNIST


In [7]:
from data import MnistData
from networks import MnistModel, LSTM
import requests

# Function to download files 
def download_file(url, destination):
    print(f"Starting to download {url} to {destination}")
    response = requests.get(url, allow_redirects=True)
    open(destination, 'wb').write(response.content)
    print(f"Successfully downloaded {url} to {destination}")

# Path of the models
model_path = {
    'LSTM': 'lstm_model_dir/lstm_best_model.pt',
    'RIM': 'rim_model_dir/best_model.pt'
}

import os

# URLs of the models
model_urls = {
    'LSTM': 'https://osf.io/4gajq/download',
    'RIM': 'https://osf.io/3squn/download'
}

# Check if model files exist, if not, download them
for model_key, model_url in model_urls.items():
    if not os.path.exists(model_path[model_key]):
        download_file(model_url, model_path[model_key])
        print(f"{model_key} model downloaded.")
    else:
        print(f"{model_key} model already exists. No download needed.")

LSTM model already exists. No download needed.
RIM model already exists. No download needed.


# RIMs

In [8]:
# Config
config = {
    'cuda': True,
    'epochs': 200,
    'batch_size': 64,
    'hidden_size': 600,
    'input_size': 1,
    'model': 'RIM', # Or 'RIM' for the MnistModel
    'train': False, # Set to False to load the saved model
    'num_units': 6,
    'rnn_cell': 'LSTM',
    'key_size_input': 64,
    'value_size_input': 400,
    'query_size_input': 64,
    'num_input_heads': 1,
    'num_comm_heads': 4,
    'input_dropout': 0.1,
    'comm_dropout': 0.1,
    'key_size_comm': 32,
    'value_size_comm': 100,
    'query_size_comm': 32,
    'k': 4,
    'size': 14,
    'loadsaved': 1, # Ensure this is 1 to load saved model
    'log_dir': 'rim_model_dir'
}

# Choose the model
model = MnistModel(config)  # Instantiating MnistModel (RIM) with config
model_directory = model_path['RIM']

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Eval
saved = torch.load(model_directory)
model.load_state_dict(saved['net'])

# Data
data = MnistData(config['batch_size'], (config['size'], config['size']), config['k'])

# Evaluation function
def test_model(model, loader, func):
    total_correct = 0
    total_samples = 0
    model.eval()
    
    print(f"Total validation samples: {loader.val_len()}")  # Total number of validation samples according to the loader

    with torch.no_grad():
        for i in tqdm(range(loader.val_len())):
            test_x, test_y = func(i)
            # Print the original shapes of test_x and test_y
            print(f"Original test_x shape: {test_x.shape}, test_y shape: {test_y.shape}")
            # Adjust the device for test_x and test_y
            test_x = model.to_device(test_x)
            test_y = model.to_device(test_y).long()
            # Print the device-adjusted shapes of test_x and test_y
            print(f"Device-adjusted test_x shape: {test_x.shape}, test_y shape: {test_y.shape}")
            # Get the model output
            probs = model(test_x)
            # Print the shape of the model output
            print(f"Model output (probs) shape: {probs.shape}")

            preds = torch.argmax(probs, dim=1)
            correct = (preds == test_y).sum().item()
            total_correct += correct
            total_samples += test_y.size(0)  # Update to count total samples processed

            # Calculate and print batch accuracy along with correct predictions out of total
            batch_accuracy = correct / test_y.size(0) * 100
            print(f"Batch {i}: {correct}/{test_y.size(0)} Correct Predictions, Batch Accuracy: {batch_accuracy:.2f}%")

    # Calculate and print overall accuracy along with total correct predictions out of total samples
    overall_accuracy = total_correct / total_samples * 100
    print(f"Overall: {total_correct}/{total_samples} Correct Predictions, Overall Accuracy: {overall_accuracy:.2f}%")
    return overall_accuracy
    
# Evaluate on all three validation sets
validation_functions = [data.val_get1, data.val_get2, data.val_get3]
validation_accuracies = []

print(f"Model: {config['model']}, Device: {device}")
print(f"Configuration: {config}")

for func in validation_functions:
    accuracy = test_model(model, data, func)
    validation_accuracies.append(accuracy)

# Print accuracies for all validation sets
for i, accuracy in enumerate(validation_accuracies, 1):
    print(f'Validation Set {i} Accuracy: {accuracy:.2f}%')

Model: RIM, Device: cuda
Configuration: {'cuda': True, 'epochs': 200, 'batch_size': 64, 'hidden_size': 600, 'input_size': 1, 'model': 'RIM', 'train': False, 'num_units': 6, 'rnn_cell': 'LSTM', 'key_size_input': 64, 'value_size_input': 400, 'query_size_input': 64, 'num_input_heads': 1, 'num_comm_heads': 4, 'input_dropout': 0.1, 'comm_dropout': 0.1, 'key_size_comm': 32, 'value_size_comm': 100, 'query_size_comm': 32, 'k': 4, 'size': 14, 'loadsaved': 1, 'log_dir': 'rim_model_dir'}
Total validation samples: 20


  0%|                                                    | 0/20 [00:00<?, ?it/s]

Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


  5%|██▏                                         | 1/20 [00:02<00:49,  2.60s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 0: 63/512 Correct Predictions, Batch Accuracy: 12.30%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 10%|████▍                                       | 2/20 [00:05<00:44,  2.50s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 1: 84/512 Correct Predictions, Batch Accuracy: 16.41%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 15%|██████▌                                     | 3/20 [00:07<00:42,  2.48s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 2: 81/512 Correct Predictions, Batch Accuracy: 15.82%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 20%|████████▊                                   | 4/20 [00:09<00:39,  2.44s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 3: 75/512 Correct Predictions, Batch Accuracy: 14.65%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 25%|███████████                                 | 5/20 [00:12<00:36,  2.40s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 4: 77/512 Correct Predictions, Batch Accuracy: 15.04%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 30%|█████████████▏                              | 6/20 [00:14<00:33,  2.38s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 5: 75/512 Correct Predictions, Batch Accuracy: 14.65%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 35%|███████████████▍                            | 7/20 [00:16<00:30,  2.36s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 6: 65/512 Correct Predictions, Batch Accuracy: 12.70%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 40%|█████████████████▌                          | 8/20 [00:19<00:28,  2.39s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 7: 73/512 Correct Predictions, Batch Accuracy: 14.26%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 45%|███████████████████▊                        | 9/20 [00:21<00:26,  2.40s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 8: 75/512 Correct Predictions, Batch Accuracy: 14.65%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 50%|█████████████████████▌                     | 10/20 [00:24<00:24,  2.41s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 9: 80/512 Correct Predictions, Batch Accuracy: 15.62%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 55%|███████████████████████▋                   | 11/20 [00:26<00:21,  2.41s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 10: 69/512 Correct Predictions, Batch Accuracy: 13.48%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 60%|█████████████████████████▊                 | 12/20 [00:28<00:19,  2.39s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 11: 69/512 Correct Predictions, Batch Accuracy: 13.48%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 65%|███████████████████████████▉               | 13/20 [00:31<00:16,  2.41s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 12: 63/512 Correct Predictions, Batch Accuracy: 12.30%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 70%|██████████████████████████████             | 14/20 [00:33<00:14,  2.45s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 13: 79/512 Correct Predictions, Batch Accuracy: 15.43%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 75%|████████████████████████████████▎          | 15/20 [00:36<00:12,  2.47s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 14: 68/512 Correct Predictions, Batch Accuracy: 13.28%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 80%|██████████████████████████████████▍        | 16/20 [00:39<00:10,  2.50s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 15: 76/512 Correct Predictions, Batch Accuracy: 14.84%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 85%|████████████████████████████████████▌      | 17/20 [00:41<00:07,  2.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 16: 62/512 Correct Predictions, Batch Accuracy: 12.11%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 90%|██████████████████████████████████████▋    | 18/20 [00:44<00:05,  2.53s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 17: 68/512 Correct Predictions, Batch Accuracy: 13.28%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 95%|████████████████████████████████████████▊  | 19/20 [00:46<00:02,  2.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 18: 70/512 Correct Predictions, Batch Accuracy: 13.67%
Original test_x shape: (272, 576, 1), test_y shape: (272,)
Device-adjusted test_x shape: torch.Size([272, 576, 1]), test_y shape: torch.Size([272])


100%|███████████████████████████████████████████| 20/20 [00:48<00:00,  2.42s/it]


Model output (probs) shape: torch.Size([272, 10])
Batch 19: 41/272 Correct Predictions, Batch Accuracy: 15.07%
Overall: 1413/10000 Correct Predictions, Overall Accuracy: 14.13%
Total validation samples: 20


  0%|                                                    | 0/20 [00:00<?, ?it/s]

Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


  5%|██▏                                         | 1/20 [00:01<00:29,  1.54s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 0: 80/512 Correct Predictions, Batch Accuracy: 15.62%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 10%|████▍                                       | 2/20 [00:03<00:27,  1.53s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 1: 85/512 Correct Predictions, Batch Accuracy: 16.60%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 15%|██████▌                                     | 3/20 [00:04<00:25,  1.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 2: 79/512 Correct Predictions, Batch Accuracy: 15.43%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 20%|████████▊                                   | 4/20 [00:06<00:24,  1.53s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 3: 92/512 Correct Predictions, Batch Accuracy: 17.97%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 25%|███████████                                 | 5/20 [00:07<00:22,  1.52s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 4: 87/512 Correct Predictions, Batch Accuracy: 16.99%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 30%|█████████████▏                              | 6/20 [00:09<00:21,  1.52s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 5: 80/512 Correct Predictions, Batch Accuracy: 15.62%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 35%|███████████████▍                            | 7/20 [00:10<00:19,  1.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 6: 71/512 Correct Predictions, Batch Accuracy: 13.87%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 40%|█████████████████▌                          | 8/20 [00:12<00:18,  1.52s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 7: 81/512 Correct Predictions, Batch Accuracy: 15.82%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 45%|███████████████████▊                        | 9/20 [00:13<00:16,  1.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 8: 103/512 Correct Predictions, Batch Accuracy: 20.12%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 50%|█████████████████████▌                     | 10/20 [00:15<00:15,  1.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 9: 97/512 Correct Predictions, Batch Accuracy: 18.95%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 55%|███████████████████████▋                   | 11/20 [00:16<00:13,  1.52s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 10: 69/512 Correct Predictions, Batch Accuracy: 13.48%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 60%|█████████████████████████▊                 | 12/20 [00:18<00:12,  1.52s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 11: 83/512 Correct Predictions, Batch Accuracy: 16.21%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 65%|███████████████████████████▉               | 13/20 [00:19<00:10,  1.52s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 12: 80/512 Correct Predictions, Batch Accuracy: 15.62%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 70%|██████████████████████████████             | 14/20 [00:21<00:09,  1.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 13: 94/512 Correct Predictions, Batch Accuracy: 18.36%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 75%|████████████████████████████████▎          | 15/20 [00:22<00:07,  1.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 14: 82/512 Correct Predictions, Batch Accuracy: 16.02%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 80%|██████████████████████████████████▍        | 16/20 [00:24<00:06,  1.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 15: 80/512 Correct Predictions, Batch Accuracy: 15.62%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 85%|████████████████████████████████████▌      | 17/20 [00:25<00:04,  1.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 16: 78/512 Correct Predictions, Batch Accuracy: 15.23%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 90%|██████████████████████████████████████▋    | 18/20 [00:27<00:03,  1.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 17: 92/512 Correct Predictions, Batch Accuracy: 17.97%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 95%|████████████████████████████████████████▊  | 19/20 [00:28<00:01,  1.51s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 18: 89/512 Correct Predictions, Batch Accuracy: 17.38%
Original test_x shape: (272, 361, 1), test_y shape: (272,)
Device-adjusted test_x shape: torch.Size([272, 361, 1]), test_y shape: torch.Size([272])


100%|███████████████████████████████████████████| 20/20 [00:29<00:00,  1.49s/it]


Model output (probs) shape: torch.Size([272, 10])
Batch 19: 42/272 Correct Predictions, Batch Accuracy: 15.44%
Overall: 1644/10000 Correct Predictions, Overall Accuracy: 16.44%
Total validation samples: 20


  0%|                                                    | 0/20 [00:00<?, ?it/s]

Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


  5%|██▏                                         | 1/20 [00:01<00:20,  1.08s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 0: 76/512 Correct Predictions, Batch Accuracy: 14.84%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 10%|████▍                                       | 2/20 [00:02<00:19,  1.08s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 1: 94/512 Correct Predictions, Batch Accuracy: 18.36%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 15%|██████▌                                     | 3/20 [00:03<00:18,  1.08s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 2: 85/512 Correct Predictions, Batch Accuracy: 16.60%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 20%|████████▊                                   | 4/20 [00:04<00:17,  1.08s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 3: 94/512 Correct Predictions, Batch Accuracy: 18.36%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 25%|███████████                                 | 5/20 [00:05<00:16,  1.08s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 4: 77/512 Correct Predictions, Batch Accuracy: 15.04%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 30%|█████████████▏                              | 6/20 [00:06<00:15,  1.08s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 5: 80/512 Correct Predictions, Batch Accuracy: 15.62%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 35%|███████████████▍                            | 7/20 [00:07<00:14,  1.08s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 6: 77/512 Correct Predictions, Batch Accuracy: 15.04%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 40%|█████████████████▌                          | 8/20 [00:08<00:12,  1.08s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 7: 89/512 Correct Predictions, Batch Accuracy: 17.38%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 45%|███████████████████▊                        | 9/20 [00:09<00:11,  1.08s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 8: 94/512 Correct Predictions, Batch Accuracy: 18.36%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 50%|█████████████████████▌                     | 10/20 [00:10<00:10,  1.10s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 9: 99/512 Correct Predictions, Batch Accuracy: 19.34%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 55%|███████████████████████▋                   | 11/20 [00:11<00:09,  1.10s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 10: 96/512 Correct Predictions, Batch Accuracy: 18.75%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 60%|█████████████████████████▊                 | 12/20 [00:13<00:08,  1.11s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 11: 105/512 Correct Predictions, Batch Accuracy: 20.51%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 65%|███████████████████████████▉               | 13/20 [00:14<00:07,  1.11s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 12: 98/512 Correct Predictions, Batch Accuracy: 19.14%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 70%|██████████████████████████████             | 14/20 [00:15<00:06,  1.11s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 13: 98/512 Correct Predictions, Batch Accuracy: 19.14%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 75%|████████████████████████████████▎          | 15/20 [00:16<00:05,  1.10s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 14: 83/512 Correct Predictions, Batch Accuracy: 16.21%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 80%|██████████████████████████████████▍        | 16/20 [00:17<00:04,  1.11s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 15: 76/512 Correct Predictions, Batch Accuracy: 14.84%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 85%|████████████████████████████████████▌      | 17/20 [00:18<00:03,  1.10s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 16: 95/512 Correct Predictions, Batch Accuracy: 18.55%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 90%|██████████████████████████████████████▋    | 18/20 [00:19<00:02,  1.09s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 17: 101/512 Correct Predictions, Batch Accuracy: 19.73%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])


 95%|████████████████████████████████████████▊  | 19/20 [00:20<00:01,  1.09s/it]

Model output (probs) shape: torch.Size([512, 10])
Batch 18: 96/512 Correct Predictions, Batch Accuracy: 18.75%
Original test_x shape: (272, 256, 1), test_y shape: (272,)
Device-adjusted test_x shape: torch.Size([272, 256, 1]), test_y shape: torch.Size([272])


100%|███████████████████████████████████████████| 20/20 [00:21<00:00,  1.08s/it]

Model output (probs) shape: torch.Size([272, 10])
Batch 19: 50/272 Correct Predictions, Batch Accuracy: 18.38%
Overall: 1763/10000 Correct Predictions, Overall Accuracy: 17.63%
Validation Set 1 Accuracy: 14.13%
Validation Set 2 Accuracy: 16.44%
Validation Set 3 Accuracy: 17.63%





# LSTM

In [9]:
# Config
config = {
    'cuda': True,
    'epochs': 200,
    'batch_size': 64,
    'hidden_size': 600,
    'input_size': 1,
    'model': 'LSTM', 
    'train': False, # Set to False to load the saved model
    'num_units': 6,
    'rnn_cell': 'LSTM',
    'key_size_input': 64,
    'value_size_input': 400,
    'query_size_input': 64,
    'num_input_heads': 1,
    'num_comm_heads': 4,
    'input_dropout': 0.1,
    'comm_dropout': 0.1,
    'key_size_comm': 32,
    'value_size_comm': 100,
    'query_size_comm': 32,
    'k': 4,
    'size': 14,
    'loadsaved': 1, # Ensure this is 1 to load saved model
    'log_dir': 'rim_model_dir'
}

model = LSTM(config)  # Instantiating LSTM with config
model_directory = model_path['LSTM']

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Eval
saved = torch.load(model_directory)
model.load_state_dict(saved['net'])

# Data
data = MnistData(config['batch_size'], (config['size'], config['size']), config['k'])

# Evaluation function
def test_model(model, loader, func):
    total_correct = 0
    total_samples = 0
    model.eval()
    
    print(f"Total validation samples: {loader.val_len()}")  # Total number of validation samples according to the loader

    with torch.no_grad():
        for i in tqdm(range(loader.val_len())):
            test_x, test_y = func(i)
            # Print the original shapes of test_x and test_y
            print(f"Original test_x shape: {test_x.shape}, test_y shape: {test_y.shape}")
            # Adjust the device for test_x and test_y
            test_x = model.to_device(test_x)
            test_y = model.to_device(test_y).long()
            # Print the device-adjusted shapes of test_x and test_y
            print(f"Device-adjusted test_x shape: {test_x.shape}, test_y shape: {test_y.shape}")
            # Get the model output
            probs = model(test_x)
            # Print the shape of the model output
            print(f"Model output (probs) shape: {probs.shape}")

            preds = torch.argmax(probs, dim=1)
            correct = (preds == test_y).sum().item()
            total_correct += correct
            total_samples += test_y.size(0)  # Update to count total samples processed

            # Calculate and print batch accuracy along with correct predictions out of total
            batch_accuracy = correct / test_y.size(0) * 100
            print(f"Batch {i}: {correct}/{test_y.size(0)} Correct Predictions, Batch Accuracy: {batch_accuracy:.2f}%")

    # Calculate and print overall accuracy along with total correct predictions out of total samples
    overall_accuracy = total_correct / total_samples * 100
    print(f"Overall: {total_correct}/{total_samples} Correct Predictions, Overall Accuracy: {overall_accuracy:.2f}%")
    return overall_accuracy
    
# Evaluate on all three validation sets
validation_functions = [data.val_get1, data.val_get2, data.val_get3]
validation_accuracies = []

print(f"Model: {config['model']}, Device: {device}")
print(f"Configuration: {config}")

for func in validation_functions:
    accuracy = test_model(model, data, func)
    validation_accuracies.append(accuracy)

# Print accuracies for all validation sets
for i, accuracy in enumerate(validation_accuracies, 1):
    print(f'Validation Set {i} Accuracy: {accuracy:.2f}%')

Model: LSTM, Device: cuda
Configuration: {'cuda': True, 'epochs': 200, 'batch_size': 64, 'hidden_size': 600, 'input_size': 1, 'model': 'LSTM', 'train': False, 'num_units': 6, 'rnn_cell': 'LSTM', 'key_size_input': 64, 'value_size_input': 400, 'query_size_input': 64, 'num_input_heads': 1, 'num_comm_heads': 4, 'input_dropout': 0.1, 'comm_dropout': 0.1, 'key_size_comm': 32, 'value_size_comm': 100, 'query_size_comm': 32, 'k': 4, 'size': 14, 'loadsaved': 1, 'log_dir': 'rim_model_dir'}
Total validation samples: 20


  5%|██▏                                         | 1/20 [00:00<00:02,  6.85it/s]

Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 0: 63/512 Correct Predictions, Batch Accuracy: 12.30%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 10%|████▍                                       | 2/20 [00:00<00:02,  6.87it/s]

Model output (probs) shape: torch.Size([512, 10])
Batch 1: 69/512 Correct Predictions, Batch Accuracy: 13.48%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 20%|████████▊                                   | 4/20 [00:00<00:02,  6.84it/s]

Batch 2: 68/512 Correct Predictions, Batch Accuracy: 13.28%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 3: 58/512 Correct Predictions, Batch Accuracy: 11.33%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 25%|███████████                                 | 5/20 [00:00<00:02,  6.81it/s]

Model output (probs) shape: torch.Size([512, 10])
Batch 4: 54/512 Correct Predictions, Batch Accuracy: 10.55%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 35%|███████████████▍                            | 7/20 [00:01<00:01,  6.80it/s]

Batch 5: 69/512 Correct Predictions, Batch Accuracy: 13.48%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 6: 57/512 Correct Predictions, Batch Accuracy: 11.13%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 40%|█████████████████▌                          | 8/20 [00:01<00:01,  6.83it/s]

Model output (probs) shape: torch.Size([512, 10])
Batch 7: 58/512 Correct Predictions, Batch Accuracy: 11.33%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 50%|█████████████████████▌                     | 10/20 [00:01<00:01,  6.83it/s]

Batch 8: 62/512 Correct Predictions, Batch Accuracy: 12.11%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 9: 66/512 Correct Predictions, Batch Accuracy: 12.89%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 55%|███████████████████████▋                   | 11/20 [00:01<00:01,  6.79it/s]

Model output (probs) shape: torch.Size([512, 10])
Batch 10: 49/512 Correct Predictions, Batch Accuracy: 9.57%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 65%|███████████████████████████▉               | 13/20 [00:01<00:01,  6.78it/s]

Batch 11: 52/512 Correct Predictions, Batch Accuracy: 10.16%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 12: 58/512 Correct Predictions, Batch Accuracy: 11.33%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 70%|██████████████████████████████             | 14/20 [00:02<00:00,  6.75it/s]

Model output (probs) shape: torch.Size([512, 10])
Batch 13: 55/512 Correct Predictions, Batch Accuracy: 10.74%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 80%|██████████████████████████████████▍        | 16/20 [00:02<00:00,  6.74it/s]

Batch 14: 50/512 Correct Predictions, Batch Accuracy: 9.77%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 15: 52/512 Correct Predictions, Batch Accuracy: 10.16%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])


 85%|████████████████████████████████████▌      | 17/20 [00:02<00:00,  6.75it/s]

Model output (probs) shape: torch.Size([512, 10])
Batch 16: 52/512 Correct Predictions, Batch Accuracy: 10.16%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 95%|████████████████████████████████████████▊  | 19/20 [00:02<00:00,  6.77it/s]

Batch 17: 48/512 Correct Predictions, Batch Accuracy: 9.38%
Original test_x shape: (512, 576, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 576, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 18: 67/512 Correct Predictions, Batch Accuracy: 13.09%
Original test_x shape: (272, 576, 1), test_y shape: (272,)
Device-adjusted test_x shape: torch.Size([272, 576, 1]), test_y shape: torch.Size([272])
Model output (probs) shape: torch.Size([272, 10])


100%|███████████████████████████████████████████| 20/20 [00:02<00:00,  6.92it/s]


Batch 19: 25/272 Correct Predictions, Batch Accuracy: 9.19%
Overall: 1132/10000 Correct Predictions, Overall Accuracy: 11.32%
Total validation samples: 20


 10%|████▍                                       | 2/20 [00:00<00:01, 10.61it/s]

Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 0: 75/512 Correct Predictions, Batch Accuracy: 14.65%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 1: 72/512 Correct Predictions, Batch Accuracy: 14.06%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 20%|████████▊                                   | 4/20 [00:00<00:01, 10.63it/s]

Model output (probs) shape: torch.Size([512, 10])
Batch 2: 82/512 Correct Predictions, Batch Accuracy: 16.02%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 3: 68/512 Correct Predictions, Batch Accuracy: 13.28%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 30%|█████████████▏                              | 6/20 [00:00<00:01, 10.61it/s]

Batch 4: 72/512 Correct Predictions, Batch Accuracy: 14.06%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 5: 75/512 Correct Predictions, Batch Accuracy: 14.65%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 6: 76/512 Correct Predictions, Batch Accuracy: 14.84%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 40%|█████████████████▌                          | 8/20 [00:00<00:01, 10.65it/s]

Model output (probs) shape: torch.Size([512, 10])
Batch 7: 82/512 Correct Predictions, Batch Accuracy: 16.02%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 8: 77/512 Correct Predictions, Batch Accuracy: 15.04%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 60%|█████████████████████████▊                 | 12/20 [00:01<00:00, 10.62it/s]

Batch 9: 76/512 Correct Predictions, Batch Accuracy: 14.84%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 10: 67/512 Correct Predictions, Batch Accuracy: 13.09%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 11: 76/512 Correct Predictions, Batch Accuracy: 14.84%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 70%|██████████████████████████████             | 14/20 [00:01<00:00, 10.59it/s]

Model output (probs) shape: torch.Size([512, 10])
Batch 12: 83/512 Correct Predictions, Batch Accuracy: 16.21%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 13: 61/512 Correct Predictions, Batch Accuracy: 11.91%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 80%|██████████████████████████████████▍        | 16/20 [00:01<00:00, 10.63it/s]

Batch 14: 60/512 Correct Predictions, Batch Accuracy: 11.72%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 15: 74/512 Correct Predictions, Batch Accuracy: 14.45%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 16: 76/512 Correct Predictions, Batch Accuracy: 14.84%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])


 90%|██████████████████████████████████████▋    | 18/20 [00:01<00:00, 10.62it/s]

Model output (probs) shape: torch.Size([512, 10])
Batch 17: 78/512 Correct Predictions, Batch Accuracy: 15.23%
Original test_x shape: (512, 361, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 361, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 18: 68/512 Correct Predictions, Batch Accuracy: 13.28%
Original test_x shape: (272, 361, 1), test_y shape: (272,)
Device-adjusted test_x shape: torch.Size([272, 361, 1]), test_y shape: torch.Size([272])
Model output (probs) shape: torch.Size([272, 10])


100%|███████████████████████████████████████████| 20/20 [00:01<00:00, 10.81it/s]


Batch 19: 41/272 Correct Predictions, Batch Accuracy: 15.07%
Overall: 1439/10000 Correct Predictions, Overall Accuracy: 14.39%
Total validation samples: 20


 10%|████▍                                       | 2/20 [00:00<00:01, 14.87it/s]

Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 0: 406/512 Correct Predictions, Batch Accuracy: 79.30%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 1: 384/512 Correct Predictions, Batch Accuracy: 75.00%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 20%|████████▊                                   | 4/20 [00:00<00:01, 14.63it/s]

Batch 2: 385/512 Correct Predictions, Batch Accuracy: 75.20%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 3: 387/512 Correct Predictions, Batch Accuracy: 75.59%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 4: 375/512 Correct Predictions, Batch Accuracy: 73.24%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 40%|█████████████████▌                          | 8/20 [00:00<00:00, 14.58it/s]

Batch 5: 398/512 Correct Predictions, Batch Accuracy: 77.73%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 6: 396/512 Correct Predictions, Batch Accuracy: 77.34%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 7: 404/512 Correct Predictions, Batch Accuracy: 78.91%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 50%|█████████████████████▌                     | 10/20 [00:00<00:00, 14.55it/s]

Batch 8: 401/512 Correct Predictions, Batch Accuracy: 78.32%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 9: 418/512 Correct Predictions, Batch Accuracy: 81.64%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 10: 450/512 Correct Predictions, Batch Accuracy: 87.89%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 70%|██████████████████████████████             | 14/20 [00:00<00:00, 14.51it/s]

Batch 11: 414/512 Correct Predictions, Batch Accuracy: 80.86%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 12: 434/512 Correct Predictions, Batch Accuracy: 84.77%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 13: 445/512 Correct Predictions, Batch Accuracy: 86.91%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


 80%|██████████████████████████████████▍        | 16/20 [00:01<00:00, 14.53it/s]

Batch 14: 436/512 Correct Predictions, Batch Accuracy: 85.16%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 15: 419/512 Correct Predictions, Batch Accuracy: 81.84%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 16: 435/512 Correct Predictions, Batch Accuracy: 84.96%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])


100%|███████████████████████████████████████████| 20/20 [00:01<00:00, 14.86it/s]

Batch 17: 462/512 Correct Predictions, Batch Accuracy: 90.23%
Original test_x shape: (512, 256, 1), test_y shape: (512,)
Device-adjusted test_x shape: torch.Size([512, 256, 1]), test_y shape: torch.Size([512])
Model output (probs) shape: torch.Size([512, 10])
Batch 18: 421/512 Correct Predictions, Batch Accuracy: 82.23%
Original test_x shape: (272, 256, 1), test_y shape: (272,)
Device-adjusted test_x shape: torch.Size([272, 256, 1]), test_y shape: torch.Size([272])
Model output (probs) shape: torch.Size([272, 10])
Batch 19: 223/272 Correct Predictions, Batch Accuracy: 81.99%
Overall: 8093/10000 Correct Predictions, Overall Accuracy: 80.93%
Validation Set 1 Accuracy: 11.32%
Validation Set 2 Accuracy: 14.39%
Validation Set 3 Accuracy: 80.93%





---
# Section 2: Global Workspace

As we have seen, deep learning has shifted towards structured models with specialized modules that enhance scalability and generalization. But we can go one step further. Inspired by the 1980s AI focus on modular architectures and the Global Workspace Theory from cognitive neuroscience, the approach we are going to analyse in this section employs a shared global workspace for module coordination. It promotes flexibility and systematic generalization by allowing dynamic interactions among specialized modules. This model emphasizes the importance of having a number of sparsely communicating specialist modules interact via a shared working memory, aiming to achieve coherent and efficient behavior across the system. 

RIMs leverage a self-attention mechanism to enable information sharing among specialist modules, traditionally through pairwise interactions where each module attends to every other. This new approach, however, introduces a shared workspace with limited capacity to streamline this process. At each computational step, specialist modules compete for the opportunity to write to this shared workspace. Subsequently, the information stored in the workspace is broadcasted to all specialists simultaneously, enhancing coordination and information flow among the modules without the need for direct pairwise communication.

## Coding Exercise: Creating a Shared Workspace

Specialists compete to write their information into the shared workspace. This process is guided by a key-query-value attention mechanism, where the competition is realized through attention scores determining which specialists' information is most critical to be updated in the workspace.

In [10]:
torch.manual_seed(42)  # Ensure reproducibility

<torch._C.Generator at 0x7f00370c67f0>

In [11]:
class SharedWorkspace(nn.Module):
    
    def __init__(self, num_specialists, hidden_dim, num_memory_slots, memory_slot_dim):
        #################################################
        ## TODO for students: fill in the missing variables ##
        # Fill out function and remove
        raise NotImplementedError("Student exercise: fill in the missing variables")
        #################################################
        super().__init__()
        self.num_specialists = num_specialists
        self.hidden_dim = hidden_dim
        self.num_memory_slots = num_memory_slots
        self.memory_slot_dim = memory_slot_dim
        self.workspace_memory = nn.Parameter(torch.randn(num_memory_slots, memory_slot_dim))
        
        # Attention mechanism components for writing to the workspace
        self.key = ...
        self.query = ...
        self.value = nn.Linear(hidden_dim, memory_slot_dim)
    
    def write_to_workspace(self, specialists_states):
        #################################################
        ## TODO for students: fill in the missing variables ##
        # Fill out function and remove
        raise NotImplementedError("Student exercise: fill in the missing variables")
        #################################################
        # Flatten specialists' states if they're not already
        specialists_states = specialists_states.view(-1, self.hidden_dim)
        
        # Compute key, query, and value
        keys = self.key(specialists_states)
        query = self.query(self.workspace_memory)
        values = self.value(specialists_states)
        
        # Compute attention scores and apply softmax
        attention_scores = torch.matmul(query, keys.transpose(-2, -1)) / (self.memory_slot_dim ** 0.5)
        attention_probs = ...
        
        # Update workspace memory with weighted sum of values
        updated_memory = torch.matmul(attention_probs, values)
        self.workspace_memory = nn.Parameter(updated_memory)
        
        return self.workspace_memory

    def forward(self, specialists_states):
        #################################################
        ## TODO for students: fill in the missing variables ##
        # Fill out function and remove
        raise NotImplementedError("Student exercise: fill in the missing variables")
        #################################################
        updated_memory = ...
        return updated_memory

# Example parameters
num_specialists = 5
hidden_dim = 10
num_memory_slots = 4
memory_slot_dim = 6

# Generate deterministic specialists' states
specialists_states = torch.randn(num_specialists, hidden_dim)

# Uncomment the code below to test your function
# workspace = SharedWorkspace(num_specialists, hidden_dim, num_memory_slots, memory_slot_dim)
# expected_output = workspace.forward(specialists_states)
# print("Expected Output:", expected_output)

In [12]:
# to remove solution

class SharedWorkspace(nn.Module):
    
    def __init__(self, num_specialists, hidden_dim, num_memory_slots, memory_slot_dim):
        super().__init__()
        self.num_specialists = num_specialists
        self.hidden_dim = hidden_dim
        self.num_memory_slots = num_memory_slots
        self.memory_slot_dim = memory_slot_dim
        self.workspace_memory = nn.Parameter(torch.randn(num_memory_slots, memory_slot_dim))

        # Attention mechanism components for writing to the workspace
        self.key = nn.Linear(hidden_dim, memory_slot_dim)
        self.query = nn.Linear(memory_slot_dim, memory_slot_dim)
        self.value = nn.Linear(hidden_dim, memory_slot_dim)

    def write_to_workspace(self, specialists_states):
        # Flatten specialists' states if they're not already
        specialists_states = specialists_states.view(-1, self.hidden_dim)

        # Compute key, query, and value
        keys = self.key(specialists_states)
        query = self.query(self.workspace_memory)
        values = self.value(specialists_states)

        # Compute attention scores and apply softmax
        attention_scores = torch.matmul(query, keys.transpose(-2, -1)) / (self.memory_slot_dim ** 0.5)
        attention_probs = F.softmax(attention_scores, dim=-1)

        # Update workspace memory with weighted sum of values
        updated_memory = torch.matmul(attention_probs, values)
        self.workspace_memory = nn.Parameter(updated_memory)

        return self.workspace_memory

    def forward(self, specialists_states):
        updated_memory = self.write_to_workspace(specialists_states)
        return updated_memory

# Example parameters
num_specialists = 5
hidden_dim = 10
num_memory_slots = 4
memory_slot_dim = 6

# Generate deterministic specialists' states
specialists_states = torch.randn(num_specialists, hidden_dim)

# Uncomment the code below to test your function
# workspace = SharedWorkspace(num_specialists, hidden_dim, num_memory_slots, memory_slot_dim)
# expected_output = workspace.forward(specialists_states)
# print("Expected Output:", expected_output)

After updating the shared workspace with the most critical signals, this information is then broadcast back to all specialists. Each specialist updates its state using this broadcast information, which can involve an attention mechanism for consolidation and an update function (like an LSTM or GRU step) based on the new combined state. Let's add this method!

In [13]:
def broadcast_from_workspace(self, specialists_states):
    # Broadcast updated memory to specialists
    broadcast_query = self.query(specialists_states).view(self.num_specialists, -1, self.memory_slot_dim)
    broadcast_keys = self.key(self.workspace_memory).unsqueeze(0).repeat(self.num_specialists, 1, 1)

    # Compute attention scores for broadcasting
    broadcast_attention_scores = torch.matmul(broadcast_query, broadcast_keys.transpose(-2, -1)) / (self.memory_slot_dim ** 0.5)
    broadcast_attention_probs = F.softmax(broadcast_attention_scores, dim=-1)

    # Update specialists' states with attention-weighted memory information
    broadcast_values = self.value(self.workspace_memory).unsqueeze(0).repeat(self.num_specialists, 1, 1)
    updated_states = torch.matmul(broadcast_attention_probs, broadcast_values)

    return updated_states.view_as(specialists_states)

# Assign the method to the class
SharedWorkspace.broadcast_from_workspace = broadcast_from_workspace

This approach modularizes the shared workspace functionality, ensuring the specialists' states are first aggregated in a competitive manner into the workspace, followed by an efficient distribution of this consolidated information. This mechanism allows for dynamic filtering based on the current context and enhances the model's ability to generalize from past experiences by focusing on the most relevant signals at each computational step. To integrate this into a full system, you would need to instantiate this SharedWorkspace within your RIM architecture, ensuring that the initial representations of specialists are processed (Step 1), passed to the SharedWorkspace for competition and update (Step 2), and then the updated information is broadcast back to the specialists (Step3).

---
# Section 3: a toy model for illustrating GNW 

In [14]:
class SimpleGNWModel:
    def __init__(self, num_nodes=5):
        self.num_nodes = num_nodes
        self.network = nx.erdos_renyi_graph(n=num_nodes, p=0.5)
        self.activations = {node: False for node in self.network.nodes}

    def activate_node(self):
        selected_node = random.choice(list(self.network.nodes))
        self.activations[selected_node] = True

        # Simulate global broadcast
        for neighbor in self.network.neighbors(selected_node):
            self.activations[neighbor] = True

    def reset_activations(self):
        self.activations = {node: False for node in self.network.nodes}

    def draw_network(self):
        color_map = ['green' if self.activations[node] else 'red' for node in self.network.nodes]
        nx.draw(self.network, node_color=color_map, with_labels=True, node_size=700)
        plt.show()

# Create a GNW model instance
gnw_model = SimpleGNWModel()

# Button to activate a node
activate_button = widgets.Button(description='Activate Node')

# Button to reset activations
reset_button = widgets.Button(description='Reset')

# Output area for the network graph
output_area = widgets.Output()

def on_activate_clicked(b):
    with output_area:
        output_area.clear_output(wait=True)
        gnw_model.activate_node()
        gnw_model.draw_network()

def on_reset_clicked(b):
    with output_area:
        output_area.clear_output(wait=True)
        gnw_model.reset_activations()
        gnw_model.draw_network()

activate_button.on_click(on_activate_clicked)
reset_button.on_click(on_reset_clicked)

display(widgets.VBox([activate_button, reset_button, output_area]))

VBox(children=(Button(description='Activate Node', style=ButtonStyle()), Button(description='Reset', style=But…

---
# Section 4: TBD

---
# Section 5: Fill in the blank

---
# Section 6: Toy model

---
# Section 7: MNIST digits