<a href="https://colab.research.google.com/github/neuromatch/NeuroAI_Course/blob/main/tutorials/W2D2_CognitiveStructures/student/W2D2_Tutorial2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> &nbsp; <a href="https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W2D2_CognitiveStructures/student/W2D2_Tutorial2.ipynb" target="_parent"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open in Kaggle"/></a>

# Tutorial 2: Learning from structures

**Week 2, Day 2: Neuro-Symbolic Structures**

**By Neuromatch Academy**

__Content creators:__ Michael Furlong

__Content reviewers:__ Hlib Solodzhuk

__Production editors:__ Konstantine Tsafatinos, Ella Batty, Spiros Chavlis, Samuele Bolotta, Hlib Solodzhuk


___


# Tutorial Objectives

*Estimated timing of tutorial: 50 minutes*

This tutorial will present you with a couple of play-examples on the usage of the basic operations while generalizing to the new knowledge.

In [None]:
# @title Tutorial slides
# @markdown These are the slides for the videos in all tutorials today

from IPython.display import IFrame
link_id = "kj6p3"

print(f"If you want to download the slides: 'https://osf.io/download/{link_id}'")

IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{link_id}/?direct%26mode=render", width=854, height=480)

---
# Setup



In [None]:
# @title Install and import feedback gadget

# !pip3 install vibecheck datatops --quiet

# from vibecheck import DatatopsContentReviewContainer
# def content_review(notebook_section: str):
#     return DatatopsContentReviewContainer(
#         "",  # No text prompt - leave this as is
#         notebook_section,
#         {
#             "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
#             "name": "sciencematch_sm", # change the name of the course : neuromatch_dl, climatematch_ct, etc
#             "user_key": "y1x3mpx5",
#         },
#     ).render()

# feedback_prefix = "W2D2_T2"

In [None]:
# @title Install dependencies
# @markdown

# Install sspspace
!pip install git+https://github.com/ctn-waterloo/sspspace@neuromatch --quiet

In [None]:
# Imports

#working with data
import numpy as np

#plotting
import matplotlib.pyplot as plt
import logging

#interactive display
import ipywidgets as widgets

#modeling
import sspspace
from scipy.special import softmax
from sklearn.metrics import log_loss
from sklearn.neural_network import MLPRegressor

In [None]:
# @title Figure settings

logging.getLogger('matplotlib.font_manager').disabled = True

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

In [None]:
# @title Plotting functions

def plot_similarity_matrix(sim_mat, labels, values = False):
    """
    Plot the similarity matrix between vectors.

    Inputs:
    - sim_mat (numpy.ndarray): similarity matrix between vectors.
    - labels (list of str): list of strings which represent concepts.
    - values (bool): True if we would like to plot values of similarity too.
    """
    with plt.xkcd():
        plt.imshow(sim_mat, cmap='Greys')
        plt.colorbar()
        plt.xticks(np.arange(len(labels)), labels, rotation=45, ha="right", rotation_mode="anchor")
        plt.yticks(np.arange(len(labels)), labels)
        if values:
            for x in range(sim_mat.shape[1]):
                for y in range(sim_mat.shape[0]):
                    plt.text(x, y, f"{sim_mat[y, x]:.2f}", fontsize = 8, ha="center", va="center", color="green")
        plt.title('Similarity between vector-symbols')
        plt.xlabel('Symbols')
        plt.ylabel('Symbols')
        plt.show()

def plot_training_and_choice(losses, sims, ant_names, cons_names, action_names):
    """
    Plot loss progression over training as well as predicted similarities for given rules / correct solutions.

    Inputs:
    - losses (list): list of loss values.
    - sims (list): list of similartiy matrices.
    - ant_names (list): list of antecedance names.
    - cons_names (list): list of consequent names.
    - action_names (list): full list of concepts.
    """
    with plt.xkcd():
        plt.subplot(1, len(ant_names) + 1, 1)
        plt.plot(losses)
        plt.xlabel('Training number')
        plt.ylabel('Loss')
        plt.title('Training Error')
        index = 1
        for ant_name, cons_name, sim in zip(ant_names, cons_names, sims):
            index += 1
            plt.subplot(1, len(ant_names) + 1, index)
            plt.bar(range(len(action_names)), sim.flatten())
            plt.gca().set_xticks(range(len(action_names)))
            plt.gca().set_xticklabels(action_names, rotation=90)
            plt.title(f'{ant_name}, not*{cons_name}')

