<a href="https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/W2D1-postcourse-bugfix/tutorials/W2D1_BayesianStatistics/W2D1_Tutorial1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neuromatch Academy: Week 3, Day 1, Tutorial 1
# Bayes with a binary hidden state

__Content creators:__ [insert your name here]

__Content reviewers:__ 

# Tutorial Objectives
This is the first in a series of two core tutorials on Bayesian statistics. In these tutorials, we will explore the fundemental concepts of the Bayesian approach from two perspectives. This tutorial will work through an example of Bayesian inference and decision making using a binary hidden state. The second main tutorial extends these concepts to a continuous hidden state. In the next days, each of these basic ideas will be extended--first through time as we consider what happens when we infere a hidden state using multiple observations and when the hidden state changes across time. In the third day, we will introduce the notion of how to use inference and decisions to select actions for optimal control. For this tutorial, you will be introduced to our binary state fishing problem!

This notebook will introduce the fundamental building blocks for Bayesian statistics: 
1. 
2. 
3. 
4. 


In [None]:
#@title Video 1: Introduction to Bayesian Statistics
from IPython.display import YouTubeVideo
# video = YouTubeVideo(id='K4sSKZtk-Sc', width=854, height=480, fs=1)
# print("Video available at https://youtube.com/watch?v=" + video.id)
# video

## Setup  
Please execute the cells below to initialize the notebook environment.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib import transforms
from matplotlib import gridspec
from scipy.optimize import fsolve

from collections import namedtuple

In [None]:
#@title Figure Settings
import ipywidgets as widgets       # interactive display
from ipywidgets import GridspecLayout
from IPython.display import clear_output
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle")

In [None]:
def compute_marginal(px, py, cor):
    # calculate 2x2 joint probabilities given marginals p(x=1), p(y=1) and correlation
    p11 = px*py + cor*np.sqrt(px*py*(1-px)*(1-py))
    p01 = px - p11
    p10 = py - p11
    p00 = 1.0 - p11 - p01 - p10
    return np.asarray([[p00, p01], [p10, p11]])
# test
# print(compute_marginal(0.4, 0.6, -0.8))

def compute_cor_range(px,py):
    # Calculate the allowed range of correlation values given marginals p(x=1) and p(y=1)
    def p11(corr):
        return px*py + corr*np.sqrt(px*py*(1-px)*(1-py))
    def p01(corr):
        return px - p11(corr)
    def p10(corr):
        return py - p11(corr)
    def p00(corr):
        return 1.0 - p11(corr) - p01(corr) - p10(corr)
    Cmax = min(fsolve(p01, 0.0), fsolve(p10, 0.0))
    Cmin = max(fsolve(p11, 0.0), fsolve(p00, 0.0))
    return Cmin, Cmax

def plot_joint_probs(P, ):
    assert np.all(P >= 0), "probabilities should be >= 0"
    # normalize if not
    P = P / np.sum(P)
    marginal_y = np.sum(P,axis=1)
    marginal_x = np.sum(P,axis=0)

    # definitions for the axes
    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    spacing = 0.005

    # start with a square Figure
    fig = plt.figure(figsize=(5, 5))

    joint_prob = [left, bottom, width, height]
    rect_histx = [left, bottom + height + spacing, width, 0.2]
    rect_histy = [left + width + spacing, bottom, 0.2, height]

    rect_x_cmap = plt.cm.Blues
    rect_y_cmap = plt.cm.Reds

    # Show joint probs and marginals
    ax = fig.add_axes(joint_prob)
    ax_x = fig.add_axes(rect_histx, sharex=ax)
    ax_y = fig.add_axes(rect_histy, sharey=ax)

    # Show joint probs and marginals
    ax.matshow(P,vmin=0., vmax=1., cmap='Greys')
    ax_x.bar(0, marginal_x[0], facecolor=rect_x_cmap(marginal_x[0]))
    ax_x.bar(1, marginal_x[1], facecolor=rect_x_cmap(marginal_x[1]))
    ax_y.barh(0, marginal_y[0], facecolor=rect_y_cmap(marginal_y[0]))
    ax_y.barh(1, marginal_y[1], facecolor=rect_y_cmap(marginal_y[1]))
    # set limits
    ax_x.set_ylim([0,1])
    ax_y.set_xlim([0,1])

    # show values 
    ind = np.arange(2)
    x,y = np.meshgrid(ind,ind)
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{P[i,j]:.2f}"
        ax.text(j,i, c, va='center', ha='center', color='black')
    for i in ind:
        v = marginal_x[i]
        c = f"{v:.2f}"
        ax_x.text(i, v +0.1, c, va='center', ha='center', color='black')
        v = marginal_y[i]
        c = f"{v:.2f}"
        ax_y.text(v+0.2, i, c, va='center', ha='center', color='black')

    # set up labels
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()
    ax.set_xticks([0,1])
    ax.set_yticks([0,1])
    ax.set_xticklabels(['R','B'])
    ax.set_yticklabels(['0','1'])
    ax.set_xlabel('color')
    ax.set_ylabel('size')
    ax_x.axis('off')   
    ax_y.axis('off')   
    return fig
