<a href="https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/W2D1-postcourse-bugfix/tutorials/W2D1_BayesianStatistics/W2D1_Tutorial1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neuromatch Academy: Week 3, Day 1, Tutorial 1
# Bayes with a binary hidden state

__Content creators:__ [insert your name here]

__Content reviewers:__ 

# Tutorial Objectives
This is the first in a series of two core tutorials on Bayesian statistics. In these tutorials, we will explore the fundemental concepts of the Bayesian approach from two perspectives. This tutorial will work through an example of Bayesian inference and decision making using a binary hidden state. The second main tutorial extends these concepts to a continuous hidden state. In the next days, each of these basic ideas will be extended--first through time as we consider what happens when we infere a hidden state using multiple observations and when the hidden state changes across time. In the third day, we will introduce the notion of how to use inference and decisions to select actions for optimal control. For this tutorial, you will be introduced to our binary state fishing problem!

This notebook will introduce the fundamental building blocks for Bayesian statistics: 
1. 
2. 
3. 
4. 


In [None]:
#@title Video 1: Introduction to Bayesian Statistics
from IPython.display import YouTubeVideo
# video = YouTubeVideo(id='K4sSKZtk-Sc', width=854, height=480, fs=1)
# print("Video available at https://youtube.com/watch?v=" + video.id)
# video

## Setup  
Please execute the cells below to initialize the notebook environment.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib import transforms
from matplotlib import gridspec
from scipy.optimize import fsolve

from collections import namedtuple

In [None]:
#@title Figure Settings
import ipywidgets as widgets       # interactive display
from ipywidgets import GridspecLayout
from IPython.display import clear_output
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle")

In [17]:
def compute_marginal(px, py, cor):
    # calculate 2x2 joint probabilities given marginals p(x=1), p(y=1) and correlation
    p11 = px*py + cor*np.sqrt(px*py*(1-px)*(1-py))
    p01 = px - p11
    p10 = py - p11
    p00 = 1.0 - p11 - p01 - p10
    return np.asarray([[p00, p01], [p10, p11]])
# test
# print(compute_marginal(0.4, 0.6, -0.8))

def compute_cor_range(px,py):
    # Calculate the allowed range of correlation values given marginals p(x=1) and p(y=1)
    def p11(corr):
        return px*py + corr*np.sqrt(px*py*(1-px)*(1-py))
    def p01(corr):
        return px - p11(corr)
    def p10(corr):
        return py - p11(corr)
    def p00(corr):
        return 1.0 - p11(corr) - p01(corr) - p10(corr)
    Cmax = min(fsolve(p01, 0.0), fsolve(p10, 0.0))
    Cmin = max(fsolve(p11, 0.0), fsolve(p00, 0.0))
    return Cmin, Cmax

def plot_joint_probs(P, ):
    assert np.all(P >= 0), "probabilities should be >= 0"
    # normalize if not
    P = P / np.sum(P)
    marginal_y = np.sum(P,axis=1)
    marginal_x = np.sum(P,axis=0)

    # definitions for the axes
    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    spacing = 0.005

    # start with a square Figure
    fig = plt.figure(figsize=(5, 5))

    joint_prob = [left, bottom, width, height]
    rect_histx = [left, bottom + height + spacing, width, 0.2]
    rect_histy = [left + width + spacing, bottom, 0.2, height]

    rect_x_cmap = plt.cm.Blues
    rect_y_cmap = plt.cm.Reds

    # Show joint probs and marginals
    ax = fig.add_axes(joint_prob)
    ax_x = fig.add_axes(rect_histx, sharex=ax)
    ax_y = fig.add_axes(rect_histy, sharey=ax)

    # Show joint probs and marginals
    ax.matshow(P,vmin=0., vmax=1., cmap='Greys')
    ax_x.bar(0, marginal_x[0], facecolor=rect_x_cmap(marginal_x[0]))
    ax_x.bar(1, marginal_x[1], facecolor=rect_x_cmap(marginal_x[1]))
    ax_y.barh(0, marginal_y[0], facecolor=rect_y_cmap(marginal_y[0]))
    ax_y.barh(1, marginal_y[1], facecolor=rect_y_cmap(marginal_y[1]))
    # set limits
    ax_x.set_ylim([0,1])
    ax_y.set_xlim([0,1])

    # show values 
    ind = np.arange(2)
    x,y = np.meshgrid(ind,ind)
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{P[i,j]:.2f}"
        ax.text(j,i, c, va='center', ha='center', color='black')
    for i in ind:
        v = marginal_x[i]
        c = f"{v:.2f}"
        ax_x.text(i, v +0.1, c, va='center', ha='center', color='black')
        v = marginal_y[i]
        c = f"{v:.2f}"
        ax_y.text(v+0.2, i, c, va='center', ha='center', color='black')

    # set up labels
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()
    ax.set_xticks([0,1])
    ax.set_yticks([0,1])
    ax.set_xticklabels(['R','B'])
    ax.set_yticklabels(['0','1'])
    ax.set_xlabel('color')
    ax.set_ylabel('size')
    ax_x.axis('off')   
    ax_y.axis('off')   
    return fig
