# Bayes tutorial Part 2: Bayesian Decision Theory

References:
* The ultimate reference (highly recommend): [Bayesian models of action and perception](https://www.cns.nyu.edu/malab/bayesianbook.html)
* Primer on Bayesian Decision Models & fitting them to data: [Ma '19
](https://www.cell.com/neuron/pdfExtended/S0896-6273(19)30840-2)
* Primer on the conections between Bayesian Decision Theory & RL: [Dayan & Daw '08](https://www.princeton.edu/~ndaw/dd08.pdf)
* Past tutorials that have inspired this one: [Hermundstad Cosyne '20](https://www.janelia.org/sites/default/files/Labs/Hermundstad%20Lab/Slides_Cosyne2020.pdf), [Ma Cosyne '19](http://www.cns.nyu.edu/malab/static/files/courses/Bayesian/Slides.pdf)


<br>

------


### This part of the tutorial has 2 (+1 bonus) subsections

1. **Perceptual decisions**: Take what you have learnt about Bayesian perceptual inference in Part 1, and use it to make decisions about a stimulus
    1. In the face of unequal priors
    2. In the face of unequal costs/rewards - introducing the concept of *utility*

<br>

2. **Hierarchical inference**: Add an extra level of Bayesian inference when there are more hidden variables of interest (e.g. an underlying context)
    1. Inferring (hidden) context by observing spikes - introducing the concept of *marginalization*
    2. In the face of changing contexts

    

##### BONUS:

3. **Reinforcement learning**: Turn this into an RL problem when costs/rewards are unknown, so you have to additionally *learn them*
    1. Markov decision process & its connections to RL
    2. In the face of unknown costs/rewards - introducing the concept of *value*


In [None]:
# Import relevant modules and define some plotting functions
import pandas as pd
import numpy as np
import re
import seaborn as sns
from scipy.stats import norm, vonmises
import scipy.optimize as opt
from matplotlib.gridspec import GridSpec
from collections import OrderedDict
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML
from typing import Dict, List

%matplotlib inline
sns.set_context('talk')

# Problem setup

**We are going to consider the decision processes of a virtual bee foraging in a virtual land with patches of two types of environments (or contexts $c$): "forests" and "fields".** 

The two contexts differ in the kind of flowers they have. We will refer to the flower type as $f$, you can think of it as a continuous variable parametrizing flower colors. We will assume that the flower colors $f$ in the two contexts are normally distributed but differ in their means. 

$$ p(f | c = \text{forest})  = \mathcal{N}(f; \mu_{\text{forest}}, \sigma^2)$$

$$ p(f | c = \text{field})  = \mathcal{N}(f; \mu_{\text{field}}, \sigma^2)$$


So first, let's setup a class with some methods that will let us easily create and sample an encountered flower color $f$ from the two contexts:

In [None]:
class Context:
    
    # initialize
    def __init__(self, pFlowerMu, pFlowerSig = 0.1):
        self.pFlowerMu = pFlowerMu   # mean of the flower color distribution
        self.pFlowerSig = pFlowerSig  # standard deviation of flower color distribution

    # method 1: Probability density of flower colors 
    def pFlower(self, f):
        return norm.pdf(
            f, 
            loc = self.pFlowerMu, 
            scale = self.pFlowerSig)

    # method 2: Cumulative density of flower colors 
    def cFlower(self, f):
        return norm.cdf(
            f,
            loc = self.pFlowerMu, 
            scale = self.pFlowerSig)

    # method 3: Randomly sample a flower color 
    def sample(self, n = 1):
        return norm.rvs(
            loc = self.pFlowerMu, 
            scale = self.pFlowerSig,
            size = n)


Now with this class, let's create the two contexts,  $c \in \{\text{forests},\text{fields}\}$ which differ in the average color of flowers $f$:

In [None]:
# make a dictionary for easily referring to the two contexts
contexts = {
    'forest': Context(pFlowerMu = 0.66), 
    'field': Context(pFlowerMu = 0.33) 
    }

# define some useful variables for plotting
colors = {
    'forest':[0,0,1], 
    'field':[1,0,0]
    }

# possible values of f
f = np.arange(0, 1, 0.01)

# And then plot
fig = plt.figure(figsize=(8, 3))
axs = fig.gca()
for c in contexts:
    axs.fill(
        f,
        contexts[c].pFlower(f),
        alpha = 0.3,
        color = colors[c],
        label = c)
axs.legend()
axs.set_xlabel('Flower color')
axs.set_ylabel('Probability')
axs.set_yticks([])
axs.set_title('Distribution of flower colors')
sns.despine()
plt.show()


# 1. Perceptual decisions

Ok, so now we are going to model how given an observation of flower color we can decide which context (*fields* or *forests*) we are in. 

In part 1 we did something similar, there we asked ~ what is the *true* flower color given the observation of flower color we made. Here, we simplify (or perhaps complicate this) given a flower color we want to classify it into which category it came from. This is more akin to some decision-making tasks we setup in the lab


<br>

So first, using what you've learnt about Bayesian inference, code up a function that returns the likelihood function - that is the probability of observing any given flower color (hence a function of flower color) given that we're in forests or fields.

$$ L(f | c = \text{forest}) = p(f | c = \text{forest}) $$

Here's a schematic of this generative model, with solid bubbles being observables and empty ones being hidden

<div style="text-align:center">
    <img src="./figures/f1.png" alt="Generative Model" width="100"/>
</div>

<div>
<img src="https://drive.google.com/uc?export=view&id=1YH2AXEM_fQxPZ7tz8bMzOBEXBeuLINpg" width="100"/>
</div>

In [None]:
def lik_function(f, context_name:str, contexts: Dict[str, Context]):
    """
    Function that returns the likelihood ratio of an observed flower color f given forest v.s. field contexts
        f: observed flower color(s)
        context: class with parameters of flower distributions
    Returns:
        lik_ratio: ratio of likelihoods of observed flower colors given forest v.s. field contexts
    """
    
    # write the expression for computing the likelihood function here
    lr = ???
    return lr


Using this function, we're going to write down a *decision rule* (i.e. a way of deciding which context we're in) that *maximizes* this function. We can express this operation of maximising likelihood in a few different ways:
* Picking the context $c$ that has a higher likelihood $L(f | c)$:
$$\text{choose forest if } L(f | c = \text{forest}) > L(f | c = \text{field})$$
* Picking $\text{forest}$ if the *likelihood ratio* (LR) of forest to field is >1:
$$\text{choose forest if LR} \equiv \frac{L(f | c = \text{forest})}{L(f | c = \text{field})}>1$$
* Taking the log, and picking $\text{forest}$ if the *log likelihood ratio* (LLR) is >0:
$$\text{choose forest if LLR}  \equiv log \bigg( \frac{L(f | c = \text{forest})}{L(f | c = \text{field}} \bigg)>0$$

Let us plot the likelihood function you coded up, and also examine the log likelihood ratios of 1000 observed flower colors, 500 from each context!


In [None]:
# Convenient method to plot any given function of flower color for both contexts
def plot_func_of_f_c(func, func_name: str, f, contexts, colors, **func_kwargs):
    fig = plt.figure(figsize=(8, 4))
    y = dict()
    for c in contexts:
        y[c]=func(f,c, contexts, **func_kwargs)
        plt.plot(f, y[c], '-', color = colors[c], label = c)
    threshold = min(f[y['forest']>y['field']])
    plt.vlines(threshold,min([min(y['forest']), min(y['field'])]),max([max(y['forest']), max(y['field'])]),linestyles='dashed',colors='k')
    
    plt.xlabel('Flower color')
    plt.ylabel(f"{func_name}")
    plt.legend()
    plt.title(f"{func_name} as a function of flower color, context")
    sns.despine()
    return threshold

threshold = plot_func_of_f_c(lik_function, 'Likelihood', f, contexts, colors)

How do you think we can implement the decision rule of *maximizing* this function? To build intuition, let's simulate some datapoints and examine the log likelihood ratio (LLR)

In [None]:
nsamples = 1000

loglik_ratios = dict()
data = dict()
for c in contexts:
    data[c] = contexts[c].sample(n = int(nsamples/2))
    loglik_ratios[c] = ??? 
    plt.scatter(data[c], loglik_ratios[c], color = colors[c],alpha=0.1)
plt.xlim([0,1])
plt.xlabel('Flower color')
plt.ylabel('Log likelihood ratio')
plt.title('Log likelihood ratios of forest v.s. field')
plt.hlines(0,0,1,colors='k')
plt.vlines(threshold, min(loglik_ratios['field']), max(loglik_ratios['forest']), linestyles='dashed',colors='k')
sns.despine()

### Do you notice anything interesting?

> write your answer here (see solutions to check your intuitions)

Let's get Bayesian! We know that the likelihood doesn't tell the entire story, and as Bayesians we must take into account information about the *prior* prevalence of the two categories.

Let's code up the *posterior* probability of being in a forest v.s. field given the observations, by combining the likelihood and prior. 

HINT: Remember that the posterior probability is a *probability* so it must be normalized!

<br>

$$ \begin{align}
p(c  = \text{forest}| f) &∝ p(f | c= \text{forest}) \times p(c= \text{forest}) \\
 &= \frac{p(f | c= \text{forest}) \times p(c= \text{forest})}{p(f | c= \text{forest}) \times p(c= \text{forest}) + p(f | c= \text{field}) \times p(c= \text{field})}
 \end{align}$$



In [None]:
def posterior_probability(f, chosen_context:str, contexts: Dict[str,Context], priors:Dict[str,float]):
    """
    Function that returns the posterior probability (upto a scaling constant) of an observed frequency f given a context and priors
        f: observed flower color(s)
        chosen_context: context for which posterior is being evaluated
        contexts: dictionary of Context classes
        prior: prior probability of that context
    Returns:
        posterior: posterior probability (upto a scaling constant) of observed flower colors given a context
    """
    # write the expression for computing the posterior here
    unchosen_context = [c for c in contexts if c!=chosen_context][0]
    posterior_chosen = ???
    posterior_unchosen = ???
    denom = ???
    return posterior_chosen/denom

Let's look at the posterior function when the prior probabilities of the two contexts are unequal. What happens to the decision rule then?


In [None]:
priors = {
    'forest': 0.9,
    'field': 0.1
}
threshold = plot_func_of_f_c(posterior_probability, 'Posterior probability', f, contexts, colors, priors = priors)
plt.vlines(0.5,0,1,linestyles='dotted')
plt.show()

#### What should happens at "ambiguous" flower colors of 0.5 when the prior is skewed? Should this affect the decision rule, and if so how?

> write answer (See solutions to check your intuitions)

## Utilities

Ok, now we are going to move even closer to real-life decisions.

For each possible "ground truth" context, there are two possible decisions we could have made, one correct and one incorrect. However, these may not be equally good/bad - different errors could have different costs, similarly different correct decisions could have different rewards.

Bayesian decision theory tells us we should take these into account for our decision rule, by thresholding the **utility ratio**:

$$
\frac{U(\text{choose forest})}{U(\text{choose field})} = \left( \frac{r_{\text{forest}}(\text{choose forest}).P(c = \text{forest} | f) - c_{\text{field}}(\text{choose forest}).P(c = \text{field} | f)}{r_{\text{field}}(\text{choose field}).P(c = \text{field} | f) - c_{\text{forest}}(\text{choose field}).P(c = \text{forest} | f)} \right)
$$

This introduces additional variables: the action we take $a$ (choose forest, choose field), and the reward $r$ / cost $c$ of correct/incorrect decisions

<div style="text-align:center">
    <img src="./figures/f15.png" alt="Generative Model" width="100"/>
</div>

<div>
<img src="https://drive.google.com/uc?export=view&id=10BjVANeoLWjSQUUUREIQ8GxsrtJRpfzi" width="100"/>
</div>


Code this up & play with different costs and rewards.

In [None]:
def utility_function(f, chosen_context:str, contexts: Dict[str, Context], priors:Dict[str,float], rewards:Dict[str,float], costs:Dict[str,float]):
    """
    Function that returns the utility function of choosing a particular context given an observed frequency f and rewards/costs
        f: observed frequency
        contexts: context class
        prior: ratio of priors for each context
        reward: reward for correctly choosing that context
        cost: cost for incorrectly choosing that context
    Returns:
        utility: utility function for choosing that context
    """
    # Compute posteriors for both contexts
    unchosen_context = [c for c in contexts if c!=chosen_context][0]
    posterior_chosen = ???
    posterior_unchosen = ???
    
    # Compute utility for chosen context
    utility = ???
    
    return utility


Let's look at the utility ratios for a number of observed flower color samples, given a prior ratio, reward ratio and associated costs. Play around with these values to see what happens:


In [None]:
# set the prior to be high for forest and 
# the cost of incorrectly choosing forest when the context is really field to be really high

priors = {
    'forest': 0.9,
    'field': 0.1
}

rewards = { # of correctly choosing each context
    'forest': 1,
    'field': 1
}

costs = { # of incorrectly choosing each context
    'forest': 91,
    'field': 11
}


threshold = plot_func_of_f_c(utility_function, 'Utility function', f, contexts, colors, priors = priors, rewards=rewards, costs=costs)
plt.vlines(0.5,-90,0,linestyles='dotted')
plt.show()


#### What do you notice about the interplay between prior ratios and reward ratios?

(see solutions to check your intuitions)

## Takeaways

Bayesian Decision Theory (BDT) extends Bayesian inference to *decisions* i.e. discrete choices between (unobservable) categories, based on observed evidence. 

In BDT, we consider not just the (posterior) probabilities of these categories given the evidence, but also the rewards/costs of making correct/incorrect decisions. 
These two ingredients combine to form the *utility* of different decisions, and BDT lets us make decisions that maximize this utility.

Consequently, BDT is a reminder of the fact that optimal decision-making agents can be biased towards one decision or another for many different reasons - because that decision is more prevalent, or because the associated rewards/costs are high.

# 2. Hierarchical inference

So far, we've assumed direct access to the flower colors $f$. However, in reality we may only have access to noisy sensory observations $x$ from a "color detector" neuron. This modifies our generative model to a new hierarchical graph:

<div style="text-align:center">
    <img src="./figures/f2.png" alt="Generative Model" width="100"/>
</div>

<div>
<img src="https://drive.google.com/uc?export=view&id=1Zjt9nhbkAHvQzyiatmEDAjuex9oprduU" width="100"/>
</div>


Let's assume that the color detector neuron has firing rate centered around the true color, with sensory noise characterized by a variance of $\sigma_{\text{sensory}}^2$ The tuning function of such a neuron can be written as:

In [None]:
def tuningFun(f, beePars):
    # Gaussian noise on firing rate centered at the color
    rate = f + np.random.normal(0, beePars['sigmaSensory'], size=f.shape) 
    return rate

In this case, we want to *marginalize* over the frequency, since we still want to infer the variable fo interest c. That is, we are interested in

$$
P(x \mid c) = \int P(x \mid f) P(f \mid c) \, df.
$$

If we assume that both $P(x \mid f)$ and $P(f \mid c)$ are Gaussian with variances $\sigma_{\text{sensory}}^2$ and $\sigma_{f}^2$ respectively, 
$$P(x \mid f) = \mathcal{N}(x | f, \sigma^2_\text{sensory})$$
$$P(f \mid c) = \mathcal{N}(f | c, \sigma^2_\text{f})$$

then the integral also becomes gaussian - and assuming that the firing rate distribution is centered around the true frequency, this new Gaussian has variance $\sigma_{f}^2 + \sigma_{sensory}^2$ 

This is quite a useful property of gaussians, so we have included a derivation below (Expand the cell)

To derive the marginal probability $P(x \mid c)$ given that $P(x \mid f)$ and $P(f \mid c)$ are Gaussian distributions, let's start with the given integral:

$$
P(x \mid c) = \int P(x \mid f) P(f \mid c) \, df
$$

Assume:
$$
P(x \mid f) = \mathcal{N}(x \mid f, \sigma_{\text{sensory}}^2) = \frac{1}{\sqrt{2\pi\sigma_{\text{sensory}}^2}} \exp\left(-\frac{(x - f)^2}{2\sigma_{\text{sensory}}^2}\right)
$$
$$
P(f \mid c) = \mathcal{N}(f \mid c, \sigma_{f}^2) = \frac{1}{\sqrt{2\pi\sigma_{f}^2}} \exp\left(-\frac{(f - c)^2}{2\sigma_{f}^2}\right)
$$

Substituting these Gaussian distributions into the integral, we get:

$$
P(x \mid c) = \int \frac{1}{\sqrt{2\pi\sigma_{\text{sensory}}^2}} \exp\left(-\frac{(x - f)^2}{2\sigma_{\text{sensory}}^2}\right) \frac{1}{\sqrt{2\pi\sigma_{f}^2}} \exp\left(-\frac{(f - c)^2}{2\sigma_{f}^2}\right) \, df
$$

Combining the constants outside the integral:

$$
P(x \mid c) = \frac{1}{2\pi \sqrt{\sigma_{\text{sensory}}^2 \sigma_{f}^2}} \int \exp\left(-\frac{(x - f)^2}{2\sigma_{\text{sensory}}^2} - \frac{(f - c)^2}{2\sigma_{f}^2}\right) \, df
$$

Next, combine the exponents:

$$
-\frac{(x - f)^2}{2\sigma_{\text{sensory}}^2} - \frac{(f - c)^2}{2\sigma_{f}^2}
$$

To simplify this, let's expand both terms:

$$
-\frac{(x - f)^2}{2\sigma_{\text{sensory}}^2} = -\frac{1}{2\sigma_{\text{sensory}}^2} (x^2 - 2xf + f^2)
$$
$$
-\frac{(f - c)^2}{2\sigma_{f}^2} = -\frac{1}{2\sigma_{f}^2} (f^2 - 2fc + c^2)
$$

Combine the exponents:

$$
-\frac{x^2}{2\sigma_{\text{sensory}}^2} + \frac{xf}{\sigma_{\text{sensory}}^2} - \frac{f^2}{2\sigma_{\text{sensory}}^2} - \frac{f^2}{2\sigma_{f}^2} + \frac{fc}{\sigma_{f}^2} - \frac{c^2}{2\sigma_{f}^2}
$$

Group the terms involving $f$:

$$
= -\frac{x^2}{2\sigma_{\text{sensory}}^2} + \frac{xf}{\sigma_{\text{sensory}}^2} - \left( \frac{1}{2\sigma_{\text{sensory}}^2} + \frac{1}{2\sigma_{f}^2} \right) f^2 + \frac{fc}{\sigma_{f}^2} - \frac{c^2}{2\sigma_{f}^2}
$$

Combine the coefficients of $f$:

$$
= -\left( \frac{1}{2\sigma_{\text{sensory}}^2} + \frac{1}{2\sigma_{f}^2} \right) f^2 + \left( \frac{x}{\sigma_{\text{sensory}}^2} + \frac{c}{\sigma_{f}^2} \right) f - \left( \frac{x^2}{2\sigma_{\text{sensory}}^2} + \frac{c^2}{2\sigma_{f}^2} \right)
$$

Complete the square for the $f$-dependent part. Let $A = \frac{1}{\sigma_{\text{sensory}}^2} + \frac{1}{\sigma_{f}^2}$, $B = \frac{x}{\sigma_{\text{sensory}}^2} + \frac{c}{\sigma_{f}^2}$:

$$
= -\frac{A}{2} \left( f^2 - \frac{2B}{A} f \right)
$$

Add and subtract the square of the middle term inside the parentheses:

$$
= -\frac{A}{2} \left( f^2 - \frac{2B}{A} f + \left(\frac{B}{A}\right)^2 - \left(\frac{B}{A}\right)^2 \right)
$$

$$
= -\frac{A}{2} \left( \left( f - \frac{B}{A} \right)^2 - \left(\frac{B}{A}\right)^2 \right)
$$

$$
= -\frac{A}{2} \left( f - \frac{B}{A} \right)^2 + \frac{B^2}{2A}
$$

Now, the exponent is:

$$
-\frac{A}{2} \left( f - \frac{B}{A} \right)^2 - \left( \frac{x^2}{2\sigma_{\text{sensory}}^2} + \frac{c^2}{2\sigma_{f}^2} \right) + \frac{B^2}{2A}
$$

Notice that the integral of the Gaussian function over \( f \) (which is a Gaussian integral) results in 1:

$$
\int \exp\left( -\frac{A}{2} \left( f - \frac{B}{A} \right)^2 \right) \, df = \sqrt{\frac{2\pi}{A}}
$$

So, we are left with the term outside the integral:

$$
P(x \mid c) = \frac{1}{2\pi \sqrt{\sigma_{\text{sensory}}^2 \sigma_{f}^2}} \sqrt{\frac{2\pi}{A}} \exp\left( \frac{B^2}{2A} - \left( \frac{x^2}{2\sigma_{\text{sensory}}^2} + \frac{c^2}{2\sigma_{f}^2} \right) \right)
$$

Given that $A = \frac{1}{\sigma_{\text{sensory}}^2} + \frac{1}{\sigma_{f}^2}$ and $B = \frac{x}{\sigma_{\text{sensory}}^2} + \frac{c}{\sigma_{f}^2}$, simplify $\frac{B^2}{2A}$:


$$
\frac{B^2}{2A} = \frac{\left( \frac{x}{\sigma_{\text{sensory}}^2} + \frac{c}{\sigma_{f}^2} \right)^2}{2 \left( \frac{1}{\sigma_{\text{sensory}}^2} + \frac{1}{\sigma_{f}^2} \right)} = \frac{\left( \frac{x}{\sigma_{\text{sensory}}^2} + \frac{c}{\sigma_{f}^2} \right)^2 \sigma_{\text{sensory}}^2 \sigma_{f}^2}{2 \left( \sigma_{\text{sensory}}^2 + \sigma_{f}^2 \right)}
$$

Thus, the resulting distribution $P(x \mid c)$ is Gaussian with mean $c$ and variance $\sigma_{\text{sensory}}^2 + \sigma_{f}^2$:

$$
P(x \mid c) = \frac{1}{\sqrt{2\pi (\sigma_{\text{sensory}}^2 + \sigma_{f}^2)}} \exp\left(-\frac{(x - c)^2}{2(\sigma_{\text{sensory}}^2 + \sigma_{f}^2)}\right)
$$

Therefore, the new variance is the sum of the original variances because they represent independent sources of noise.



Before we go on and code the marginal distribution, lets define the parameters of the two contexts and the bee

In [None]:
### Parameters of the bee's *internal* distribution of noisy sensory observations
beePars = {'sigmaSensory': 0.1}

### In contrast, these are the parameters of the *environmental* distribution of flower colors
contexts = {
    'forest': Context(pFlowerMu = 0.66, pFlowerSig = 0.1), 
    'field': Context(pFlowerMu = 0.33, pFlowerSig = 0.1) 
    }

# define some useful variables for plotting
colors = {
    'forest':[0,0,1], 
    'field':[1,0,0]
    }

Complete the function below to compute marginal likelihood $P(x|c)$:

In [None]:
def marginal_lik(x, context, beePars):
    
    marginal_lik = norm.pdf(
        x,
        loc = ???,
        scale = ???
        )
    
    return marginal_lik


#### Monte carlo techniques to approximate marginalization

Here, because each of our sources of noise was Gaussian and independent we could (painstakingly) derive the integral to get to the marginal distribution. There may be many cases where this integral becomes intractable to evaluate analytically - in such cases we can use *simulation* to approximate it - such approaches are called Monte Carlo techniques and have wide applications in real world problems.

$$
P(x \mid c) \approx \frac{1}{N} \sum_{i=1}^{N} P(x \mid f_i), \quad \text{where} f_i \sim P(f \mid c)
$$

In [None]:
def marginal_lik_sim(x, context, beePars):
    f_sample = context.sample(n = 100)
    p_sample = np.vstack([norm.pdf(
        x,
        loc = ???,
        scale = ???
        ) for f_i in f_sample])
    return p_sample.mean(axis = 0)


Let us plot both the analytical and simulation-based marginal distributions to confirm that they match:

In [None]:
def compare_analytical_simulated_marginal_likelihoods(contexts, beePars):
    
    fig, axs = plt.subplots(1, 2, figsize=(14, 7))

    for c in contexts:

        # Plot observed firing rates for different true flower colors
        data = contexts[c].sample(n=100)
        rate = tuningFun(data,beePars)
        axs[0].plot(data, data, color = 'k')
        axs[0].set_ylim([0,1])
        axs[0].scatter(data, rate, color = colors[c], alpha = 0.3)
        axs[0].set_xlabel('True flower color')
        axs[0].set_ylabel('Firing rate $(x)$')
        axs[0].set_title('Firing rates $(x)$ observed\nfor the two contexts')

        # we will complute analytical and simulated marginal distributions 
        # over all these possible values of x (firing rate)
        x = np.linspace(rate.min(),rate.max(),100)

        # Plot analytical marginal distribution
        axs[1].plot(
            marginal_lik(x, contexts[c], beePars),
            x, 
            color = colors[c],
            alpha= 0.3, 
            label='analytical')
        
        # Plot simulated marginal distribution
        axs[1].plot(
            marginal_lik_sim(x, contexts[c], beePars),
            x, 
            ls = '--',
            color = colors[c],
            alpha= 0.3, 
            label = 'monte carlo')

        # Plot the distribution of firing rates for the two contexts
        label = f"$P(x | c = {c})$"
        axs[1].hist(
            rate,
            color = colors[c],
            alpha= 0.3,
            label = label, 
            orientation='horizontal', 
            density=True)
        axs[1].set_xlim([0,5])
        axs[1].set_xticks([])
        axs[1].legend()
        axs[1].set_ylabel('Firing rate')
        axs[1].set_xlabel('marginal likelihood $P(x|c)$')
        axs[1].set_title('Comparing analytical vs simulated\nmarginal likelihood $P(x|c)$') 

    sns.despine()
    plt.tight_layout()
    plt.show()
    
compare_analytical_simulated_marginal_likelihoods(contexts, beePars)

So, here given a noisy firing rate, we can infer its context using the marginal likelihood, marginalizing over all possible flower colors. For high firing rates (or intense flower colors), the marginal likelihood is larger for "forest" context - where the true mean was 0.66, and vice versa. To make *decision* about the context, just like before it will be helpful to compute the ratio of the two marginal likelihoods:

In [None]:
def marginal_lik_ratio(x, contexts, beePars):
    return ???

Play around with the sensory noise $\sigma_{\text{sensory}}$ to see the effect on (1) marginal likelihoods and then (2) marginal likelihood ratios!

(1) marginal likelihoods

In [None]:
beePars = {'sigmaSensory': 0.5} # try values: [0.1,0.2,0.4,0.8]

compare_analytical_simulated_marginal_likelihoods(contexts, beePars)

(2) marginal likelihood ratios

In [None]:
# possible firing rates of color detector neuron.
x = np.linspace(0,1,100)

plt.figure(figsize=(5, 5))
legend_label = "$\sigma_{sensory}$"

# vary the sensory noise of the bee
for sigmaSensory in [0.1,0.2,0.4,0.8,0.9]:
    
    beePars = {'sigmaSensory': sigmaSensory}
    plt.plot(
        x, 
        np.log(marginal_lik_ratio(x, contexts, beePars)), 
        label = f"{legend_label}={sigmaSensory}"
        )
plt.xlabel('Firing rate $(x)$')
plt.ylabel('Marginal log likelihood ratio')
plt.title("Effect of sensory noise on marginal LLR")
sns.despine()
plt.legend(bbox_to_anchor=[1.1, 0.8])
plt.show()

### Think about
- What happens to marginal likelihoods and marginal likelihood ratios? 
- Where should the decision boundary be drawn? 
- Is the bee still able to *correctly* decide which context it might be in?

>  [WRITE ANSWER]

## Inferring changing contexts
Ok, now we will let the bee fly around. So the context it is in might change. We will model the decision dyanmics of the bee in such changing contexts. 

<div style="text-align:center">
    <img src="./figures/f3.png" alt="Generative Model" width="400"/>
</div>

<div>
<img src="https://drive.google.com/uc?export=view&id=12O54pIopjwTWtJoOu5yohjcQ4XudIAW7" width="500"/>
</div>

**BUT** before going there, we will learn how to do online Bayesian inference (also called recursive Bayes or Bayesian filtering) in a more restricted setting - one in which the contexts change without the bee having any control over whether it samples from field or forest. This is almost like the bee is trapped in a car and the car passes through patches of fields and forests.

#### First, we will define a class that lets us simulate a bee being driven around through patches of forests and fields.
In addition to the parameters we have had for the two contexts (namely $\mu_c$) and bee parameters (sensory noise $\sigma_{\text{sensory}}^2$), we are going to introduce another parameter $p_{switch}$ which is the probability of the context changing. We are going to assume that this change process is Markovian i.e. there is a fixed probability of change occuring at every time step.

<div style="text-align:center">
    <img src="./figures/f4.png" alt="Generative Model" width="400"/>
</div>

<div>
<img src="https://drive.google.com/uc?export=view&id=1fRKUIA96_P0sqq3ItwsdCTz2fnsC2IG_" width="300"/>
</div>

In [None]:
class simulateEnv:
    
    def __init__(self, contexts, beePars, **envParams):
        self.contexts = contexts
        self.beePars = beePars
        self.switchType = envParams.get('switchType', None)
        self.nSwitch = envParams.get('nSwitch', None)
        self.contextLabels = list(self.contexts)


    #Sample contexts, observed flower colours aka firing rates
    def sample(self, tMax):
        
        fObs = np.zeros(tMax)  # observations made by the bee i.e. firing rates of the color neuron
        pSwitch = 1/self.nSwitch # Hazard rate i.e. 1/(avg no. encounters before switching contexts)
        contextFlag = np.zeros(tMax, dtype='int') # store which context the bee is in (environmental dynamics)

        # Simulate observations of flower colors for this assumed rate of context change
        
        # sample the initial context 
        contextFlag[0] = np.random.randint(len(self.contextLabels))
        context_0 = self.contextLabels[contextFlag[0]]
        
        # make a firing rate observation in this context (remember marginal likelihood is P(x|c))
        fObs[0] = self.sample_from_marginal_lik(context_0, 1)

        switch_indicator = False
    
        # simulate over time
        for i in range(1, tMax):
            
            # switch every nSwitch trials
            if self.switchType == 'periodic':
                switch_indicator = np.mod(i, self.nSwitch) == 0
                
            # Switch contexts with a probability of 'pSwitch'
            elif self.switchType == 'markov':
                switch_indicator = np.random.rand() < pSwitch
            else:
                raiseException("unknow switchType, can be 'periodic' or 'markov")
            
            # switch contexts
            if switch_indicator:
                contextFlag[i] = not(contextFlag[i-1])  
            else:
                contextFlag[i] = contextFlag[i-1]
                
            # Sample flower colors from chosen context
            fObs[i] = self.sample_from_marginal_lik(self.contextLabels[contextFlag[i]], 1)
            
            
        return contextFlag, fObs
    
    
    # method to sample observations (firing rates) given context
    def sample_from_marginal_lik(self, context, n):
        
        sample = norm.rvs(
            loc = self.contexts[context].pFlowerMu,
            scale = np.sqrt(self.contexts[context].pFlowerSig**2 + self.beePars['sigmaSensory']**2),
            size = n
        )
    
        return sample    
    
    
    

# plotting function for visualizing firing rates when the contexts are changing
def plot_context_and_observations(tMax, contextFlag, fObs, title):

    # Plot true context and sensory observations with two y-axes
    fig, ax1 = plt.subplots(figsize=(12, 3))

    # First subplot (True context)
    scatter1 = ax1.scatter(range(tMax), contextFlag, 5, color='k', label='True context')
    ax1.set_yticks([0, 1])
    ax1.set_yticklabels(list(contexts))
    ax1.set_xlabel('Time')

    # Second y-axis for the second subplot (Sensory observations)
    ax2 = ax1.twinx()
    this_color = [0.5, 0.5, 0]
    scatter2 = ax2.scatter(range(tMax), fObs, 8, color=this_color, label='Firing rate observations')
    ax2.set_ylabel('Firing rate')
    ax2.spines['right'].set_color(this_color)
    ax2.tick_params(axis='y', colors = this_color)
    ax2.yaxis.label.set_color(this_color)

    lines = [scatter1, scatter2]
    labels = [line.get_label() for line in lines]
    ax1.legend(
        lines, 
        labels, 
        bbox_to_anchor = (0.8, -0.4),
        ncol = 2
        )
    
    ax1.set_title(f"{title} environment")
    sns.despine(right = False)
    
    return fig, ax1



Okay, lets's visualize this change: you can change the `switchType` to `periodic` in the cell below, to very clearly visualize distributional shifts in observations due to changing contexts. Also try increasing the sensory noise to see how that changes the observations that the bee makes.

In [None]:
## Create a bee with a chosen sensory noise parameter - PLAY WITH THIS!
beePars = {'sigmaSensory':0.1}

### Parameters of the *environmental* distribution of flower colors
contexts = OrderedDict({
    'forest': Context(pFlowerMu = 0.66, pFlowerSig = 0.1), 
    'field': Context(pFlowerMu = 0.33, pFlowerSig = 0.1) 
    })

### number of time points
tMax = 200

### Simulate context switches, with an average change every 20 trials
# Change switchType to 'periodic' to easily understand the relationship between context switch and firing rates
envPars = {
    'switchType': 'markov',
    'nSwitch': 20 
    }

# initialize class
sim = simulateEnv(contexts, beePars, **envPars)

### Sample contexts, firing rate samples from that context
contextFlag,fObs = sim.sample(tMax)
fig, ax = plot_context_and_observations(tMax, contextFlag, fObs, envPars['switchType'])


### Inference 

Now let's write a function to perform online inference as the bee encounters flowers in different contexts. 

Some terminology first:
The posterior probability of the bee thinking that it is in context "field", given the sequence of flowers it has encountered upto and including time t is often referred to as the belief: $p_{\text{belief}, t}^{field}$. 

This belief will need to be updated at each time step, based on the dynamics of context switches and the encountered flower colors. 

Now, when the bee encounters its first flower, two steps need to occur:
1. Prediction: probability that the bee is still in context "field", given that it encountered a flower (irrespective of its color)

2. Update: probability that the bee is still in context "field", given that it observed the specific color of the flower. 

Lets, write down what these steps look like in math:

##### **the prediction step**:
$$p_{\text{predict}, t = 1}^{\text{field}} = p^{\text{field}}_{\text{belief},t=0}(1-p_{\text{switch}}) + (1 - p^{\text{field}}_{\text{belief},t=0})p_{\text{switch}} $$
This first term here evaluates the probability of being in context field if a switch didn't happen. And the second term computes the probability of being in context field if a switch happened.

##### **the update step**:
our prediction acts as the prior and we compute the likelihood of observing firing rate $x_1$ given we are in context "field" to arrive at the posterior proability of thinking that we are in context "field" 
$$p_{\text{belief}, t = 1}^{field} \propto p_{\text{predict}, t = 1}^{\text{field}} \mathcal{N}(x_1; \sigma^2_{\text{sensory}}, \mu_{\text{field}}, \sigma^2_{\text{field}}) $$

Again, this process will be repeated at every time step. Therefore, essentially in Bayesian filtering **"today’s posterior acts as tomorrow’s prior”** (Lindley, Bayesian statistics, a review. 1972, p. 2). This is comparable to evidence accumulation frameworks, such as the DDM, where the observer is thought to accumulate noisy samples of evidence over time. Here, the evidence is the log-likelihood, and accumulation is the Bayes-optimal solution.



In [None]:
# Perform online inference
def onlineInference(fObs, contexts, beePars, pSwitch = 0.5, p0 = [0.5,0.5]):

    # Construct transition matrix
    T = [[1-pSwitch,pSwitch],[pSwitch,1-pSwitch]]

    # Initialize beliefs
    pContext = np.zeros([2,np.size(fObs)])
    pContext[:,0] = p0
    contextLabel = list(contexts)

    for t in np.arange(1, np.size(fObs)):
        
        #Prediction step: apply transition matrix
        pPredict = ???

        # Update step: Likelihood of observation fObs(t) under the two contexts
        lik = [
            ???,
            ???
            ]
        
        # Update step: Posterior belief 
        pContext[:,t] = ???
        
        # normalize
        pContext[:,t] = pContext[:,t]/sum(pContext[:,t])

    return pContext




Now, lets plot the running posterior belief as well as the posterior belief averaged over switches.

In [None]:
# Create a bee with a chosen sensory noise parameter - PLAY WITH THIS!
beePars = {'sigmaSensory': 0.1}

### Parameters of the *environmental* distribution of flower colors
contexts = OrderedDict({
    'forest': Context(pFlowerMu = 0.66, pFlowerSig = 0.1), 
    'field': Context(pFlowerMu = 0.33, pFlowerSig = 0.1) 
    })

# initialize class
sim = simulateEnv(contexts, beePars, **envPars)

# Sample contexts, firing rate samples from that context
contextFlag,fObs = sim.sample(tMax)

# infer posterior beliefs
pContext = onlineInference(
    fObs, 
    contexts, 
    beePars, 
    pSwitch = 1/envPars['nSwitch'],
    # pSwitch = 0.5,
    p0 = [0.5, 0.5]  # initial beliefs about the two contexts
    )

# plot the beliefs
fig, axs = plot_context_and_observations(tMax, contextFlag, fObs, envPars['switchType'])
label = f'$p^{list(contexts)[1]}$'
axs.plot(pContext[1,:],color = 'r', label = label)
axs.legend( bbox_to_anchor = (1.4, 0.7),)
plt.show()

# Plot posterior belief averaged over switches
pattern = {'Forest->Field':'1111100000', 'Field->Forest':'0000011111'}
fig2 = plt.figure(figsize=(16.5, 3))
k = 1
for s in pattern.keys():
    fig2.add_subplot(1, 2 ,k, title = 'Average p('+s+')')
    k+=1
    pAvg = np.zeros(10)
    inds = [m.start() for m in re.finditer(pattern[s], ''.join(str(e) for e in contextFlag))]
    for i in inds:
        pAvg = pAvg + pContext[1,i:i+10]
    pAvg = pAvg/np.size(inds)
    plt.plot(np.arange(-4,6), (1 - pAvg), c = 'r')
    plt.ylabel('p (field)')
    plt.plot([0,0],[0,1],'k--')
sns.despine()


#### things to think about:
1. What happens if you increase the sensory noise, observation noise?
2. What happens if the bee think $p_{\text{switch}}$ is higher than it is?
3. What happens if the bee's knowledge of the distribution of flower colors in the two context is not accurate? 

> log your observations here

# 3. Selecting actions
Let's turn now to simulating actions of the bee in this environment! So far, we assumed (perhaps unnaturally) that the bee's context was shifting because it was being driven around. In reality however, the bee can move around - hence changing the context it is in *as a consequence of its actions*! Such a setup takes us from a hidden Markov model (in which hidden states change on their own) to a **Markov decision process* or MDP, in which our decision don't just yield rewards/costs, but can change the state of the world around us.

MDPs are the foundation of RL, as you might have already heard. 

<div style="text-align:center">
    <img src="./figures/f6.png" alt="Generative Model" width="400"/>
</div>

<div>
<img src="https://drive.google.com/uc?export=view&id=1hhMFTuRYNu5Yh8GogIwuYEsgNg9e0CHU" width="500"/>
</div>

Let us start by expanding the bee's environment into a 2d patch that it can forage in, which has flowers of different colors signalling different contexts (forests or fields). Additionally, we will also define the rewards available from different flowers, as a function of context.

In [None]:
class Patch:
    # Create 2d flower color pattern
    def __init__(self):
        X,Y = np.meshgrid(np.linspace(-3,3,100),np.linspace(-3,3,100))
        F = -(X**5 - Y**5 - 2 * np.matmul(X, Y) - 2 * X**3 - 3 * Y**3) * (
            0.9 * np.exp(-0.4 * (X - 1.5)**2 - 0.5 * (Y - 1.5)**2) -
            0.9 * np.exp(-(X + 2)**2 - (Y + 2)**2 - 0.8 * (X + 2) * (Y + 2)) -
            0.7 * np.exp(-(X + 2)**2 - 0.5 * (Y - 2)**2 + 0.3 * (X + 2) * (Y - 2)) -
            0.5 * np.exp(-3 * (X - 2)**6 - 4 * (Y + 2)**2 + 4 * (X - 2) * (Y + 2)**2))     
        F = 1-(F- np.min(F))/(np.max(F)-np.min(F))
        self.X = X
        self.Y = Y
        self.F = F

    # Flower color as a function of location on patch
    def f(self,xCurr,yCurr,xNext,yNext):
        iCurr = np.max(np.where(self.X<xCurr)[1])
        jCurr = np.max(np.where(self.Y<yCurr)[0])
        iNext = np.max(np.where(self.X<xNext)[1])
        jNext = np.max(np.where(self.Y<yNext)[0])
        return self.F[iCurr,jCurr], self.F[iNext,jNext]

    # Nectar as a function of flower color in "field" or "forest" environments
    def nectar(self,f,context):
        if context =='field':
            return 1 - norm.cdf(f, loc = 0.5,scale = 0.15)
        elif context =='forest':
            return norm.cdf(f, loc = 0.5,scale = 0.15)


Let's take a look at this patch, and the available nectar in different contexts

In [None]:
### Create a patch for the bee to forage in
patch = Patch()

# Plot flower frequency patterns in patch
plt.figure(figsize=(12,5))
plt.subplot(121,title = 'Foraging patch')
plt.contourf(patch.F,levels=10,cmap=plt.cm.get_cmap('Spectral'))
plt.colorbar(orientation = 'horizontal', label = 'Flower color')
plt.axis('off')

# Plot nectar function in the two contexts
f = np.arange(0,1,0.01)
colors = {'forest':[0,0,1],'field':[1,0,0]}

plt.subplot(122,title = 'Nectar function')
for context in {'forest','field'}:
    plt.plot(f,patch.nectar(f,context),color = colors[context],label = context)
plt.xlabel('Flower color')
plt.ylabel('Nectar')
sns.despine()
plt.legend()
plt.show()

Let us now define a bee with the following functions:

1. An **inference function** that lets it infer which context it is in, based on the flower colors it observes. To keep it simple, lets assume the bee directly observes flower color and thresholds it directly to decide which context it is in (as in section 1).

2. A **value function** - this is similar to the reward/utility functions we have seen before, and this tells the bee how *valuable* different flowers are i.e. how much nectar reward it can expect from different flower colors in different contexts.

3. An **action policy** - this is the function that determines how the bee will act, consequently changing the context it is in. We are going to use a very simple policy, where the bee is going to keep going straight if utility/value increases ahead of it, and reorient randomly if it decreases.

In [None]:
class Bee:
    def __init__(self,params):
        self.params = params
        
    def inferenceFunc(self,f):
        return 'forest' if f>self.params['sensoryThreshold'] else 'field'

    #Linear approximation to utility/value function 
    def valueFunc(self,f,context):
        if np.size(f)>0:
            return self.params['w'][context][1]*f+self.params['w'][context][0]
        else:
            return -np.inf #Boundary condition

    #Stochastic action policy
    def policyFunc(self,vTminus1,vT):
        
        # Softmax decision to reorient
        if np.random.rand()<1/(1+np.exp(self.params['softmaxTemp']*(vT-vTminus1-self.params['softmaxThreshold']))):

            # Von-mises distributed reorienting angles
            deltaTheta =  np.random.vonmises(self.params['reorientMu'],self.params['reorientKappa'])
            
            # Turn clockwise (cw), counter-clockwise (ccw), or (both) as determined by 'policyFuncType'
            action = {'cw': deltaTheta, 'ccw': -deltaTheta, 'both': np.random.choice([deltaTheta,-deltaTheta])}
            
            return action[self.params['policyFuncType']]
        else:
            # Continue straight if not reorienting
            return 0


## Policy with known values
Let's instantiate a bee with custom-defined action policy parameters, with a fixed (known) value function. Note that we are approximating the *true* nectar function with a linear value function.

In [None]:
### Create a bee with known, context dependent value function weights 'w'
bee = Bee(
    params = {
        'sensoryThreshold': 0.5,
        'policyFuncType': 'cw',
        'softmaxTemp':1,
        'softmaxThreshold':0,
        'reorientMu':np.pi/2, 
        'reorientKappa':10,
        'stepSize':0.5,
        'discountGamma':0,
        'learningRate':0,
        'w':{'forest': np.array([0,1]), 'field':np.array([1,-1])}
        }
    )


# Plot policy functions - reorienting probability & angles
fig1 = plt.figure(figsize=(10, 3))
plt.subplot(121,title = 'Policy: Reorienting probability')
plt.plot(
    np.linspace(-1,1,100),
    1/(1+np.exp(bee.params['softmaxTemp']*(np.linspace(-1,1,100)-bee.params['softmaxThreshold']))))
plt.xlabel('Change in value')
plt.ylabel('p (Reorient)')

plt.subplot(122,title = 'Policy: Reorienting angles')
plt.fill(
    np.linspace(0,np.pi,100), 
    vonmises.pdf(np.linspace(0,np.pi,100),bee.params['reorientKappa'],loc = bee.params['reorientMu']),
    alpha = 0.4)
plt.xticks([0,np.pi/2,np.pi],[0,90,180])
plt.xlabel('Change in orientation')
sns.despine()

# Plot value functions (assumed known)
fig2 = plt.figure(figsize=(10, 3))
plt.subplot(121,title = 'Value in field')
plt.plot(f,bee.valueFunc(f,'field'),color = 'r')
plt.ylabel('Value')
plt.xlabel('Flower color')
plt.subplot(122,title = 'Value in forest')
plt.plot(f,bee.valueFunc(f,'forest'),color = 'b')
plt.xlabel('Flower color')
sns.despine()
plt.show()


Let's simulate the behavior of the bee on the patch for a single trial with 50 timesteps!

In [None]:
class SimulateBee:
    def __init__(self, bee, patch, simPars):
        self.bee = bee
        self.patch = patch
        self.simPars = simPars
        self.f = np.arange(0, 1, 0.01)  
        self.fig, (self.axp, self.axv) = plt.subplots(1, 2,figsize=(12,5))
        
        # Initial position & orientation
        self.xCurr, self.yCurr = 0, 0
        self.theta = simPars['thetaInit']
        self.xNext, self.yNext = self.next_position(self.xCurr, self.yCurr, self.theta)
        self.evaluate_position()
        
    def next_position(self, xCurr, yCurr, theta):
        # Next position based on current position & orientation
        return xCurr + self.bee.params['stepSize'] * np.cos(theta), yCurr + self.bee.params['stepSize'] * np.sin(theta)
        
    def evaluate_position(self):
        # Sensory observations
        self.fCurr, self.fNext = self.patch.f(self.xCurr, self.yCurr, self.xNext, self.yNext)
        # Context inference
        self.context = self.bee.inferenceFunc(self.fCurr)
        # Evalutation of sensory observations given context
        self.vCurr, self.vNext = self.bee.valueFunc(self.fCurr, self.context), self.bee.valueFunc(self.fNext, self.context)
    
    def plan_and_select_action(self, xPrev, yPrev):
        # Planning and action selection 
        plan = True
        while plan:
            # Prospective position, orientation
            deltaTheta = self.bee.policyFunc(self.vCurr, self.vNext)
            self.theta += deltaTheta
            self.xCurr, self.yCurr = self.next_position(xPrev, yPrev, self.theta)
            self.xNext, self.yNext = self.next_position(self.xCurr, self.yCurr, self.theta)

            # Check if the new position is within the bounds of the patch
            if (self.xNext > np.min(self.patch.X) and self.xNext < np.max(self.patch.X) and
                self.yNext > np.min(self.patch.Y) and self.yNext < np.max(self.patch.Y)):
                self.evaluate_position()
                plan = False
            else:
                self.vNext = -np.inf
        
    
    def update(self, t):
        
        # Plot value function
        self.axv.clear()
        current_context = self.context
        other_context = 'forest' if current_context == 'field' else 'field'
        alpha = {
            current_context:1,
            other_context:0.1
        }
        for context in ['forest','field']:
            self.axv.plot(
                self.f, 
                self.bee.valueFunc(self.f, context), 
                color=colors[context], 
                label = context,
                alpha = alpha[context])
        self.axv.set_title('Value in ' + self.context)
        self.axv.set_xlabel('Flower color')
        self.axv.set_ylabel('Value') 
        self.axv.set_ylim(0, 3)
        self.axv.legend()
        sns.despine(ax = self.axv)

        # Plot patch & bee position
        self.axp.clear()
        self.axp.axis('off')
        self.fig.set_facecolor('white')
        self.axp.contourf(self.patch.X, self.patch.Y, self.patch.F, levels=10, cmap='Spectral')
        self.axp.set_title(self.context)

        # Planning and action selection
        xPrev, yPrev = self.xCurr, self.yCurr
        self.plan_and_select_action(xPrev, yPrev)

        # Plot movement to new location
        self.axp.quiver(xPrev, yPrev, self.xCurr - xPrev, self.yCurr - yPrev, scale=7, width=0.003)

    def animate(self):
        anim = FuncAnimation(self.fig, self.update, frames=self.simPars['tMax'], interval=200)
        return anim



In [None]:
# Run simulation. 
simPars = {
    'tMax': 100, 
    'thetaInit': np.random.vonmises(0, 0.1)
    }
contexts = ['field', 'forest']

%matplotlib notebook
sim = SimulateBee(bee, patch, simPars)
anim = sim.animate()
html_content = anim.to_jshtml()
display(HTML(f"{html_content}"))

What do you notice about the bee's inferred context & value function?

> log your observations here

## Learning values with reinforcement learning

We now turn to the final challenge - what happens if you *don't* know the value function, but have to learn it from experience with rewards in the environment? This is the problem of reinforcement learning. Thing get tricky *very fast* when we mix uncertainty about sensory context and uncertainty about value - so we will use a popular assumption that you can treat them in sequence - first do contextaul inferences and then learn value conditional on the inferred context.

Here, we are using a popular RL technique called temporal difference learning, in order to learn *state values* that is values as a function of the sensory state - in this case flower color. We are also going to be using function approximation, to approximate the true value with a linear function. Play around with the policy - specifically the temperature, which controls exploration. You will see how easy it is to get stuck in a situation where you are unable to learn any further!

In [None]:
class BeeWithLearning(Bee):
    
    def learningFunc(self, f, context, vTminus1, vT, reward):
        #Temporally discounted reward prediction error
        delta = reward + self.params['discountGamma']*vT - vTminus1
        # Gradient of value w.r.t. weights
        gradfun = np.array([1, f])
        # Weight update using Temporal Difference learning rule
        self.params['w'][context] = self.params['w'][context] + delta*self.params['learningRate']*gradfun

In [None]:
class SimulateBeeWithLearning(SimulateBee):

    def update(self, t):

        # Plot current learned estimate of value function
        self.axv.clear()
        current_context = self.context
        other_context = 'forest' if current_context == 'field' else 'field'
        alpha = {
            current_context:1,
            other_context:0.1
        }
        for context in ['forest','field']:
            self.axv.plot(
                self.f, 
                self.bee.valueFunc(self.f, context),
                color=colors[context], 
                label = context,
                alpha = alpha[context]
                )
        self.axv.set_title('Value in ' + self.context)
        self.axv.set_xlabel('Flower color')
        self.axv.set_ylabel('Value')
        self.axv.set_ylim(0, 3)
        self.axv.legend()
        sns.despine(ax = self.axv)

        # Plot patch environment
        self.axp.clear()
        self.axp.contourf(self.patch.X, self.patch.Y, self.patch.F, levels=10, cmap='Spectral')
        self.axp.set_title(self.context)

        # Planning and action selection
        xPrev, yPrev = self.xCurr, self.yCurr
        vPrev = self.bee.valueFunc(self.fCurr, self.context)
        self.plan_and_select_action(xPrev, yPrev)

        # ~~~~~~~~~~~~~~~~~~~~~ LEARNING RULE~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Observe reward outcome of action, update weights based on learning rule
        reward = self.patch.nectar(self.fCurr, self.context)
        self.bee.learningFunc(self.fCurr, self.context, vPrev, self.vCurr, reward)

        # Plot movement to new location
        self.axp.quiver(xPrev, yPrev, self.xCurr - xPrev, self.yCurr - yPrev, scale=7, width=0.003)
        sns.despine(ax = self.axp)
        self.fig.set_facecolor("white")
        self.fig.tight_layout()
        

    def animate(self):
        anim = FuncAnimation(self.fig, self.update, frames=self.simPars['tMax'], interval=200)
        plt.show()
        return anim


In [None]:
# Assuming createBee, createPatch, and other necessary classes and functions are defined elsewhere
# TRY CHANGING THE SOFTMAX TEMP! WHEN VERY HIGH, THE BEE WON'T EXPLORE THE OTHER CONTEXT- HENCE WON'T LEARN ABOUT IT
bee2 = BeeWithLearning(
    params = {
        'sensoryThreshold':0.5,
        'policyFuncType': 'both',
        'softmaxTemp': 1, 
        'softmaxThreshold': 0,
        'reorientMu': np.pi/2, 
        'reorientKappa': 10, 
        'stepSize': 0.5,
        'discountGamma': 0.6, 
        'learningRate': 0.5,
        'w': {'forest': np.array([0, 0]), 'field': np.array([0, 0])}
        }
    )

patch = Patch()  
### Simulate 200 timesteps
simPars = {'tMax': 200, 'thetaInit': np.random.vonmises(0, 0.1)}
simulator = SimulateBeeWithLearning(bee2, patch, simPars)
anim = simulator.animate()
html_content = anim.to_jshtml()
display(HTML(f"{html_content}"))