# test
# P = np.random.rand(2,2)
# P = np.asarray([[0.9, 0.8], [0.4, 0.1]])
# P = P / np.sum(P)
# fig = plot_joint_probs(P)
# plt.show(fig)
# plt.close(fig)

def plot_prior_likelihood(ps, cor):
    likelihood = np.asarray([[cor, 1-cor],[1-cor,cor]])
    assert 0.0 <= ps <= 1.0
    assert 0.0 <= cor <= 1.0
    prior = np.asarray([1-ps, ps])

    posterior = likelihood * prior.reshape(1,2)

    fig, ax_all = plt.subplots(1, 3, figsize=(12, 4))
    ax_prior, ax_likelihood, ax_posterior = ax_all
    ax_prior.set_title("prior distribution", pad=15)
    ax_likelihood.set_title("likelihood")
    ax_posterior.set_title("posterior distribution")

    rect_colormap = plt.cm.Blues

    # Show posterior probs and marginals
    ax_prior.bar(0, prior[0], facecolor = rect_colormap(prior[0]))
    ax_prior.bar(1, prior[1], facecolor = rect_colormap(prior[1]))
    ax_likelihood.matshow(likelihood, vmin=0., vmax=1., cmap='Reds')
    ax_posterior.matshow(posterior, vmin=0., vmax=1., cmap='Greys')

    for ax in ax_all:
        ax.set_xticks([0, 1])
        ax.set_yticks([0, 1])
        ax.set_xticklabels([0, 1])
        ax.set_yticklabels([0, 1])
    ax_posterior.xaxis.set_ticks_position('bottom')
    ax_likelihood.xaxis.set_ticks_position('bottom')

    # show values 
    ind = np.arange(2)
    x,y = np.meshgrid(ind,ind)
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{posterior[i,j]:.2f}"
        ax_posterior.text(j,i, c, va='center', ha='center', color='black')
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{likelihood[i,j]:.2f}"
        ax_likelihood.text(j,i, c, va='center', ha='center', color='black')
    for i in ind:
        v = prior[i]
        c = f"{v:.2f}"
        ax_prior.text(i, v +0.05, c, va='center', ha='center', color='black')

    # set up labels
    ax_prior.set_xlabel("s")
    ax_likelihood.set_xlabel("s")
    ax_likelihood.set_ylabel("m")
    ax_posterior.set_xlabel("s")
    ax_posterior.set_ylabel("m")
    return fig


# fig = plot_prior_likelihood(0.5, 0.3)
# plt.show(fig)
# plt.close(fig)