# test
# P = np.random.rand(2,2)
# P = np.asarray([[0.9, 0.8], [0.4, 0.1]])
# P = P / np.sum(P)
# fig = plot_joint_probs(P)
# plt.show(fig)
# plt.close(fig)

def plot_prior_likelihood(ps, cor):
    likelihood = np.asarray([[cor, 1-cor],[1-cor,cor]])
    assert 0.0 <= ps <= 1.0
    assert 0.0 <= cor <= 1.0
    prior = np.asarray([1-ps, ps])

    posterior = likelihood * prior.reshape(1,2)

    fig, ax_all = plt.subplots(1, 3, figsize=(12, 4))
    ax_prior, ax_likelihood, ax_posterior = ax_all
    ax_prior.set_title("prior distribution", pad=15)
    ax_likelihood.set_title("likelihood")
    ax_posterior.set_title("posterior distribution")

    rect_colormap = plt.cm.Blues

    # Show posterior probs and marginals
    ax_prior.bar(0, prior[0], facecolor = rect_colormap(prior[0]))
    ax_prior.bar(1, prior[1], facecolor = rect_colormap(prior[1]))
    ax_likelihood.matshow(likelihood, vmin=0., vmax=1., cmap='Reds')
    ax_posterior.matshow(posterior, vmin=0., vmax=1., cmap='Greys')

    for ax in ax_all:
        ax.set_xticks([0, 1])
        ax.set_yticks([0, 1])
        ax.set_xticklabels([0, 1])
        ax.set_yticklabels([0, 1])
    ax_posterior.xaxis.set_ticks_position('bottom')
    ax_likelihood.xaxis.set_ticks_position('bottom')

    # show values 
    ind = np.arange(2)
    x,y = np.meshgrid(ind,ind)
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{posterior[i,j]:.2f}"
        ax_posterior.text(j,i, c, va='center', ha='center', color='black')
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{likelihood[i,j]:.2f}"
        ax_likelihood.text(j,i, c, va='center', ha='center', color='black')
    for i in ind:
        v = prior[i]
        c = f"{v:.2f}"
        ax_prior.text(i, v +0.05, c, va='center', ha='center', color='black')

    # set up labels
    ax_prior.set_xlabel("s")
    ax_likelihood.set_xlabel("s")
    ax_likelihood.set_ylabel("m")
    ax_posterior.set_xlabel("s")
    ax_posterior.set_ylabel("m")
    return fig


# fig = plot_prior_likelihood(0.5, 0.3)
# plt.show(fig)
# plt.close(fig)

