# The Sally-Anne test

**Inspired by:** Wimmer, H., & Perner, J. (1983). _Beliefs about beliefs: Representation and constraining function of wrong beliefs in young children's understanding of deception._ Cognition, 13(1), 103-128.

Sally sees a marble in a box, then leaves the room. While she is gone, anne secretly moves the marble to a basket. When Sally returns to the room, where will she look for the marble—the box or the basket?

In [1]:
from memo import memo
import jax.numpy as np
import jax
from enum import IntEnum

In [2]:
class Loc(IntEnum):  # marble's location
    BOX = 0
    BASKET = 1

class Action(IntEnum):  # anne's action on marble
    ACT_STAY = 0
    ACT_MOVE = 1

@jax.jit
def do(l, a):  # apply action to marble to get new location
    return np.array([
        [0, 1],
        [1, 0]
    ])[a, l]

class Obs(IntEnum):  # what sally sees
    OBS_NONE = -1  # sees nothing
    OBS_STAY = Action.ACT_STAY
    OBS_MOVE = Action.ACT_MOVE

@memo
def model[marble_pos_t0: Loc, obs: Obs, where_look: Loc]():
    child: knows(marble_pos_t0, obs, where_look)
    child: thinks[
        sally: knows(marble_pos_t0, where_look),
        sally: thinks[
            anne: knows(marble_pos_t0),
            anne: chooses(a in Action, wpp=0.01 if a=={Action.ACT_MOVE} else 0.99),
            anne: chooses(marble_pos_t1 in Loc, wpp=do(marble_pos_t0, a)==marble_pos_t1),
            anne: chooses(o in Obs, wpp=1 if o=={Obs.OBS_NONE} or o==a else 0),
        ],
        sally: observes [anne.o] is obs,
        sally: chooses(where_look in Loc, wpp=Pr[anne.marble_pos_t1 == where_look])
    ]
    return child[ Pr[sally.where_look == where_look] ]

model(print_table=True);

+--------------------+----------+-----------------+----------------------------------------+
| marble_pos_t0: Loc | obs: Obs | where_look: Loc | model[marble_pos_t0, obs, where_look]  |
+--------------------+----------+-----------------+----------------------------------------+
| BOX                | OBS_NONE | BOX             | 0.9900000095367432                     |
| BOX                | OBS_NONE | BASKET          | 0.009999999776482582                   |
| BOX                | OBS_STAY | BOX             | 1.0                                    |
| BOX                | OBS_STAY | BASKET          | 0.0                                    |
| BOX                | OBS_MOVE | BOX             | 0.0                                    |
| BOX                | OBS_MOVE | BASKET          | 1.0                                    |
| BASKET             | OBS_NONE | BOX             | 0.009999999776482582                   |
| BASKET             | OBS_NONE | BASKET          | 0.9900000095367432