def plot_prior_likelihood_utility(ps, p_a_s1, p_a_s0, loss_s, gain_s):
    likelihood = np.asarray([[p_a_s1, 1-p_a_s1],[1-p_a_s0,p_a_s0]]).T
    assert 0.0 <= ps <= 1.0
    assert 0.0 <= p_a_s1 <= 1.0
    assert 0.0 <= p_a_s0 <= 1.0
    prior = np.asarray([1-ps, ps])

    utility = np.array([[-loss_s, loss_s-1], [gain_s, 1-gain_s]]).T
    posterior = likelihood * prior.reshape(1,2)
    posterior, likelihood = posterior.T, likelihood.T
    expected = np.multiply(utility, posterior)

    # definitions for the axes
    left, width = 0.05, 0.16
    bottom, height = 0.05, 0.9
    padding = 0.04
    small_width = 0.1
    left_space = left + small_width + padding
    added_space = padding + width

    fig = plt.figure(figsize=(17, 3))

    rect_prior = [left, bottom, small_width, height]
    rect_likelihood = [left_space, bottom , width, height]
    rect_posterior = [left_space + added_space, bottom , width, height]
    rect_utility = [left_space + 2*added_space, bottom , width, height]
    rect_expected = [left_space + 3*added_space, bottom , width, height]

    ax_likelihood = fig.add_axes(rect_likelihood)
    ax_prior = fig.add_axes(rect_prior, sharey=ax_likelihood)
    ax_posterior = fig.add_axes(rect_posterior, sharey=ax_likelihood)
    ax_utility = fig.add_axes(rect_utility, sharey=ax_posterior)
    ax_expected = fig.add_axes(rect_expected, sharey=ax_utility)

    ax_prior.set_title("prior distribution", pad=15)
    ax_likelihood.set_title("likelihood")
    ax_posterior.set_title("posterior distribution")
    ax_utility.set_title("utility function")
    ax_expected.set_title("expected utility")

    rect_colormap = plt.cm.Blues

    # Show posterior probs and marginals
    ax_prior.barh(0, prior[0], facecolor = rect_colormap(prior[0]))
    ax_prior.barh(1, prior[1], facecolor = rect_colormap(prior[1]))
    ax_likelihood.matshow(likelihood, vmin=0., vmax=1., cmap='Reds')
    ax_posterior.matshow(posterior, vmin=0., vmax=1., cmap='Greys')
    ax_utility.matshow(utility, vmin=0., vmax=1., cmap='cool')
    ax_expected.matshow(expected, vmin=0., vmax=1., cmap='Wistia')

    for ax in [ax_prior, ax_likelihood, ax_posterior, ax_expected]:
        ax.set_xticks([0, 1])
        ax.set_yticks([0, 1])
        ax.set_xticklabels([0, 1])
        ax.set_yticklabels([0, 1])

    ax_utility.set_xticks([0, 1])
    ax_utility.set_yticks([0, 1])
    ax_utility.set_xticklabels(["loss", "gain"])
    # ax_utility.set_yticklabels([0, 1])

    ax_posterior.xaxis.set_ticks_position('bottom')
    ax_likelihood.xaxis.set_ticks_position('bottom')
    ax_utility.xaxis.set_ticks_position('bottom')
    ax_expected.xaxis.set_ticks_position('bottom')

    ax_prior.set_xlim([1, 0])

    # show values 
    ind = np.arange(2)
    x,y = np.meshgrid(ind,ind)
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{posterior[i,j]:.2f}"
        ax_posterior.text(j,i, c, va='center', ha='center', color='black')
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{likelihood[i,j]:.2f}"
        ax_likelihood.text(j,i, c, va='center', ha='center', color='black')
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{utility[i,j]:.2f}"
        ax_utility.text(j,i, c, va='center', ha='center', color='black')
    for i,j in zip(x.flatten(), y.flatten()):
        c = f"{expected[i,j]:.2f}"
        ax_expected.text(j,i, c, va='center', ha='center', color='black')
    for i in ind:
        v = prior[i]
        c = f"{v:.2f}"
        ax_prior.text(v+0.2, i, c, va='center', ha='center', color='black')


    # # show values 
    # ind = np.arange(2)
    # x,y = np.meshgrid(ind,ind)
    # for i,j in zip(x.flatten(), y.flatten()):
    #     c = f"{P[i,j]:.2f}"
    #     ax.text(j,i, c, va='center', ha='center', color='white')
    # for i in ind:
    #     v = marginal_x[i]
    #     c = f"{v:.2f}"
    #     ax_x.text(i, v +0.2, c, va='center', ha='center', color='black')
    #     v = marginal_y[i]
    #     c = f"{v:.2f}"
    #     ax_y.text(v+0.2, i, c, va='center', ha='center', color='black')

    # set up labels
    ax_prior.set_xlabel("m")
    ax_likelihood.set_xlabel("m")
    ax_likelihood.set_ylabel("s")
    ax_posterior.set_xlabel("m")
    ax_posterior.set_ylabel("s")
    ax_utility.set_ylabel("s")
    # ax_expected.set_xlabel("m")
    ax_expected.set_ylabel("s")
    ax_utility.set_xlabel("a")
    ax_prior.axis('off')
    return fig