def plot_prior_likelihood_utility(ps, p_a_s1, p_a_s0, loss_s, gain_s):
    likelihood = np.asarray([[p_a_s1, 1-p_a_s1],[1-p_a_s0,p_a_s0]]).T
    assert 0.0 <= ps <= 1.0
    assert 0.0 <= p_a_s1 <= 1.0
    assert 0.0 <= p_a_s0 <= 1.0
    prior = np.asarray([1-ps, ps])

    utility = np.array([[-loss_s, loss_s-1], [gain_s, 1-gain_s]]).T
    posterior = likelihood * prior.reshape(1,2)
    posterior, likelihood = posterior.T, likelihood.T
    expected = np.multiply(utility, posterior)

    # definitions for the axes
    left, width = 0.05, 0.16
    bottom, height = 0.05, 0.9
    padding = 0.04
    small_width = 0.1
    left_space = left + small_width + padding
    added_space = padding + width

    fig = plt.figure(figsize=(17, 3))

    rect_prior = [left, bottom, small_width, height]
    rect_likelihood = [left_space, bottom , width, height]
    rect_posterior = [left_space + added_space, bottom , width, height]
    rect_utility = [left_space + 2*added_space, bottom , width, height]
    rect_expected = [left_space + 3*added_space, bottom , width, height]

    ax_likelihood = fig.add_axes(rect_likelihood)
    ax_prior = fig.add_axes(rect_prior, sharey=ax_likelihood)
    ax_posterior = fig.add_axes(rect_posterior, sharey=ax_likelihood)
    ax_utility = fig.add_axes(rect_utility, sharey=ax_posterior)
    ax_expected = fig.add_axes(rect_expected, sharey=ax_utility)

    ax_prior.set_title("prior distribution", pad=15)
    ax_likelihood.set_title("likelihood")
    ax_posterior.set_title("posterior distribution")
    ax_utility.set_title("utility function")
    ax_expected.set_title("expected utility")

    rect_colormap = plt.cm.Blues

    # Show posterior probs and marginals
    ax_prior.barh(0, prior[0], facecolor = rect_colormap(prior[0]))
    ax_prior.barh(1, prior[1], facecolor = rect_colormap(prior[1]))
    ax_likelihood.matshow(likelihood, vmin=0., vmax=1., cmap='Reds')
    ax_posterior.matshow(posterior, vmin=0., vmax=1., cmap='Greys')
    ax_utility.matshow(utility, vmin=0., vmax=1., cmap='cool')
    ax_expected.matshow(expected, vmin=0., vmax=1., cmap='Wistia')

    for ax in [ax_prior, ax_likelihood, ax_posterior, ax_expected]:
        ax.set_xticks([0, 1])
        ax.set_yticks([0, 1])
        ax.set_xticklabels([0, 1])
        ax.set_yticklabels([0, 1])

    ax_utility.set_xticks([0, 1])
    ax_utility.set_yticks([0, 1])
    ax_utility.set_xticklabels(["loss", "gain"])
    # ax_utility.set_yticklabels([0, 1])

    ax_posterior.xaxis.set_ticks_position('bottom')
    ax_likelihood.xaxis.set_ticks_position('bottom')
    ax_utility.xaxis.set_ticks_position('bottom')
    ax_expected.xaxis.set_ticks_position('bottom')

    ax_prior.set_xlim([1, 0])

    # show values 
    ind = np.arange(2)
    x,y = np.meshgrid(ind,ind)
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{posterior[i,j]:.2f}"
        ax_posterior.text(j,i, c, va='center', ha='center', color='black')
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{likelihood[i,j]:.2f}"
        ax_likelihood.text(j,i, c, va='center', ha='center', color='black')
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{utility[i,j]:.2f}"
        ax_utility.text(j,i, c, va='center', ha='center', color='black')
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{expected[i,j]:.2f}"
        ax_expected.text(j,i, c, va='center', ha='center', color='black')
    for i in ind:
        v = prior[i]
        c = f"{v:.2f}"
        ax_prior.text(v+0.2, i, c, va='center', ha='center', color='black')


    # # show values 
    # ind = np.arange(2)
    # x,y = np.meshgrid(ind,ind)
    # for i,j in zip(x.flatten(), y.flatten()):
    #     c = f"{P[i,j]:.2f}"
    #     ax.text(j,i, c, va='center', ha='center', color='white')
    # for i in ind:
    #     v = marginal_x[i]
    #     c = f"{v:.2f}"
    #     ax_x.text(i, v +0.2, c, va='center', ha='center', color='black')
    #     v = marginal_y[i]
    #     c = f"{v:.2f}"
    #     ax_y.text(v+0.2, i, c, va='center', ha='center', color='black')

    # set up labels
    ax_prior.set_xlabel("m")
    ax_likelihood.set_xlabel("m")
    ax_likelihood.set_ylabel("s")
    ax_posterior.set_xlabel("m")
    ax_posterior.set_ylabel("s")
    ax_utility.set_ylabel("s")
    # ax_expected.set_xlabel("m")
    ax_expected.set_ylabel("s")
    ax_utility.set_xlabel("a")
    ax_prior.axis('off')
    return fig


# fig = plot_prior_likelihood(0.5, 0.3)
# plt.show(fig)
# plt.close(fig)


# Section 1: The Binary hidden state

[description of the problem, etc]

## Video : Observations and costs

In [20]:
cor_widget = widgets.FloatSlider(0.3, description='ρ', min=0.0, max=1.0, step=0.01, disabled=False)

@widgets.interact(
    cor=cor_widget,
)
def make_prior_likelihood_plot(cor):
    fig = plot_prior_likelihood(0.5,cor)
    plt.show(fig)
    plt.close(fig)
    return None