def plot_choice(sims, ant_names, cons_names, action_names):
    """
    Plot predicted similarities for given rules / correct solutions.
    """
    with plt.xkcd():
        index = 0
        for ant_name, cons_name, sim in zip(ant_names, cons_names, sims):
            index += 1
            plt.subplot(1, len(ant_names) + 1, index)
            plt.bar(range(len(action_names)), sim.flatten())
            plt.gca().set_xticks(range(len(action_names)))
            plt.gca().set_xticklabels(action_names, rotation=90)
            plt.ylabel("Similarity")
            plt.title(f'{ant_name}, not*{cons_name}')

In [None]:
# @title Set random seed

import random
import numpy as np

def set_seed(seed=None):
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)

set_seed(seed = 42)

---

# Section 1: Analogies. Part 1

In this section we will construct a simple analogy using Vector Symbolic Algebras. The question we are going to try and solve is "King is to queen as prince is to X".

In [None]:
# @title Video 1: Analogy 1

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display

class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)

def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents

video_ids = [('Youtube', 'qOoUEpIkV6w')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

In [None]:
# @title Submit your feedback
# content_review(f"{feedback_prefix}_analogy_one")

## Coding Exercise 1: Royal Relationships

We're going to start by considering our vocabulary.  We will use the basic discrete concepts of monarch, heir, male and female.

In [None]:
set_seed(42)

symbol_names = ['monarch','heir','male','female']
discrete_space = sspspace.DiscreteSPSpace(symbol_names, ssp_dim=1024, optimize=False)

objs = {n:discrete_space.encode(n) for n in symbol_names}

Now lets create the objects we know about by combinatorally expanding the space: 

1. King is a male monarch
2. Queen is a female monarch
3. Prince is a male heir
4. Princess is a female heir

Complete the missing parts of the code to obtain correct representations of new concepts.

```python
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete correct relations for creating new concepts.")
###################################################################

objs['king'] = objs['monarch'] * objs['male']
objs['queen'] = objs['monarch'] * ...
objs['prince'] = objs['heir'] * objs['male']
objs['princess'] = ... * objs['female']

```

In [None]:
#to_remove solution

objs['king'] = objs['monarch'] * objs['male']
objs['queen'] = objs['monarch'] * objs['female']
objs['prince'] = objs['heir'] * objs['male']
objs['princess'] = objs['heir'] * objs['female']

Now we can take an explicit approach. We know that the conversion from king to queen is to unbind male and bind female, so let's apply that to our prince object and see what we uncover. 

At first, in the cell below, let's recover `queen` from `king` by constructing new `query` concept which represents unbinding of `male` and binding of `female`.

```python
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete correct relation for creating `query` object to compare with `queen`.")
###################################################################

objs['query'] = (objs[...] * ~objs[...]) * objs[...]

```

In [None]:
#to_remove solution

objs['query'] = (objs['king'] * ~objs['male']) * objs['female']

Let's see if this new query object bears any similarity to anything in our vocabulary.

In [None]:
object_names = list(objs.keys())
sims = np.zeros((len(object_names), len(object_names)))

for name_idx, name in enumerate(object_names):
    for other_idx in range(name_idx, len(object_names)):
        sims[name_idx, other_idx] = sims[other_idx, name_idx] = (objs[name] | objs[object_names[other_idx]]).item()

plot_similarity_matrix(sims, object_names, values = True)

The above similarity plot shows that applying that operation successfully converts king to queen.  Let's apply it to 'prince' and see what happens. Now, `query` should represent `princess` concept.

In [None]:
objs['query'] = (objs['prince'] * ~objs['male']) * objs['female']

sims = np.zeros((len(object_names), len(object_names)))

for name_idx, name in enumerate(object_names):
    for other_idx in range(name_idx, len(object_names)):
        sims[name_idx, other_idx] = sims[other_idx, name_idx] = (objs[name] | objs[object_names[other_idx]]).item()

plot_similarity_matrix(sims, object_names, values = True)

Here we have successfully recovered princess, completing the analogy.

This approach, however, requires explicit knowledge of the construction of the objects.  Let's see if we can just work with the concepts of 'king', 'queen',and 'prince' directly.

In the cell below, construct `princess` concept using only `king`, `queen` and `prince`.

```python
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete correct relation for creating `query` object to compare with `princess`.")
###################################################################

objs['query'] = (objs[...] * ~objs[...]) * objs[...]

```

In [None]:
#to_remove solution

objs['query'] = (objs['prince'] * ~objs['king']) * objs['queen']

In [None]:
sims = np.zeros((len(object_names), len(object_names)))

for name_idx, name in enumerate(object_names):
    for other_idx in range(name_idx, len(object_names)):
        sims[name_idx, other_idx] = sims[other_idx, name_idx] = (objs[name] | objs[object_names[other_idx]]).item()

plot_similarity_matrix(sims, object_names, values = True)

Again, we see that we have recovered princess by using our analogy.

That said, the above depends on knowning that the representations are constructed using binding.  Can we do a similar thing through the bundling operation?  Let's try that out.

Reassing concept definitions using bundling operation.

In [None]:
objs['king'] = (objs['monarch'] + objs['male']).normalize()
objs['queen'] = (objs['monarch'] + objs['female']).normalize()
objs['prince'] = (objs['heir'] + objs['male']).normalize()
objs['princess'] = (objs['heir'] + objs['female']).normalize()

But now that we are using an additive model, we need to take a different approach.  Instead of unbinding king and binding queen, we subtract king and add queen to find princess from prince.

Complete the code to reflect updated mechanism.

```python
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete correct relation for creating `query` object to compare with `princess`.")
###################################################################

objs['query'] = (objs[...] - objs[...]) + objs[...]

```

In [None]:
#to_remove solution

objs['query'] = (objs['prince'] - objs['king']) + objs['queen']

In [None]:
sims = np.zeros((len(object_names), len(object_names)))

for name_idx, name in enumerate(object_names):
    for other_idx in range(name_idx, len(object_names)):
        sims[name_idx, other_idx] = sims[other_idx, name_idx] = (objs[name] | objs[object_names[other_idx]]).item()

plot_similarity_matrix(sims, object_names, values = True)

This is a messier similarity plot, due to the fact that the bundled representations are interacting with the all their constituent parts in the vocabulary.  That said, we see that 'princess' is still most similar to the query vector. 

This approach is more like what we would expect from a wordvec embedding.

In [None]:
# @title Submit your feedback
# content_review(f"{feedback_prefix}_royal_relationships")

---

# Section 2: Analogies. Part 2

Estimated timing to here from start of tutorial: 15 minutes

In this section we will construct a database of data structures that describe different countries. Materials are adopted from the paper TBR by Pentti Kanerva.

In [None]:
# @title Video 2: Analogy 2

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display

class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)