# fig = plot_prior_likelihood(0.5, 0.3)
# plt.show(fig)
# plt.close(fig)


# Section 1: The Binary hidden state

You were just introduced to the binary hidden state problem we are going to explore. You are watching a person fishing on a dock and you want to determine where the fish are so you can decide where to fish. Remember, you can either think of your self as a scientist conducting an experiment or as a brain trying to make a decision. The Bayesian approach is the same!

In this section, we are going to walk though what it means to take a measurement (also often called an observation) and how to think about what it tells you about the probability of the hidden state we are interested in. Then we are going to think about what happens if you act on a guess about the hidden state.


## Video : Observations and costs

## Exercise 1: What is a likelihood?
We know fish like to school together. On different days the fish are mostly on the left or right, but you don’t know what the case is today. Let’s assume that on a given day all the fish are on one side only. So, we have no prior knowledge, but we can still know something about what catching a fish means for the likelihood of the fish being on one side or the other. You know that if you fish on the side of the dock where the fish are, you have a 50% chance that you catch a fish. Otherwise you catch a fish with only 10% probability. Calculate the following probabilities by hand.

1. We showed the P(m|s) if you fish on the right side of the dock. What are the probabilities on the left side?
2. What does a single measurement tell you?
3. What is the difference in the likelihood if you know the $P(s = left) = 0.3$?

To explore what happens as you change the (prior) probability of where the fish are today, $P(s)$, and you change the likelihood function, $P(m|s)$, use the widget below.


In [None]:
cor_widget = widgets.FloatSlider(0.3, description='ρ', min=0.0, max=1.0, step=0.01, disabled=False)

@widgets.interact(
    cor=cor_widget,
)
def make_prior_likelihood_plot(cor):
    fig = plot_prior_likelihood(0.5,cor)
    plt.show(fig)
    plt.close(fig)
    return None


## Exercise 2: Utility and Loss (gain) functions

Fish are much easier to catch on the left side of the dock, as there are no submarines, but you also know you are going to get sunburnt if you fish on the side were there are no fish. Let’s say you don’t know anything about where the fish are (no fishing yet today)

| Utility: U(s,a)   | a = left   | a = right  |
| ----------------- |----------|----------|
| s = Left          | 2          | -3         |
| s = right         | -2         |  1         |

1. What should cause you to choose left or right?
2. What changes after you have an observation?
3. Calculate the utility of fishing on the right and left side of the dock if you have no measurement and no prior information (50/50 probabilty the fish are on either side).
3. You observe someone fish on the right side and catch a fish, using the utilities we described in the video, what is the utility you should expect?

Let's ask the question in a different way, what if we just wanted to know how close or far off our inference is about the location of the fish today. To do this, we must decide how much cost we incur if we are incorrect in our inference. In this case, we call the utility function a Loss function, similiar to the Loss functions you have encountered already. The expected loss is: $\sum_{s}u(s,a)p(s)$, which allows us to ask how badly we expect our inference to be given a probabilty over the hidden state $s$.

1. Assume your Loss function is the the squared error $(a-s)^2$, what is expected loss if $P(s=left)$ is .5? if it is .3?

To explore what happens if you change the likelifhood and utility functions, use the widget below:

In [None]:
ps_widget = widgets.FloatSlider(0.5, description='p(s)', min=0.0, max=1.0, step=0.01)
p_a_s1_widget = widgets.FloatSlider(0.5, description='p(fish | a = s)', min=0.0, max=1.0, step=0.01)
p_a_s0_widget = widgets.FloatSlider(0.5, description='p(fish | a != s)', min=0.0, max=1.0, step=0.01)
loss_s_widget = widgets.FloatSlider(0.5, description='loss (a = s)', min=0.0, max=1.0, step=0.01)
gain_s_widget = widgets.FloatSlider(0.5, description='gain (a = s)', min=0.0, max=1.0, step=0.01)

