# W2D5Tutorial1

**Week 2, Day 5: Mysteries**

**By Neuromatch Academy**

__Content creators:__ Names & Surnames

__Content reviewers:__ Names & Surnames

__Production editors:__ Names & Surnames

<br>

Acknowledgments: [ACKNOWLEDGMENT_INFORMATION]


___


# Tutorial Objectives

*Estimated timing of tutorial: [insert estimated duration of whole tutorial in minutes]*

In this tutorial, you will observe how performance degrades as testing data distribution strays from training distribution.


In [None]:
# @title Tutorial slides

# @markdown These are the slides for the videos in all tutorials today


## Uncomment the code below to test your function

#from IPython.display import IFrame
#link_id = "<YOUR_LINK_ID_HERE>"

print("If you want to download the slides: 'Link to the slides'")
      # Example: https://osf.io/download/{link_id}/

#IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{link_id}/?direct%26mode=render", width=854, height=480)

---
# Setup



In [None]:
# @title Install dependencies
# @markdown

!pip install numpy matplotlib Pillow torch torchvision transformers ipywidgets gradio trdg scikit-learn networkx pickleshare

In [None]:
# @title Import dependencies
# @markdown Enhanced organization and clarity in import statements for better readability

# Standard libraries for essential operations, random number generation, logging, and file management
import logging
import os
import random
import requests

# Data handling libraries
import pickleshare

# Scientific computing and statistical libraries for advanced mathematical operations and array handling
import numpy as np
from numpy.linalg import inv
from scipy.special import logsumexp
from scipy.stats import multivariate_normal

# Deep learning libraries for constructing and training neural networks
import torch
import torch.nn as nn
import torch.nn.functional as F

# Image processing and visualization libraries
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from mpl_toolkits.mplot3d import Axes3D

# Libraries for interactive elements in notebooks and web applications
from IPython.display import IFrame
import gradio as gr
import ipywidgets as widgets
from ipywidgets import interact, IntSlider

# Graph analysis library for the creation, manipulation, and study of the structure of complex networks
import networkx as nx

# Progress monitoring
from tqdm import tqdm

# Standard Libraries
import random
import copy

In [None]:
# @title Figure settings
# @markdown

logging.getLogger('matplotlib.font_manager').disabled = True

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

In [None]:
# @title Set device (GPU or CPU). Execute `set_device()`
# especially if torch modules used.
# @markdown

# inform the user if the notebook uses GPU or CPU.

def set_device():
    """
    Determines and sets the computational device for PyTorch operations based on the availability of a CUDA-capable GPU.

    Outputs:
    - device (str): The device that PyTorch will use for computations ('cuda' or 'cpu'). This string can be directly used
    in PyTorch operations to specify the device.
    """

    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device != "cuda":
        print("GPU is not enabled in this notebook. \n"
              "If you want to enable it, in the menu under `Runtime` -> \n"
              "`Hardware accelerator.` and select `GPU` from the dropdown menu")
    else:
        print("GPU is enabled in this notebook. \n"
              "If you want to disable it, in the menu under `Runtime` -> \n"
              "`Hardware accelerator.` and select `None` from the dropdown menu")

    return device

# Section 1: Recurrent Independent Mechanisms