def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents

video_ids = [('Youtube', '7RkogP-czNw')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

In [None]:
# @title Submit your feedback
# content_review(f"{feedback_prefix}_analogy_two")

## Coding Exercise 2: Dolar of Mexico

This is going to be a little more involved, because to construct the data structure we are going to need vectors that don't just represent values that we are reasoning about, but also vectors that represent different roles data can play. This is sometimes called a slot-filler representation, or a key-value representation.

At first, let us define concepts and cleanup object.

In [None]:
set_seed(42)

symbol_names = ['dollar','peso', 'ottawa','mexico-city','currency','capital']
discrete_space = sspspace.DiscreteSPSpace(symbol_names, ssp_dim=1024, optimize=False)


objs = {n:discrete_space.encode(n) for n in symbol_names}

cleanup = sspspace.Cleanup(objs)

Now, we will define `canada` and `mexico` concepts by integrating the available information together. You will be provided with `canada` object and your task is to complete for `mexico` one.

```python
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete `mexico` concept.")
###################################################################

objs['canada'] = (objs['currency'] * objs['dollar'] + objs['capital'] * objs['ottawa']).normalize()
objs['mexico'] = (objs['currency'] * ... + objs['capital'] * ...).normalize()

```

In [None]:
#to_remove solution

objs['canada'] = (objs['currency'] * objs['dollar'] + objs['capital'] * objs['ottawa']).normalize()
objs['mexico'] = (objs['currency'] * objs['peso'] + objs['capital'] * objs['mexico-city']).normalize()

We would like to find out Mexico's currency. Complete the code for constructing `query` which will help us to do that. Note, that we are using cleanup operation.

```python
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete `query` concept which will be similar to currency in Mexico.")
###################################################################

objs['query'] = cleanup((objs[...] * ~objs[...]) * objs['mexico'])

```

In [None]:
#to_remove solution

objs['query'] = cleanup((objs['dollar'] * ~objs['canada']) * objs['mexico'])

In [None]:
object_names = list(objs.keys())
sims = np.zeros((len(object_names), len(object_names)))

for name_idx, name in enumerate(object_names[:-1]):
    for other_idx in range(name_idx, len(object_names)):
        sims[name_idx, other_idx] = sims[other_idx, name_idx] = (objs[name] | objs[object_names[other_idx]]).item()

plot_similarity_matrix(sims, object_names)

After cleanup, the query vector is the most similar with the 'peso' object in the vocabularly, correctly answering the question.  

Note, however, that the similarity is not perfectly equal to 1.  This is due to the scale factors applied to the composite vectors 'canada' and 'mexico', to ensure they remain unit vectors, and due to cross talk. Crosstalk is a symptom of the fact that we are binding and unbinding bundles of vector symbols to produce the resultant query vector. The constituent vectors are not perfectly orthogonal (i.e., having a dot product of zero) and as such the terms in the bundle interact when we measure similarity between them.

In [None]:
# @title Submit your feedback
# content_review(f"{feedback_prefix}_dolar_of_mexico")

---

# Section 3: Wason Card Task

Estimated timing to here from start of tutorial: 25 minutes

One of the powerful benefits of using these structured representations is being able to generalize to other circumstances. To demonstrate this, we are going to show how we can use a simple learning rule to learn to extract a generalized rule to different circumstances.

In [None]:
# @title Video 3: Wason Card Task Intro

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display

class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)