@widgets.interact(
    ps=ps_widget,
    p_a_s1=p_a_s1_widget,
    p_a_s0=p_a_s0_widget,
    loss_s=loss_s_widget,
    gain_s=gain_s_widget,
)
def make_prior_likelihood_utility_plot(ps, p_a_s1, p_a_s0, loss_s, gain_s):
    fig = plot_prior_likelihood_utility(ps, p_a_s1, p_a_s0, loss_s, gain_s)
    plt.show(fig)
    plt.close(fig)
    return None


# Section 2: Correlation and marginalization

In this section we are going to think about the amount of information shared between two variables. We want to know how much information you gain when you observe one variable (take a measurement) if you know something about another. The fundemental concept is the same if we think about two attributes, for example the size and color of the fish, or the prior information and the likelihood.

## Video : Correlation and marginalization

## Exercise

To understand the information between two variables, let's first consider the size and color of the fish.

| p (x & y)   | y = sliver   | y = gold  |
| ----------------- |----------|----------|
| x = small          | 0.4          | 0.2         |
| x = large         | 0.1         |  0.3         |

We want to know what the probability of catching a small fish or a silver fish. To do this, we need to marginalize--or average out--the variable we are not intersted in across the rows or columns. For example, the $P(x = small) = \sum_y{P(x = small \& y)}$.

1. Calculate the probability of catching a small fish, a large fish, a silver fish or a gold fish.
2. Calculate the probability of catching a small fish OR a gold fish. (Hint: $P(A\ \textrm{or}\ B) = P(A) + P(B) - P(A\ \textrm{and}\ B)$)
3. Calculate the probability of catching a small gold fish or a large silver fish.

The relationship between the marginal probabilities and the joint probabilities is determined by the correlation between the two variables. To understand the way we calculate the correlation, we need to review the definition of covariance and correlation.

Covariance:
$cov(X,Y) = \sigma_{XY} = E[(X - \mu_{x})(Y - \mu_{y})] = E[X]E[Y] - \mu_{x}\mu_{y}$

Correlation:
$\rho_{XY} = \frac{cov(Y,Y)}{\sqrt{V(X)V(Y)}} = \frac{\sigma_{XY}}{\sigma_{X}\sigma_{Y}}$

Use the widget below and answer the following questions:

1. 


In [None]:
gs = GridspecLayout(2,2)

cor_widget = widgets.FloatSlider(0.3, description='ρ', min=-1, max=1, step=0.01)
px_widget = widgets.FloatSlider(0.5, description='p(x)', min=0.01, max=0.99, step=0.01)
py_widget = widgets.FloatSlider(0.5, description='p(y)', min=0.01, max=0.99, step=0.01)
gs[0,0] = cor_widget
gs[0,1] = px_widget
gs[1,0] = py_widget


@widgets.interact(
    px=px_widget,
    py=py_widget,
    cor=cor_widget,
)
def make_corr_plot(px, py, cor):
    Cmin, Cmax = compute_cor_range(px, py) #allow correlation values
    cor_widget.min, cor_widget.max = Cmin+0.01, Cmax-0.01
    if cor_widget.value > Cmax:
        cor_widget.value = Cmax
    if cor_widget.value < Cmin:
        cor_widget.value = Cmin
    cor = cor_widget.value
    P = compute_marginal(px,py,cor)
    # print(P)
    fig = plot_joint_probs(P) 
    plt.show(fig)
    plt.close(fig)
    return None

# gs[1,1] = make_corr_plot()

## Exercise

Let's return to fishing on the dock! We know the likelihoods if we take a measurement on the left or right side of the dock. We need to determine the likelihood given an assumption about the (prior) probility the fish are on the left or right using $\mathcal{L}(m|s) = P(m|s)$. Assume $P(s = left) = 0.5$.

1. Calcualte the likelihood of the fish being on the left side (the hidden state) if you see the person fishing on the right, but not catching a fish.
2. Calcualte the likelihood of the fish being on the left side if you see the person fishing on the left and catch a fish.

Now assume that $P(s = left) = 0.3$.

3. Calcualte the likelihood of the fish being on the right if you see the person catch a fish on the left.
4. Calcualte the likelihood of the fish being on the right if you see the person catch a fish on the right.

You can 

# Section 3. Bayes' Theorem and the Posterior

## Video : Bayes' Theorem

In [None]:
#@title Video 2: Bayes' theorem
from IPython.display import YouTubeVideox
# video = YouTubeVideo(id='ewQPHQMcdBs', width=854, height=480, fs=1)
# print("Video available at https://youtube.com/watch?v=" + video.id)
video