The crucial idea behind this section is that machine learning aims to capture the modular structure of the physical world, where complexity emerges from simpler, independently evolving subsystems. This concept aligns with causal inference, suggesting that understanding and modeling the world involves identifying and integrating these autonomous mechanisms. These mechanisms, which interact sparsely, maintain their functionality even amidst changes in others, highlighting their robustness. Recurrent Independent Mechanisms (RIMs) embody this principle by operating mostly independently, occasionally interacting through an attention-based mechanism for efficient and dynamic information processing. This approach [https://arxiv.org/pdf/1909.10893.pdf] suggests a preference for models that can capture the independence and sparse interactions of mechanisms, potentially leading to more adaptable and generalizable AI systems.

In [None]:
# @title Data retrieval
# @markdown

# URL of the repository to clone
!git clone https://github.com/SamueleBolotta/RIMs-Sequential-MNIST
%cd RIMs-Sequential-MNIST

# Imports
from data import MnistData
from networks import MnistModel, LSTM

# Function to download files
def download_file(url, destination):
    print(f"Starting to download {url} to {destination}")
    response = requests.get(url, allow_redirects=True)
    open(destination, 'wb').write(response.content)
    print(f"Successfully downloaded {url} to {destination}")

# Path of the models
model_path = {
    'LSTM': 'lstm_model_dir/lstm_best_model.pt',
    'RIM': 'rim_model_dir/best_model.pt'
}

# URLs of the models
model_urls = {
    'LSTM': 'https://osf.io/4gajq/download',
    'RIM': 'https://osf.io/3squn/download'
}

# Check if model files exist, if not, download them
for model_key, model_url in model_urls.items():
    if not os.path.exists(model_path[model_key]):
        download_file(model_url, model_path[model_key])
        print(f"{model_key} model downloaded.")
    else:
        print(f"{model_key} model already exists. No download needed.")

# RIMs

The model has been trained on the Sequential MNIST datset with individual image size 14x14. To assess its generalization capabilities, it got tested on 16x16 (Validation Set 3), 19x19 (Validation Set 2) and 24x24 images (Validation Set 3).

In [None]:
# Config
config = {
    'cuda': True,
    'epochs': 200,
    'batch_size': 64,
    'hidden_size': 100,
    'input_size': 1,
    'model': 'RIM', # Or 'RIM' for the MnistModel
    'train': False, # Set to False to load the saved model
    'num_units': 6,
    'rnn_cell': 'LSTM',
    'key_size_input': 64,
    'value_size_input': 400,
    'query_size_input': 64,
    'num_input_heads': 1,
    'num_comm_heads': 4,
    'input_dropout': 0.1,
    'comm_dropout': 0.1,
    'key_size_comm': 32,
    'value_size_comm': 100,
    'query_size_comm': 32,
    'k': 4,
    'size': 14,
    'loadsaved': 1, # Ensure this is 1 to load saved model
    'log_dir': 'rim_model_dir'
}

# Choose the model
model = MnistModel(config)  # Instantiating MnistModel (RIM) with config
model_directory = model_path['RIM']

# Set device
device = set_device()
model.to(device)

# Set the map_location based on whether CUDA is available
map_location = 'cuda' if torch.cuda.is_available() and config['cuda'] else 'cpu'

# Use torch.load with the map_location parameter
saved = torch.load(model_directory, map_location=map_location)
model.load_state_dict(saved['net'])

# Data
data = MnistData(config['batch_size'], (config['size'], config['size']), config['k'])

# Evaluation function
def test_model(model, loader, func):
    accuracy = 0
    loss = 0
    model.eval()

    print(f"Total validation samples: {loader.val_len()}")  # Print total number of validation samples

    with torch.no_grad():
        for i in tqdm(range(loader.val_len())):
            test_x, test_y = func(i)
            test_x = model.to_device(test_x)
            test_y = model.to_device(test_y).long()
            print(f"Sample {i}: test_x device: {test_x.device}, test_y device: {test_y.device}")

            probs  = model(test_x)
            preds = torch.argmax(probs, dim=1)
            correct = preds == test_y
            accuracy += correct.sum().item()

    accuracy /= 100  # Use the total number of items in the validation set for accuracy calculation
    return accuracy


# Evaluate on all three validation sets
validation_functions = [data.val_get1, data.val_get2, data.val_get3]
validation_accuracies_rim = []

print(f"Model: {config['model']}, Device: {device}")
print(f"Configuration: {config}")

for func in validation_functions:
    accuracy = test_model(model, data, func)
    validation_accuracies_rim.append(accuracy)

# LSTM

Let's now repeat the same process with LSTMs.

In [None]:
# Config
config = {
    'cuda': True,
    'epochs': 200,
    'batch_size': 64,
    'hidden_size': 100,
    'input_size': 1,
    'model': 'LSTM',
    'train': False, # Set to False to load the saved model
    'num_units': 6,
    'rnn_cell': 'LSTM',
    'key_size_input': 64,
    'value_size_input': 400,
    'query_size_input': 64,
    'num_input_heads': 1,
    'num_comm_heads': 4,
    'input_dropout': 0.1,
    'comm_dropout': 0.1,
    'key_size_comm': 32,
    'value_size_comm': 100,
    'query_size_comm': 32,
    'k': 4,
    'size': 14,
    'loadsaved': 1, # Ensure this is 1 to load saved model
    'log_dir': 'rim_model_dir'
}

model = LSTM(config)  # Instantiating LSTM with config
model_directory = model_path['LSTM']

# Set device
device = set_device()
model.to(device)

# Set the map_location based on whether CUDA is available
map_location = 'cuda' if torch.cuda.is_available() and config['cuda'] else 'cpu'

# Use torch.load with the map_location parameter
saved = torch.load(model_directory, map_location=map_location)
model.load_state_dict(saved['net'])

# Data
data = MnistData(config['batch_size'], (config['size'], config['size']), config['k'])

# Evaluation function
def test_model(model, loader, func):
    accuracy = 0
    loss = 0
    model.eval()

    print(f"Total validation samples: {loader.val_len()}")  # Print total number of validation samples

    with torch.no_grad():
        for i in tqdm(range(loader.val_len())):
            test_x, test_y = func(i)
            test_x = model.to_device(test_x)
            test_y = model.to_device(test_y).long()
            probs  = model(test_x)
            preds = torch.argmax(probs, dim=1)
            correct = preds == test_y
            accuracy += correct.sum().item()

    accuracy /= 100  # Use the total number of items in the validation set for accuracy calculation
    return accuracy


# Evaluate on all three validation sets
validation_functions = [data.val_get1, data.val_get2, data.val_get3]
validation_accuracies_lstm = []

print(f"Model: {config['model']}, Device: {device}")
print(f"Configuration: {config}")

for func in validation_functions:
    accuracy = test_model(model, data, func)
    validation_accuracies_lstm.append(accuracy)

In [None]:
# Print accuracies for all validation sets (RIMs)
for i, accuracy in enumerate(validation_accuracies_rim, 1):
    print(f'Validation Set {i} Accuracy (RIMs): {accuracy:.2f}%')

# Print accuracies for all validation sets (LSTM)
for i, accuracy in enumerate(validation_accuracies_lstm, 1):
    print(f'Validation Set {i} Accuracy (LSTM): {accuracy:.2f}%')

As you can see, the accuracy on 16x16 images is not extremely different. However, RIMs generalize way more robustly to 19x19 and 24x24 images. To achieve this, the authors built recurrent networks that are modular in nature, with each module being independent of the other modules and only interacting sparsely through attention. In this way, each module can learn different aspects of the environment and is only responsible for ensuring similar performance on the same aspect of a different environment.

---
# Section 2: Global Workspace

As we have seen, deep learning has shifted towards structured models with specialized modules that enhance scalability and generalization. But we can go one step further. Inspired by the 1980s AI focus on modular architectures and the Global Workspace Theory from cognitive neuroscience, the approach [https://arxiv.org/pdf/2103.01197.pdf] we are going to analyse in this section employs a shared global workspace for module coordination. It promotes flexibility and systematic generalization by allowing dynamic interactions among specialized modules. This model emphasizes the importance of having a number of sparsely communicating specialist modules interact via a shared working memory, aiming to achieve coherent and efficient behavior across the system. 

RIMs leverage a self-attention mechanism to enable information sharing among specialist modules, traditionally through pairwise interactions where each module attends to every other. This new approach, however, introduces a shared workspace with limited capacity to streamline this process. At each computational step, specialist modules compete for the opportunity to write to this shared workspace. Subsequently, the information stored in the workspace is broadcasted to all specialists simultaneously, enhancing coordination and information flow among the modules without the need for direct pairwise communication.

## Coding Exercise: Creating a Shared Workspace

Specialists compete to write their information into the shared workspace. This process is guided by a key-query-value attention mechanism, where the competition is realized through attention scores determining which specialists' information is most critical to be updated in the workspace.

In [None]:
torch.manual_seed(42)  # Ensure reproducibility

In [None]:
class SharedWorkspace(nn.Module):

    def __init__(self, num_specialists, hidden_dim, num_memory_slots, memory_slot_dim):
        #################################################
        ## TODO for students: fill in the missing variables ##
        # Fill out function and remove
        raise NotImplementedError("Student exercise: fill in the missing variables")
        #################################################
        super().__init__()
        self.num_specialists = num_specialists
        self.hidden_dim = hidden_dim
        self.num_memory_slots = num_memory_slots
        self.memory_slot_dim = memory_slot_dim
        self.workspace_memory = nn.Parameter(torch.randn(num_memory_slots, memory_slot_dim))

        # Attention mechanism components for writing to the workspace
        self.key = ...
        self.query = ...
        self.value = nn.Linear(hidden_dim, memory_slot_dim)

    def write_to_workspace(self, specialists_states):
        #################################################
        ## TODO for students: fill in the missing variables ##
        # Fill out function and remove
        raise NotImplementedError("Student exercise: fill in the missing variables")
        #################################################
        # Flatten specialists' states if they're not already
        specialists_states = specialists_states.view(-1, self.hidden_dim)

        # Compute key, query, and value
        keys = self.key(specialists_states)
        query = self.query(self.workspace_memory)
        values = self.value(specialists_states)

        # Compute attention scores and apply softmax
        attention_scores = torch.matmul(query, keys.transpose(-2, -1)) / (self.memory_slot_dim ** 0.5)
        attention_probs = ...

        # Update workspace memory with weighted sum of values
        updated_memory = torch.matmul(attention_probs, values)
        self.workspace_memory = nn.Parameter(updated_memory)

        return self.workspace_memory

    def forward(self, specialists_states):
        #################################################
        ## TODO for students: fill in the missing variables ##
        # Fill out function and remove
        raise NotImplementedError("Student exercise: fill in the missing variables")
        #################################################
        updated_memory = ...
        return updated_memory

In [None]:
# to remove solution

class SharedWorkspace(nn.Module):

    def __init__(self, num_specialists, hidden_dim, num_memory_slots, memory_slot_dim):
        super().__init__()
        self.num_specialists = num_specialists
        self.hidden_dim = hidden_dim
        self.num_memory_slots = num_memory_slots
        self.memory_slot_dim = memory_slot_dim
        self.workspace_memory = nn.Parameter(torch.randn(num_memory_slots, memory_slot_dim))

        # Attention mechanism components for writing to the workspace
        self.key = nn.Linear(hidden_dim, memory_slot_dim)
        self.query = nn.Linear(memory_slot_dim, memory_slot_dim)
        self.value = nn.Linear(hidden_dim, memory_slot_dim)

    def write_to_workspace(self, specialists_states):
        # Flatten specialists' states if they're not already
        specialists_states = specialists_states.view(-1, self.hidden_dim)

        # Compute key, query, and value
        keys = self.key(specialists_states)
        query = self.query(self.workspace_memory)
        values = self.value(specialists_states)

        # Compute attention scores and apply softmax
        attention_scores = torch.matmul(query, keys.transpose(-2, -1)) / (self.memory_slot_dim ** 0.5)
        attention_probs = F.softmax(attention_scores, dim=-1)

        # Update workspace memory with weighted sum of values
        updated_memory = torch.matmul(attention_probs, values)
        self.workspace_memory = nn.Parameter(updated_memory)

        return self.workspace_memory

    def forward(self, specialists_states):
        updated_memory = self.write_to_workspace(specialists_states)
        return updated_memory

In [None]:
# Example parameters
num_specialists = 5
hidden_dim = 10
num_memory_slots = 4
memory_slot_dim = 6

# Generate deterministic specialists' states
specialists_states = torch.randn(num_specialists, hidden_dim)

workspace = SharedWorkspace(num_specialists, hidden_dim, num_memory_slots, memory_slot_dim)
expected_output = workspace.forward(specialists_states)
print("Expected Output:", expected_output)

After updating the shared workspace with the most critical signals, this information is then broadcast back to all specialists. Each specialist updates its state using this broadcast information, which can involve an attention mechanism for consolidation and an update function (like an LSTM or GRU step) based on the new combined state. Let's add this method!

In [None]:
def broadcast_from_workspace(self, specialists_states):
    # Broadcast updated memory to specialists
    broadcast_query = self.query(specialists_states).view(self.num_specialists, -1, self.memory_slot_dim)
    broadcast_keys = self.key(self.workspace_memory).unsqueeze(0).repeat(self.num_specialists, 1, 1)

    # Compute attention scores for broadcasting
    broadcast_attention_scores = torch.matmul(broadcast_query, broadcast_keys.transpose(-2, -1)) / (self.memory_slot_dim ** 0.5)
    broadcast_attention_probs = F.softmax(broadcast_attention_scores, dim=-1)

    # Update specialists' states with attention-weighted memory information
    broadcast_values = self.value(self.workspace_memory).unsqueeze(0).repeat(self.num_specialists, 1, 1)
    updated_states = torch.matmul(broadcast_attention_probs, broadcast_values)

    return updated_states.view_as(specialists_states)

# Assign the method to the class
SharedWorkspace.broadcast_from_workspace = broadcast_from_workspace

This approach modularizes the shared workspace functionality, ensuring the specialists' states are first aggregated in a competitive manner into the workspace, followed by an efficient distribution of this consolidated information. This mechanism allows for dynamic filtering based on the current context and enhances the model's ability to generalize from past experiences by focusing on the most relevant signals at each computational step. To integrate this into a full system, you would need to instantiate this SharedWorkspace within your RIM architecture, ensuring that the initial representations of specialists are processed (Step 1), passed to the SharedWorkspace for competition and update (Step 2), and then the updated information is broadcast back to the specialists (Step3).

---
# Section 3: a toy model for illustrating GNW 

Now, we outline a `SimpleGNWModel` class for simulating node activation within a network. It uses an Erdős-Rényi graph to model connections and includes methods to activate nodes, reset the network, and visualize the results. This setup provides an interactive introduction to network dynamics, making it easy to observe how activations spread across a network.

In the network visualization, the colors distinguish between active and inactive nodes. Active nodes are colored green, indicating they have been activated either directly or through their connection to another activated node. Inactive nodes are colored red, showing they have yet to be activated in the simulation. 

In [None]:
class SimpleGNWModel:
    def __init__(self, num_nodes=5):
        self.num_nodes = num_nodes
        self.network = nx.erdos_renyi_graph(n=num_nodes, p=0.5)
        self.activations = {node: False for node in self.network.nodes}

    def activate_node(self):
        selected_node = random.choice(list(self.network.nodes))
        self.activations[selected_node] = True

        # Simulate global broadcast
        for neighbor in self.network.neighbors(selected_node):
            self.activations[neighbor] = True

    def reset_activations(self):
        self.activations = {node: False for node in self.network.nodes}

    def draw_network(self):
        color_map = ['green' if self.activations[node] else 'red' for node in self.network.nodes]
        nx.draw(self.network, node_color=color_map, with_labels=True, node_size=700)
        plt.show()

# Create a GNW model instance
gnw_model = SimpleGNWModel()

# Button to activate a node
activate_button = widgets.Button(description='Activate Node')

# Button to reset activations
reset_button = widgets.Button(description='Reset')

# Output area for the network graph
output_area = widgets.Output()

def on_activate_clicked(b):
    with output_area:
        output_area.clear_output(wait=True)
        gnw_model.activate_node()
        gnw_model.draw_network()

def on_reset_clicked(b):
    with output_area:
        output_area.clear_output(wait=True)
        gnw_model.reset_activations()
        gnw_model.draw_network()

activate_button.on_click(on_activate_clicked)
reset_button.on_click(on_reset_clicked)

display(widgets.VBox([activate_button, reset_button, output_area]))

When you click on the "Activate Node" button, the code randomly selects one node in the network and activates it, changing its status to active (if it was inactive). This action also triggers a "global broadcast" effect, meaning that all of the selected node's immediate neighbors are activated as well. The network visualization then updates to reflect these changes: the activated nodes are colored green, while any nodes that remain inactive are colored red. This process visually demonstrates how activation can spread through a network, highlighting the connections between nodes and the potential influence of a single node's activation on its neighbors.

## Ignition 

The idea of ignition in cognitive science refers to the moment a stimulus becomes strong enough to trigger a widespread activation across the network, simulating the threshold of conscious awareness.

The following class is equipped with a damping factor that controls how activation spreads through the network, mirroring the dampening effects seen in biological networks where not all signals lead to widespread activation. The network itself is constructed with nodes and directed edges to represent the flow of information, crucial for understanding how different parts of a cognitive system can ignite and sustain activation that leads to consciousness.

By simulating the propagation of a stimulus through the network for given durations, the model visualizes the ignition process. Activation levels are calculated at each step, showing how initial stimuli can either dissipate or amplify to achieve ignition across the network. 


In [None]:
class GlobalWorkspaceNetwork:
    """
    A class to model and visualize a simplified global workspace network.
    """

    def __init__(self, damping_factor=0.1):
        """
        Initialize the network with a specified damping factor.

        :param damping_factor: A float representing the damping factor applied to activation propagation.
        """
        if not 0 <= damping_factor <= 1:
            raise ValueError("Damping factor must be between 0 and 1.")

        self.network = nx.DiGraph()
        self.damping_factor = damping_factor
        self.setup_network()

    def setup_network(self, nodes=None, edges=None):
        """
        Set up the network structure.

        :param nodes: A list of nodes for the network. If None, a default set is used.
        :param edges: A list of tuples representing edges between nodes. If None, a default set is used.
        """
        default_nodes = ["sensory_input", "processing_module", "global_workspace", "visual_cortex"]
        default_edges = [
            ("sensory_input", "processing_module"),
            ("processing_module", "global_workspace"),
            ("global_workspace", "visual_cortex"),
            ("visual_cortex", "processing_module")
        ]
        self.network.add_nodes_from(nodes if nodes is not None else default_nodes)
        self.network.add_edges_from(edges if edges is not None else default_edges)

    def propagate_stimulus(self, duration):
        """
        Simulate the propagation of stimulus through the network for a given duration.

        :param duration: The number of steps to simulate.
        :return: A dictionary of activation levels for each node.
        """
        activation_levels = {node: 0 for node in self.network.nodes}
        activation_levels["sensory_input"] = 1

        for step in range(duration):
            new_activation_levels = activation_levels.copy()
            for node in self.network.nodes:
                total_input = sum(activation_levels[pre] for pre in self.network.predecessors(node))
                new_activation_levels[node] += total_input * (1 - self.damping_factor)
            activation_levels = new_activation_levels

        return activation_levels

    def visualize_network(self, activation_levels):
        """
        Visualize the network with node colors representing activation levels.

        :param activation_levels: A dictionary of activation levels for each node.
        """
        pos = nx.spring_layout(self.network, seed=42)  # For consistent layout
        nx.draw(self.network, pos, with_labels=True, node_size=700)
        nx.draw_networkx_nodes(self.network, pos, nodelist=activation_levels.keys(),
        node_color=[activation_levels[n] for n in activation_levels],
        cmap=plt.cm.viridis)
        plt.show()


    def interactive_propagation(self):
        """
        Create an interactive widget to explore the effects of stimulus duration on the network.
        """
        @interact(duration=IntSlider(min=1, max=10, step=1, value=1))
        def update(duration):
            activation_levels = self.propagate_stimulus(duration)
            self.visualize_network(activation_levels)

# Create an instance of the GlobalWorkspaceNetwork
gwn = GlobalWorkspaceNetwork(damping_factor=0.06)  # Feel free to adjust the damping factor

# Use the interactive widget to explore how different durations of stimulus affect the network
gwn.interactive_propagation()

---
# Section 4: Second Order Model

---
# Section 5: HOSS model

The following function is designed for inference within a simplified Bayesian framework, specifically tailored for assessing perceptual states based on observed data. It computes the posterior probabilities of these states and the Kullback-Leibler (KL) divergence between the posterior and prior distributions. This function operates under a model that assumes a flat (or single-layer) Bayesian network, focusing directly on the relationship between perceptual states and observed data. 

In [None]:
def HOSS_evaluate_flat(X, mu, Sigma, Wprior):
    """
    Perform inference on a 2D Bayes net for asymmetric inference on presence vs. absence.

    Parameters:
    X - Observed data
    mu - Means for each perceptual state
    Sigma - Covariance matrix
    Wprior - Prior probabilities of perceptual states

    #Returns:
    post_W - Posterior probabilities of perceptual states
    KL_W - Kullback-Leibler divergence from posterior to prior
    """
    # Prior on perceptual states W
    p_W = Wprior

    # Compute likelihood of observed X for each possible W (P(X|W))
    log_lik_X_W = np.array([np.log(multivariate_normal.pdf(X, mean=mu[m], cov=Sigma)) for m in range(mu.shape[0])])

    # Renormalize to get P(X|W)
    log_p_X_W = log_lik_X_W - logsumexp(log_lik_X_W)

    # Posterior over W (P(W|X=x))
    log_post_W = log_p_X_W + np.log(p_W)
    log_post_W = log_post_W - logsumexp(log_post_W)  # Normalize
    post_W = np.exp(log_post_W)

    # KL divergences
    KL_W = np.sum(post_W * (np.log(post_W) - np.log(p_W)))

    return post_W, KL_W

## Make our stimulus space

The model we are using is grounded in classical "signal detection theory", or SDT for short. SDT is in turn a special case of a Bayesian generative model, in which an arbitrary "evidence" value is drawn from an unknown distribution, and the task of the observer is to infer which distribution this evidence came from.

Let's imagine we have two categories, A and B - for instance, left- and right-tilted visual stimuli. 
The sensory "evidence" can be written as 2D vector, where the first element is evidence for A, and the second element evidence for B:

In [None]:
# Creating the array X with strong evidence for A and weak evidence for B
X = np.array([1.5, 0])

The origin (0,0) represents low activation of both features, consistent with no stimulus (or noise) being presented. Comparing how the model handles inference on stimulus presence vs. absence - detecting, vs. not detecting a stimulus - allows us to capture the classical conscious vs. unconscious contrast in consciousness science.

Let's start by creating our space, and placing three Gaussian distributions on the space that represent the likelihood of observing a pair of features given each of three stimulus classes: leftward tilt (w1), rightward tilt (w2) and noise/nothing (w0).

In [None]:
# Define the grid
xgrid = np.arange(-4, 6.02, 0.02)
X1, X2 = np.meshgrid(xgrid, xgrid)

# Mean and covariance of the distributions
mu = np.array([[0.5, 0.5], [3.5, 0.5], [0.5, 3.5]])
Sigma = np.array([[1, 0], [0, 1]])

# Plotting
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Colors and labels according to the specification
colors = ['green', 'blue', 'red']
labels = ['w0', 'w1', 'w2']

for i, (color, label) in enumerate(zip(colors, labels)):
    p = multivariate_normal.pdf(np.dstack((X1, X2)), mean=mu[i], cov=Sigma)
    ax.plot_surface(X1, X2, p.reshape(X1.shape), color=color, alpha=0.5, label=label)

# Create custom legends
legend_elements = [Patch(facecolor=color, edgecolor='k', label=label) for color, label in zip(colors, labels)]
ax.legend(handles=legend_elements, loc='upper right')

# Reverse the X1 axis
ax.set_xlim([6, -4])
ax.set_ylim([-4, 6])
ax.set_xlabel('X1')
ax.set_ylabel('X2')
ax.set_title('2D SDT')
ax.view_init(45, -45)

plt.show()

In [None]:
# Define the input parameters
mu = np.array([[3.5, 0.5], [0.5, 3.5], [0.5, 0.5]])
Sigma = np.array([[1, 0], [0, 1]])
Wprior = np.array([1/3, 1/3, 1/3])  # flat priors

# High evidence for X1, low evidence for X2
X = np.array([3, 0])
post_w, KL_W = HOSS_evaluate_flat(X, mu, Sigma, Wprior)
print('Posterior probabilities for X = [3, 0]:', post_w)
print('KL Divergence for X = [3, 0]:', KL_W)

# High evidence for X2, low evidence for X1
X = np.array([0, 3])
post_w, KL_W = HOSS_evaluate_flat(X, mu, Sigma, Wprior)
print('Posterior probabilities for X = [0, 3]:', post_w)
print('KL Divergence for X = [0, 3]:', KL_W)

# No evidence for either
X = np.array([0, 0])
post_w, KL_W = HOSS_evaluate_flat(X, mu, Sigma, Wprior)
print('Posterior probabilities for X = [0, 0]:', post_w)
print('KL Divergence for X = [0, 0]:', KL_W)

This is as we would expect - the most likely state is recovered in each case. The slightly higher KL divergence in the third scenario indicates a greater degree of "surprise" or information gain, as the prior was uniformly distributed across all states, but the posterior is now highly concentrated on the third state.

## Add in higher-order node for global detection
So far we have considered a "flat" architecture in which each state (w1, w2 or absence) is independent.
The key addition in HOSS is to allow the model to flexibly answer queries about awareness of any stimulus contents (w1 or w2... or wN).
We achieve this by introducing a higher-order node - the "A" level - that "monitors" for activations in the W-level below.

The inputs remain the same - pairs of X's. But now the outputs give us queries on both the W (content) and A (awareness) levels - where post_A denotes the posterior probaiblity of any content (vs. noise).
We now also need to set priors at both the A- and W-levels.

### Coding exercise

```python
def HOSS_evaluate(X, mu, Sigma, Aprior, Wprior):
    """
    Inference on 2D Bayes net for asymmetric inference on presence vs. absence.
    """

    #################################################
    ## TODO for students: fill in the missing variables ##
    # Fill out function and remove
    raise NotImplementedError("Student exercise: fill in the missing variables")
    #################################################

    # Initialise variables and conditional prob tables
    p_A = np.array([1 - Aprior, Aprior])  # prior on awareness state A
    p_W_a1 = np.append(Wprior, 0)  # likelihood of world states W given aware, last entry is absence
    p_W_a0 = np.append(np.zeros(len(Wprior)), 1)  # likelihood of world states W given unaware, last entry is absence
    p_W = (p_W_a1 + p_W_a0) / 2  # prior on W marginalising over A (for KL)

    # Compute likelihood of observed X for each possible W (P(X|mu_w, Sigma))
    lik_X_W = np.array([multivariate_normal.pdf(...) for mu_i in mu])
    p_X_W = lik_X_W / lik_X_W.sum()  # normalise to get P(X|W)

    # Combine with likelihood of each world state w given awareness state A
    lik_W_A = np.vstack((p_X_W * p_W_a0 * p_A[0], p_X_W * p_W_a1 * p_A[1]))
    post_A = ...  # sum over W
    post_A = post_A / post_A.sum()  # normalise

    # Posterior over W (P(W|X=x) marginalising over A)
    post_W = ...  # sum over A
    post_W = post_W / post_W.sum()  # normalise

    # KL divergences
    KL_W = (post_W * np.log(post_W / p_W)).sum()
    KL_A = (post_A * np.log(post_A / p_A)).sum()

    return post_W, post_A, KL_W, KL_A

```

In [None]:
# to_remove solution

def HOSS_evaluate(X, mu, Sigma, Aprior, Wprior):
    """
    Inference on 2D Bayes net for asymmetric inference on presence vs. absence.
    """

    # Initialise variables and conditional prob tables
    p_A = np.array([1 - Aprior, Aprior])  # prior on awareness state A
    p_W_a1 = np.append(Wprior, 0)  # likelihood of world states W given aware, last entry is absence
    p_W_a0 = np.append(np.zeros(len(Wprior)), 1)  # likelihood of world states W given unaware, last entry is absence
    p_W = (p_W_a1 + p_W_a0) / 2  # prior on W marginalising over A (for KL)

    # Compute likelihood of observed X for each possible W (P(X|mu_w, Sigma))
    lik_X_W = np.array([multivariate_normal.pdf(X, mean=mu_i, cov=Sigma) for mu_i in mu])
    p_X_W = lik_X_W / lik_X_W.sum()  # normalise to get P(X|W)

    # Combine with likelihood of each world state w given awareness state A
    lik_W_A = np.vstack((p_X_W * p_W_a0 * p_A[0], p_X_W * p_W_a1 * p_A[1]))
    post_A = lik_W_A.sum(axis=1)  # sum over W
    post_A = post_A / post_A.sum()  # normalise

    # Posterior over W (P(W|X=x) marginalising over A)
    post_W = lik_W_A.sum(axis=0)  # sum over A
    post_W = post_W / post_W.sum()  # normalise

    # KL divergences
    KL_W = (post_W * np.log(post_W / p_W)).sum()
    KL_A = (post_A * np.log(post_A / p_A)).sum()

    return post_W, post_A, KL_W, KL_A

This is now factorised in the code, so we first set the prior on presence (vs. absence), and then set the priors on w1 vs. w2, and the model takes care of the rest. 

In [None]:
# Define the input parameters for this specific example
X = np.array([0, 3])  # Input observed features
Wprior = np.array([0.5, 0.5])  # Prior probabilities of stimuli
Aprior = 0.5  # Prior probability of being aware

# Call the HOSS_evaluate function with the specified parameters
post_W, post_A, KL_W, KL_A = HOSS_evaluate(X, mu, Sigma, Aprior, Wprior)

# Print the posterior probabilities
print(f"Posterior probabilities at W level: {post_W}")
print(f"Posterior probability at A level: {post_A}")

In [None]:
# Define the input parameters for this specific example
X = np.array([0, 3])  # Input observed features
Wprior = np.array([0.5, 0.5])  # Prior probabilities of stimuli
Aprior = 0.5  # Prior probability of being aware

# Call the HOSS_evaluate function with the specified parameters
post_W, post_A, KL_W, KL_A = HOSS_evaluate(X, mu, Sigma, Aprior, Wprior)

# Print the posterior probabilities
print(f"Posterior probabilities at W level: {post_W}")
print(f"Posterior probability at A level: {post_A}")

## Plot surfaces for content / awareness inferences
To explore the properties of the model, we can simulate inference at different levels of the hierarchy over the full 2D space of possible input X's. The left panel below shows that the probability of awareness (of any stimulus contents) rises in a graded manner from the lower left corner of the graph (low activation of any feature) to the upper right (high activation of both features). In contrast, the right panel shows that confidence in making a discrimination response (e.g. rightward vs. leftward) increases away from the major diagonal, as the model becomes sure that the sample was generated by either a leftward or rightward tilted stimulus. 

Together, the two surfaces make predictions about the relationships we might see between discrimination confidence and awareness in a simple psychophysics experiment. One notable prediction is that discrimination could still be possible - and lead to some degree of confidence - even when the higher-order node is "reporting" unawareness of the stimulus.

In [None]:
# Define the grid
xgrid = np.arange(0, 2.01, 0.01)

# Define the means for the Gaussian distributions
mu = np.array([[0.5, 1.5], [1.5, 0.5], [0.5, 0.5]])

# Define the covariance matrix
Sigma = np.array([[1, 0], [0, 1]])

# Prior probabilities
Wprior = np.array([0.5, 0.5])
Aprior = 0.5

# Initialize arrays to hold confidence and posterior probability
confW = np.zeros((len(xgrid), len(xgrid)))
posteriorAware = np.zeros((len(xgrid), len(xgrid)))
KL_w = np.zeros((len(xgrid), len(xgrid)))
KL_A = np.zeros((len(xgrid), len(xgrid)))

# Compute confidence and posterior probability for each point in the grid
for i, xi in enumerate(xgrid):
    for j, xj in enumerate(xgrid):
        X = [xi, xj]
        post_w, post_A, KL_w[i, j], KL_A[i, j] = HOSS_evaluate(X, mu, Sigma, Aprior, Wprior)

        confW[i, j] = max(post_w)
        posteriorAware[i, j] = post_A[1]

# Plotting
plt.figure(figsize=(10, 5))

# Posterior probability "seen"
plt.subplot(1, 2, 1)
plt.contourf(xgrid, xgrid, posteriorAware.T)
plt.colorbar()
plt.xlabel('X1')
plt.ylabel('X2')
plt.title('Posterior probability "seen"')
plt.axis('square')

# Confidence in identity
plt.subplot(1, 2, 2)
contour_set = plt.contourf(xgrid, xgrid, confW.T)
plt.colorbar()
plt.contour(xgrid, xgrid, posteriorAware.T, levels=[0.5], linewidths=4, colors=['white'])  # Line contour for threshold
plt.xlabel('X1')
plt.ylabel('X2')
plt.title('Confidence in identity')
plt.axis('square')

plt.show()

## Simulate KL divergence surfaces

We can also simulate K-L divergences (a measure of Bayesian surprise) at each layer in the network, which under predictive coding models of brain has been proposed to scale with neural activation (eg Friston, 2005; Summerfield & de Lange, 2014).

In [None]:
# Calculate the mean K-L divergence for absent and present awareness states
KL_A_absent = np.mean(KL_A[posteriorAware < 0.5])
KL_A_present = np.mean(KL_A[posteriorAware >= 0.5])
KL_w_absent = np.mean(KL_w[posteriorAware < 0.5])
KL_w_present = np.mean(KL_w[posteriorAware >= 0.5])

# Plotting
plt.figure(figsize=(18, 6))

# K-L divergence, perceptual states
plt.subplot(1, 2, 1)
plt.contourf(xgrid, xgrid, KL_w.T, cmap='viridis')
plt.colorbar()
plt.xlabel('X1')
plt.ylabel('X2')
plt.title('K-L divergence, perceptual states')
plt.axis('square')

# K-L divergence, awareness state
plt.subplot(1, 2, 2)
plt.contourf(xgrid, xgrid, KL_A.T, cmap='viridis')
plt.colorbar()
plt.xlabel('X1')
plt.ylabel('X2')
plt.title('K-L divergence, awareness state')
plt.axis('square')

plt.show()

## Discussion point

Can you recognise the difference between the K-L divergence for the W-level and the one for the A-level?

#### Answer below 

At the level of perceptual states W, there is substantial asymmetry in the K-L divergence expected when the model says ‘seen’ vs. ‘unseen’ (lefthand panel). This is due to the large belief updates invoked in the perceptual layer W by samples that deviate from the lower lefthand corner - from absence. In contrast, when we compute K-L divergence for the A-level (righthand panel), the level of prediction error is symmetric across seen and unseen decisions, leading to "hot" zones both at the upper righthand (present) and lower lefthand (absent) corners of the 2D space.

We can also sort the K-L divergences as a function of whether the model "reported" presence or absence. As can be seen in the bar plots below, there is more asymmetry in the prediction error at the W compared to the A levels.

In [None]:
# Create figure with specified size
plt.figure(figsize=(12, 4))

# KL divergence for W states
plt.subplot(1, 2, 1)
plt.bar(['unseen', 'seen'], [KL_w_absent, KL_w_present], color='k')
plt.ylabel('K-L divergence, W states')
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)

# KL divergence for A states
plt.subplot(1, 2, 2)
plt.bar(['unseen', 'seen'], [KL_A_absent, KL_A_present], color='k')
plt.ylabel('K-L divergence, A states')
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)