def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents

video_ids = [('Youtube', 'KqMMEDjhbKI')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

In [None]:
# @title Submit your feedback
# content_review(f"{feedback_prefix}_wason_card_task_intro")

## Coding Exercise 3: Wason Card Task

We are going to test the generalization property on the Wason Card Task, where a person is told a rule of the form "if the card is even, then the back is blue", they are then presented with a number of cards with either an odd number, an even number, a red back, or a blue back. The participant is asked which cards they have to flip to determine that the rule is true.

In this case, the participant needs to flip only the even card(s), as the rule does not state whether or not odd numbers can have blue backs. 

At first, we will define all needed concepts. For all noun concepts we would also like to have `not concept` presented in the space, please complete missing code parts.

```python
set_seed(42)

card_states = ['red','blue','odd','even','not','green','prime','implies','ant','relation','cons']
encoder = sspspace.DiscreteSPSpace(card_states, ssp_dim=1024, optimize=False)
vocab = {c:encoder.encode(c) for c in card_states}

###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete creating `not x` concepts.")
###################################################################

for a in ['red','blue','odd','even','green','prime']:
    vocab[f'not*{a}'] = vocab[...] * vocab[a]

action_names = ['red','blue','odd','even','green','prime','not*red','not*blue','not*odd','not*even','not*green','not*prime']
action_space = np.array([vocab[x] for x in action_names]).squeeze()

```

In [None]:
#to_remove solution

set_seed(42)

card_states = ['red','blue','odd','even','not','green','prime','implies','ant','relation','cons']
encoder = sspspace.DiscreteSPSpace(card_states, ssp_dim=1024, optimize=False)
vocab = {c:encoder.encode(c) for c in card_states}

for a in ['red','blue','odd','even','green','prime']:
    vocab[f'not*{a}'] = vocab['not'] * vocab[a]

action_names = ['red','blue','odd','even','green','prime','not*red','not*blue','not*odd','not*even','not*green','not*prime']
action_space = np.array([vocab[x] for x in action_names]).squeeze()

Now, we are going to set up a simple perceptron-style learning rule, using the HRR (Holographic Reduced Representations) algebra.  We are going to learn a target transformation, $T$, such that given a learning rule, $A^{*} = T\circledast R$, where $A^{*}$ is the antecedance value bundled with $\texttt{not}$ bound with the consequent value and $R$ is the learning rule.

Rules themselves are going to be composed as country data structures from the previous section. `ant`, `relation` and `cons` are extra concepts which define the structure and which will bind to the specific instances. In the cell below, let us define two rules:

$$\text{blue} \implies \text{even}$$
$$\text{odd} \implies \text{green}$$

```python
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete creating rules as defined above.")
###################################################################

rules = [
    (vocab['ant'] * vocab['blue'] + vocab['relation'] * vocab['implies'] + vocab['cons'] * vocab[...]).normalize(),
    (vocab[...] * vocab[...] + vocab[...] * vocab[...] + vocab[...] * vocab[...]).normalize(),
]

```

In [None]:
#to_remove solution

rules = [
    (vocab['ant'] * vocab['blue'] + vocab['relation'] * vocab['implies'] + vocab['cons'] * vocab['even']).normalize(),
    (vocab['ant'] * vocab['odd'] + vocab['relation'] * vocab['implies'] + vocab['cons'] * vocab['green']).normalize(),
]

Now, we are ready to derive the transformation! For that, we will iterate through the rules and solutions for specified number of iterations and update it as the following:

$$T \leftarrow T - \text{lr}*(A^{*} * \sim R)$$

where $\text{lr}$ is learning rate constant value. Indeed, as $A^{*} = T\circledast R$, it makes sense to unbind learning rule to get the current transformation prediction.

We will also compute loss progression over the time and log loss function between perfect similarity (ones only for antecedance value and not consequent one) and the one we obtain between prediciton for current transformation and full action space. Complete missing parts of the code in the next cell to complete training.

```python
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete training loop.")
###################################################################

num_iters = 500
losses = []
sims = []
lr = 1e-1
ant_names = ["blue", "odd"]
cons_names = ["even", "green"]

transform = np.zeros((1,encoder.ssp_dim))
for i in range(num_iters):
    loss = 0
    for rule, ant_name, cons_name in zip(rules, ant_names, cons_names):

        #perfect similarity
        y_true = np.eye(len(action_names))[action_names.index(ant_name),:] + np.eye(len(action_names))[4+action_names.index(cons_name),:]

        #prediction with current transform (a_hat = transform * rule)
        a_hat = sspspace.SSP(transform) * ...

        #similarity with current transform
        sim_mat = np.einsum('nd,md->nm', action_space, a_hat)

        #cleanup
        y_hat = softmax(sim_mat)

        #true solution (a* = ant_name + not * cons_name)
        a_true = (vocab[ant_name] + vocab['not']*vocab[...]).normalize()

        #calculate loss
        loss += log_loss(y_true, y_hat)

        #update transform (T <- T - lr * (A* * (~rule)))
        transform -= (lr) * (... - np.array(... * ~...))
        transform = transform / np.linalg.norm(transform)

        #save predicted similarities if it is last iteration
        if i == num_iters - 1:
            sims.append(sim_mat)

    #save loss
    losses.append(np.copy(loss))

```

In [None]:
#to_remove solution

num_iters = 500
losses = []
sims = []
lr = 1e-1
ant_names = ["blue", "odd"]
cons_names = ["even", "green"]

transform = np.zeros((1,encoder.ssp_dim))
for i in range(num_iters):
    loss = 0
    for rule, ant_name, cons_name in zip(rules, ant_names, cons_names):

        #perfect similarity
        y_true = np.eye(len(action_names))[action_names.index(ant_name),:] + np.eye(len(action_names))[4+action_names.index(cons_name),:]

        #prediction with current transform (a_hat = transform * rule)
        a_hat = sspspace.SSP(transform) * rule

        #similarity with current transform
        sim_mat = np.einsum('nd,md->nm', action_space, a_hat)

        #cleanup
        y_hat = softmax(sim_mat)

        #true solution (a* = ant_name + not * cons_name)
        a_true = (vocab[ant_name] + vocab['not']*vocab[cons_name]).normalize()

        #calculate loss
        loss += log_loss(y_true, y_hat)

        #update transform (T <- T - lr * (A* * (~rule)))
        transform -= (lr) * (transform - np.array(a_true * ~rule))
        transform = transform / np.linalg.norm(transform)

        #save predicted similarities if it is last iteration
        if i == num_iters - 1:
            sims.append(sim_mat)

    #save loss
    losses.append(np.copy(loss))

In [None]:
plot_training_and_choice(losses, sims, ant_names, cons_names, action_names)

Let's see what happens when we test it on a new rule it hasn't seen before. This time we will use the rule that $\text{red} \implies \text{prime}$. Your task is to complete new rule in the cell below and observe the results.

```python
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete new rule and predict for it.")
###################################################################

new_rule = (vocab['ant'] * vocab[...] + vocab['relation'] * ... + vocab['cons'] * vocab[...]).normalize()

#apply transform on new rule to test the generalization of the transform
a_hat = sspspace.SSP(transform) * ...

new_sims = np.einsum('nd,md->nm', action_space, a_hat)
y_hat = softmax(new_sims)

```

In [None]:
#to_remove solution

new_rule = (vocab['ant'] * vocab['red'] + vocab['relation'] * vocab['implies'] + vocab['cons'] * vocab['prime']).normalize()

#apply transform on new rule to test the generalization of the transform
a_hat = sspspace.SSP(transform) * new_rule

new_sims = np.einsum('nd,md->nm', action_space, a_hat)
y_hat = softmax(new_sims)

In [None]:
plot_choice([new_sims], ["red"], ["prime"], action_names)

Let's compare how a standard MLP that isn't aware of the structure in the representation performs. Here, features are going to be the rules and output - solutions. Complete the code below.

```python
###################################################################
## Fill out the following then remove
raise NotImplementedError("Student exercise: complete MLP training.")
###################################################################

#features - rules
X_train = np.array(...).squeeze()

#output - a* for each rule
y_train = np.array([
    (vocab[ant_names[0]] + vocab['not']*vocab[cons_names[0]]).normalize(),
    (vocab[ant_names[1]] + vocab['not']*vocab[cons_names[1]]).normalize(),
]).squeeze()

regr = MLPRegressor(random_state=1, hidden_layer_sizes=(1024,1024), max_iter=1000).fit(..., ...)

a_mlp = regr.predict(new_rule)

mlp_sims = np.einsum('nd,md->nm', action_space, a_mlp)

```

In [None]:
#to_remove solution

#features - rules
X_train = np.array(rules).squeeze()

#output - a* for each rule
y_train = np.array([
    (vocab[ant_names[0]] + vocab['not']*vocab[cons_names[0]]).normalize(),
    (vocab[ant_names[1]] + vocab['not']*vocab[cons_names[1]]).normalize(),
]).squeeze()

regr = MLPRegressor(random_state=1, hidden_layer_sizes=(1024,1024), max_iter=1000).fit(X_train, y_train)

a_mlp = regr.predict(new_rule)

mlp_sims = np.einsum('nd,md->nm', action_space, a_mlp)

In [None]:
plot_choice([mlp_sims], ["red"], ["prime"], action_names)

As you can see, this model, even though it is a more expressive neural network, simply learns to predict the values it had seen before, when presented with a novel stimulus.

In [None]:
# @title Video 4: Wason Card Task Outro

from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display

class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)

def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents

video_ids = [('Youtube', 'rV3oZXLFrb4')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)

In [None]:
# @title Submit your feedback
# content_review(f"{feedback_prefix}_wason_card_task_outro")

---
# Summary

*Estimated timing of tutorial: 45 minutes*

In this tutorial, we have observed three scenarios where we used basic operations to develop relations between different concepts and derive some useful information about them. The next, enclosing tutorial, proposes even more complicated tasks and develops the true power of the proposed representations.