Now we can calcualte the full posterior distribution. The difference between the evidence or marginal likelihood,

$\mathcal{L}(s) = E_{s}[\mathcal{L}(m|s)] = \sum_s{P(m|s)P(s)}$

and the posterior probability of the hidden state, $s$, given a measurement, $m$, is the partiction function (or normalizing constant) that ensures we produce a full probability distribution. This means that we can use this posterior as a complete probability distribution for future compututations! We often call the posterior probability distribution our *belief*, $b$, about the hidden state.

$b = P(s|m) = \frac{P(m|s)P(s)}{P(m)}$

However, there is a reason that the likelihood function is an important concept: for many complicated cases, like those we might be using to model behavioral or brain inferences, the partition function can be intractable or extremely complex to calculate. This is why we often need to be careful to choose probability distributions were we can analytically calculate the posterior probability or numerical approximation is reliable. But the important thing to remember is that you can compare likelihoods, as we have seen during model fitting and model comparision, because the relative likelihoods are independent of the partition function.

Let's calculate the posterior probability distribution (our belief about the hidden state). Assume $P(s = left) = 0.3$

1. Calculate the posterior probability distribution if you see the person catch a fish on the right.
2. Calculate the posterior probability distribution if you see the person try to fish on the right but not catch a fish.

## Interactive Demo: What affects the posterior?

Now that we can play with the effects of *Bayes rule*, let's vary the parameters of the prior to see how changing the prior and likelihood affect the posterior. 

**Hit the Play button or Ctrl+Enter in the cell below** and play with the sliders to get an intuition for how the means and standard deviations of prior and likelihood influence the posterior.

When does the prior have the strongest influence over the posterior? When is it the weakest?  

In [None]:
cor_widget = widgets.FloatSlider(0.3, description='ρ', min=0.0, max=1.0, step=0.01, disabled=False)
ps_widget = widgets.FloatSlider(0.5, description='p(s)', min=0.0, max=1.0, step=0.01)

@widgets.interact(
    ps=ps_widget,
    cor=cor_widget,
)
def make_prior_likelihood_plot(ps,cor):
    fig = plot_prior_likelihood(ps,cor)
    plt.show(fig)
    plt.close(fig)
    return None


In [None]:
#@title
from IPython.display import YouTubeVideo

video = YouTubeVideo(id='AbXorOLBrws', width=854, height=480, fs=1)
print("Video available at https://youtube.com/watch?v=" + video.id)
video

# Section 3: Bayesian decisions

We will explore how to consider the expected utility of an action based on our belief (the posterior distribution) about where we think the fish are. Now we have all the components of a Bayesian decision: our prior information, the likelihood given a measurement, the posterior distribution (belief) and our utility (the gains and losses). This allows us to consider the relationship between the true value of the hidden state, $s$, and what we *expect* to get if we take action, $a$, based on our beleif!

Let's use the following widget to think about the relationship between these probability distributions and utility function.


In [None]:
ps_widget = widgets.FloatSlider(0.5, description='p(s)', min=0.0, max=1.0, step=0.01)
p_a_s1_widget = widgets.FloatSlider(0.5, description='p(fish | a = s)', min=0.0, max=1.0, step=0.01)
p_a_s0_widget = widgets.FloatSlider(0.5, description='p(fish | a != s)', min=0.0, max=1.0, step=0.01)
loss_s_widget = widgets.FloatSlider(0.5, description='loss (a = s)', min=0.0, max=1.0, step=0.01)
gain_s_widget = widgets.FloatSlider(0.5, description='gain (a = s)', min=0.0, max=1.0, step=0.01)

@widgets.interact(
    ps=ps_widget,
    p_a_s1=p_a_s1_widget,
    p_a_s0=p_a_s0_widget,
    loss_s=loss_s_widget,
    gain_s=gain_s_widget,
)
def make_prior_likelihood_utility_plot(ps, p_a_s1, p_a_s0, loss_s, gain_s):
    fig = plot_prior_likelihood_utility(ps, p_a_s1, p_a_s0, loss_s, gain_s)
    plt.show(fig)
    plt.close(fig)
    return None


# End of tutorial 1