plt.tight_layout()

# Show plot
plt.show()

## Simulate ignition (asymmetry vs. symmetry)

A notable feature about the HOSS architecture is that it is asymmetric - there are naturally more possible (perceptual) states nested under "presence" than under "absence".
As we saw in the previous section, this asymmetry in the state space suggests there will be greater summed prediction error in the entire network on presence decisions (as summarized by K-L divergence at each node of W). This may be a computational correlate of the global ignition responses often found to track awareness reports, and which is often interpreted as supporting global workspace models (e.g. Del Cul et al. 2007; Dehaene and Changeux 2011).
We can simulate this by asking how the divergence in seen vs. unseen prediction error at the two levels seen in the previous plots changes as a function of stimulus strength - modeled here as sensory precision.

In [None]:
# Experiment parameters
mu = np.array([[3.5, 0.5], [0.5, 3.5], [0.5, 0.5]])
Nsubjects = 30
Ntrials = 600
cond = np.concatenate((np.ones(Ntrials//3), np.ones(Ntrials//3)*2, np.ones(Ntrials//3)*3))
Wprior = [0.5, 0.5]
Aprior = 0.5

# Sensory precision values
gamma = np.linspace(0.1, 10, 6)

# Initialize lists for results
all_KL_w_yes = []
sem_KL_w_yes = []
all_KL_w_no = []
sem_KL_w_no = []
all_KL_A_yes = []
sem_KL_A_yes = []
all_KL_A_no = []
sem_KL_A_no = []
all_prob_y = []

#################################################
## TODO for students: fill in the missing variables ##
# Fill out function and remove
raise NotImplementedError("Student exercise: fill in the missing variables")
#################################################

for y in ...:
    Sigma = np.diag([1./np.sqrt(y)]*2)
    mean_KL_w = np.zeros((Nsubjects, 4))
    mean_KL_A = np.zeros((Nsubjects, 4))
    prob_y = np.zeros(Nsubjects)

    for s in range(Nsubjects):
        KL_w = np.zeros(len(cond))
        KL_A = np.zeros(len(cond))
        posteriorAware = np.zeros(len(cond))

        # Generate sensory samples
        X = np.array([multivariate_normal.rvs(mean=mu[int(c)-1, :], cov=Sigma) for c in cond])

        # Model inversion for each trial
        for i, x in enumerate(X):
            post_w, post_A, KL_w[i], KL_A[i] = HOSS_evaluate(x, mu, Sigma, Aprior, Wprior)
            posteriorAware[i] = post_A[1]  # Assuming post_A is a tuple with awareness probability at index 1

        binaryAware = posteriorAware > 0.5
        for i in range(4):
            conditions = [(cond == 3), (cond != 3), (cond != 3), (cond == 3)]
            aware_conditions = [(binaryAware == 0), (binaryAware == 0), (binaryAware == 1), (binaryAware == 1)]
            mean_KL_w[s, i] = np.mean(KL_w[np.logical_and(aware_conditions[i], conditions[i])])
            mean_KL_A[s, i] = np.mean(KL_A[np.logical_and(aware_conditions[i], conditions[i])])

        prob_y[s] = np.mean(binaryAware[cond != 3])

    # Aggregate results across subjects
    all_KL_w_yes.append(np.mean(mean_KL_w[:, 2:4].flatten()))
    sem_KL_w_yes.append(np.std(mean_KL_w[:, 2:4].flatten()) / np.sqrt(Nsubjects))
    all_KL_w_no.append(np.mean(mean_KL_w[:, :2].flatten()))
    sem_KL_w_no.append(np.std(mean_KL_w[:, :2].flatten()) / np.sqrt(Nsubjects))
    all_KL_A_yes.append(np.mean(mean_KL_A[:, 2:4].flatten()))
    sem_KL_A_yes.append(np.std(mean_KL_A[:, 2:4].flatten()) / np.sqrt(Nsubjects))
    all_KL_A_no.append(np.mean(mean_KL_A[:, :2].flatten()))
    sem_KL_A_no.append(np.std(mean_KL_A[:, :2].flatten()) / np.sqrt(Nsubjects))
    all_prob_y.append(np.mean(prob_y))

# Create figure
plt.figure(figsize=(16, 4.67))

# First subplot: Probability of reporting "seen" for w_1 or w_2
plt.subplot(1, 3, 1)
plt.plot(gamma, all_prob_y, linewidth=2)
plt.xlabel('SOA (precision)')
plt.ylabel('Prob. report "seen" for w_1 or w_2')
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.box(False)

# Second subplot: K-L divergence, perceptual states
plt.subplot(1, 3, 2)
plt.errorbar(gamma, all_KL_w_yes, yerr=sem_KL_w_yes, linewidth=2, label='Seen')
plt.errorbar(gamma, all_KL_w_no, yerr=sem_KL_w_no, linewidth=2, label='Unseen')
plt.legend(frameon=False)
plt.xlabel('SOA (precision)')
plt.ylabel('K-L divergence, perceptual states')
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.box(False)

# Third subplot: K-L divergence, awareness state
plt.subplot(1, 3, 3)
plt.errorbar(gamma, all_KL_A_yes, yerr=sem_KL_A_yes, linewidth=2, label='Seen')
plt.errorbar(gamma, all_KL_A_no, yerr=sem_KL_A_no, linewidth=2, label='Unseen')
plt.legend(frameon=False)
plt.xlabel('SOA (precision)')
plt.ylabel('K-L divergence, awareness state')
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.box(False)

# Adjust layout and display the figure
plt.tight_layout()
plt.show()

In [None]:
# to remove solution

# Experiment parameters
mu = np.array([[3.5, 0.5], [0.5, 3.5], [0.5, 0.5]])
Nsubjects = 30
Ntrials = 600
cond = np.concatenate((np.ones(Ntrials//3), np.ones(Ntrials//3)*2, np.ones(Ntrials//3)*3))
Wprior = [0.5, 0.5]
Aprior = 0.5

# Sensory precision values
gamma = np.linspace(0.1, 10, 6)

# Initialize lists for results
all_KL_w_yes = []
sem_KL_w_yes = []
all_KL_w_no = []
sem_KL_w_no = []
all_KL_A_yes = []
sem_KL_A_yes = []
all_KL_A_no = []
sem_KL_A_no = []
all_prob_y = []

for y in gamma:
    Sigma = np.diag([1./np.sqrt(y)]*2)
    mean_KL_w = np.zeros((Nsubjects, 4))
    mean_KL_A = np.zeros((Nsubjects, 4))
    prob_y = np.zeros(Nsubjects)

    for s in range(Nsubjects):
        KL_w = np.zeros(len(cond))
        KL_A = np.zeros(len(cond))
        posteriorAware = np.zeros(len(cond))

        # Generate sensory samples
        X = np.array([multivariate_normal.rvs(mean=mu[int(c)-1, :], cov=Sigma) for c in cond])

        # Model inversion for each trial
        for i, x in enumerate(X):
            post_w, post_A, KL_w[i], KL_A[i] = HOSS_evaluate(x, mu, Sigma, Aprior, Wprior)
            posteriorAware[i] = post_A[1]  # Assuming post_A is a tuple with awareness probability at index 1

        binaryAware = posteriorAware > 0.5
        for i in range(4):
            conditions = [(cond == 3), (cond != 3), (cond != 3), (cond == 3)]
            aware_conditions = [(binaryAware == 0), (binaryAware == 0), (binaryAware == 1), (binaryAware == 1)]
            mean_KL_w[s, i] = np.mean(KL_w[np.logical_and(aware_conditions[i], conditions[i])])
            mean_KL_A[s, i] = np.mean(KL_A[np.logical_and(aware_conditions[i], conditions[i])])

        prob_y[s] = np.mean(binaryAware[cond != 3])

    # Aggregate results across subjects
    all_KL_w_yes.append(np.mean(mean_KL_w[:, 2:4].flatten()))
    sem_KL_w_yes.append(np.std(mean_KL_w[:, 2:4].flatten()) / np.sqrt(Nsubjects))
    all_KL_w_no.append(np.mean(mean_KL_w[:, :2].flatten()))
    sem_KL_w_no.append(np.std(mean_KL_w[:, :2].flatten()) / np.sqrt(Nsubjects))
    all_KL_A_yes.append(np.mean(mean_KL_A[:, 2:4].flatten()))
    sem_KL_A_yes.append(np.std(mean_KL_A[:, 2:4].flatten()) / np.sqrt(Nsubjects))
    all_KL_A_no.append(np.mean(mean_KL_A[:, :2].flatten()))
    sem_KL_A_no.append(np.std(mean_KL_A[:, :2].flatten()) / np.sqrt(Nsubjects))
    all_prob_y.append(np.mean(prob_y))

# Create figure
plt.figure(figsize=(16, 4.67))

# First subplot: Probability of reporting "seen" for w_1 or w_2
plt.subplot(1, 3, 1)
plt.plot(gamma, all_prob_y, linewidth=2)
plt.xlabel('SOA (precision)')
plt.ylabel('Prob. report "seen" for w_1 or w_2')
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.box(False)

# Second subplot: K-L divergence, perceptual states
plt.subplot(1, 3, 2)
plt.errorbar(gamma, all_KL_w_yes, yerr=sem_KL_w_yes, linewidth=2, label='Seen')
plt.errorbar(gamma, all_KL_w_no, yerr=sem_KL_w_no, linewidth=2, label='Unseen')
plt.legend(frameon=False)
plt.xlabel('SOA (precision)')
plt.ylabel('K-L divergence, perceptual states')
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.box(False)

# Third subplot: K-L divergence, awareness state
plt.subplot(1, 3, 3)
plt.errorbar(gamma, all_KL_A_yes, yerr=sem_KL_A_yes, linewidth=2, label='Seen')
plt.errorbar(gamma, all_KL_A_no, yerr=sem_KL_A_no, linewidth=2, label='Unseen')
plt.legend(frameon=False)
plt.xlabel('SOA (precision)')
plt.ylabel('K-L divergence, awareness state')
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.box(False)

# Adjust layout and display the figure
plt.tight_layout()
plt.show()

## Discussion point

Can you think of experiments that could distinguish between the HOSS and GWS accounts of ignition?