interactive(children=(FloatSlider(value=0.3, description='ρ', max=1.0, step=0.01), Output()), _dom_classes=('w…

# Section 2: Correlation and marginalization


## Video : 

In [12]:
gs = GridspecLayout(2,2)

cor_widget = widgets.FloatSlider(0.3, description='ρ', min=-1, max=1, step=0.01)
px_widget = widgets.FloatSlider(0.5, description='p(x)', min=0.01, max=0.99, step=0.01)
py_widget = widgets.FloatSlider(0.5, description='p(y)', min=0.01, max=0.99, step=0.01)
gs[0,0] = cor_widget
gs[0,1] = px_widget
gs[1,0] = py_widget


@widgets.interact(
    px=px_widget,
    py=py_widget,
    cor=cor_widget,
)
def make_corr_plot(px, py, cor):
    Cmin, Cmax = compute_cor_range(px, py) #allow correlation values
    cor_widget.min, cor_widget.max = Cmin+0.01, Cmax-0.01
    if cor_widget.value > Cmax:
        cor_widget.value = Cmax
    if cor_widget.value < Cmin:
        cor_widget.value = Cmin
    cor = cor_widget.value
    P = compute_marginal(px,py,cor)
    # print(P)
    fig = plot_joint_probs(P) 
    plt.show(fig)
    plt.close(fig)
    return None

# gs[1,1] = make_corr_plot()

interactive(children=(FloatSlider(value=0.5, description='p(x)', layout=Layout(grid_area='widget002'), max=0.9…

# Section 3. Bayes' Theorem and the Posterior

## Video : Bayes' Theorem

In [None]:
#@title Video 2: Bayes' theorem
from IPython.display import YouTubeVideo
# video = YouTubeVideo(id='ewQPHQMcdBs', width=854, height=480, fs=1)
# print("Video available at https://youtube.com/watch?v=" + video.id)
video

In [15]:
cor_widget = widgets.FloatSlider(0.3, description='ρ', min=0.0, max=1.0, step=0.01, disabled=False)
ps_widget = widgets.FloatSlider(0.5, description='p(s)', min=0.0, max=1.0, step=0.01)

@widgets.interact(
    ps=ps_widget,
    cor=cor_widget,
)
def make_prior_likelihood_plot(ps,cor):
    fig = plot_prior_likelihood(ps,cor)
    plt.show(fig)
    plt.close(fig)
    return None


interactive(children=(FloatSlider(value=0.5, description='p(s)', max=1.0, step=0.01), FloatSlider(value=0.3, d…

Math questions

## Interactive Demo: What affects the posterior?

Now that we can play with the effects of *Bayes rule*, let's vary the parameters of the prior to see how changing the prior and likelihood affect the posterior. 

**Hit the Play button or Ctrl+Enter in the cell below** and play with the sliders to get an intuition for how the means and standard deviations of prior and likelihood influence the posterior.

When does the prior have the strongest influence over the posterior? When is it the weakest?  

In [None]:
#@title
from IPython.display import YouTubeVideo

video = YouTubeVideo(id='AbXorOLBrws', width=854, height=480, fs=1)
print("Video available at https://youtube.com/watch?v=" + video.id)
video

# Section 3: Bayesian decisions

We will explore how taking an action based on our belief (the posterior distribution) over where we think the fish might be effects the expected gain or loss.


In [18]:
ps_widget = widgets.FloatSlider(0.5, description='p(s)', min=0.0, max=1.0, step=0.01)
p_a_s1_widget = widgets.FloatSlider(0.5, description='p(fish | a = s)', min=0.0, max=1.0, step=0.01)
p_a_s0_widget = widgets.FloatSlider(0.5, description='p(fish | a != s)', min=0.0, max=1.0, step=0.01)
loss_s_widget = widgets.FloatSlider(0.5, description='loss (a = s)', min=0.0, max=1.0, step=0.01)
gain_s_widget = widgets.FloatSlider(0.5, description='gain (a = s)', min=0.0, max=1.0, step=0.01)

@widgets.interact(
    ps=ps_widget,
    p_a_s1=p_a_s1_widget,
    p_a_s0=p_a_s0_widget,
    loss_s=loss_s_widget,
    gain_s=gain_s_widget,
)
def make_prior_likelihood_utility_plot(ps, p_a_s1, p_a_s0, loss_s, gain_s):
    fig = plot_prior_likelihood_utility(ps, p_a_s1, p_a_s0, loss_s, gain_s)
    plt.show(fig)
    plt.close(fig)
    return None


interactive(children=(FloatSlider(value=0.5, description='p(s)', max=1.0, step=0.01), FloatSlider(value=0.5, d…