<a href="https://colab.research.google.com/github/motorlearner/neuromatch/blob/main/bayes_intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title `Code: Definitions and Helpers`
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import ipywidgets as widgets

# classes
class dotdict(dict):
    """dictionary where attributes can be accessed as dotdict.attribute"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

class NormalParam(dotdict):
    """dotdict for normal rv with attributes mean, sd, value"""
    def __init__(self, mean=None, sd=None, value=None):
        if not all((isinstance(v, (int, float, type(None)))) for v in (mean, sd, value)):
            raise TypeError (f'NormalParam(): all args must be numbers!')
        super().__init__()
        self.mean = mean
        self.sd = sd
        self.value = value

class PlotData(dotdict):
    """dotdict for normal rv with attributes mean, sd, value"""
    def __init__(self, s=None, xs=None, l=None, sx=None):
        super().__init__()
        self.s  = s
        self.xs = xs
        self.l  = l
        self.sx = sx

# functions
def normalize(arr):
    """normalize 1D array such that it sums to 1"""
    return arr / arr.sum(axis=0, keepdims=True)

def multiply_normals(a: NormalParam, b: NormalParam) -> NormalParam :
    """get NormalParam(mean, sd) from multiplying two others"""
    # check types
    if not (isinstance(a, NormalParam) & isinstance(b, NormalParam)):
        raise TypeError ("Both arguments must be of class NormalParam.")
    elif any(v is None for v in (a.mean, b.mean, a.sd, b.sd)):
        raise ValueError ("Both means and SDs must be initialized.")
    # get result
    result = NormalParam()
    result.mean = ( 1/a.sd**2 * a.mean + 1/b.sd**2 * b.mean ) / (1/a.sd**2 + 1/b.sd**2)
    result.sd = np.sqrt( (a.sd**2 * b.sd**2) / (a.sd**2 + b.sd**2) )
    return result

# variables
xlims  = (-4,4)
labels = PlotData('$p(s)$', '$p(x|s)$', '$\\mathscr{L}(s;x)$', '$p(s|x)$')
colors = PlotData('#E98A15', '#012622', '#7B0828', '#035E7B')

# Symbols

* $s$ is the stimulus presented by the experimenter to the observer (known to the experimenter, unknown to the observer)
* $p(s)$ is the observer's prior belief about what stimuli they might recieve
* $x$ is a noisy measurement of the stimulus, this is what the observer actually receives (unknown to the experimenter, known to the observer)
* $p(x\mid s)$ is the measurement distribution, given a stimulus $s$, the measurement $x$ is drawn randomly from this distribution
* $\mathscr{L}(s;x)$ is the likelihood function, it is the evidence for each possible stimulus value given only the measurement (ignoring prior belief)...mind you, the likelihood function is generally not a probability distribution, but once you normalize it it is (because then it sums/integrates to 1)
* $\widehat{S}_\text{ML}$ is the maximum likelihood estimate, the stimulus $s$ that maximizes the likelihood function
* $p(s \mid x)$ is the posterior distribution, it is obtained by multiplying prior with likelihood (and then normalizing, because probability must sum/integrate to 1), it is the evidence for each possible stimulus value given all the information available, i.e. both prior belief and measurement
* $\widehat{S}_\text{PM}$ is the posterior mean, the mean of the posterior distribution

In [None]:
# @title Single Trial Inference

# @markdown A Bayesian observer with Gaussian prior and Gaussian measurement distribution. You can change the SD of the prior, the SD of the measurement, the stimulus $s$ (on which the measurement distribution $p(x\mid s)$ is centered), and the actual measurement recieved. Note that for this all-Gaussian case:

# @markdown Play around with the parameters and note the following:
# @markdown * the posterior mean estimate is always between the prior mean and the mean of the likelihood function
# @markdown * the posterior mean is further towards the *narrower* distribution
# @markdown * the posterior is always narrower than both the prior and likelihood (for non-Gaussian cases, this is only true on average...for the Gaussian case, this is always true)

def make_plot(s_val, s_sd, x_val, x_sd):

    # initialize normal param dicts for p(s), p(x|s), L(s;x), p(s|x)
    params    = PlotData()
    params.s  = NormalParam(mean=0, sd=s_sd, value=s_val)
    params.xs = NormalParam(mean=s_val, sd=x_sd, value=x_val)
    params.l  = NormalParam(mean=x_val, sd=x_sd)
    params.sx = multiply_normals(params.s, params.l)

    # support for plotting
    xbase = np.linspace(xlims[0], xlims[1], 500)

    # densities for plotting
    dens    = PlotData()
    dens.s  = stats.norm.pdf(xbase, params.s.mean, params.s.sd)
    dens.xs = stats.norm.pdf(xbase, params.xs.mean, params.xs.sd)
    dens.l  = stats.norm.pdf(params.xs.value, xbase, params.xs.sd)
    dens.sx = stats.norm.pdf(xbase, params.sx.mean, params.sx.sd)

    # make plot
    layout = [[x] for x in PlotData().keys()]
    fig, axd = plt.subplot_mosaic(layout, sharex=True)

    for key, ax in axd.items():
        ax.set_xlim(np.min(xbase), np.max(xbase))
        ax.set_yticklabels([])
        ax.set_yticks([])
        ax.set_ylabel(labels[key], rotation=0, horizontalalignment='right')
        ax.plot(xbase, dens[key], c=colors[key])

        if key=='s':
            ax.axvline(params.s.mean, c=colors[key], ls='--', lw=0.8)
            ax.annotate(
                r'$\mathrm{\mathbb{E}}[s]$', (params.s.mean, ax.get_ylim()[1]/2), c=colors[key], zorder=2,
                textcoords="offset points", xytext=(0, 5), ha='center', va='top', fontsize=8,
                bbox=dict(boxstyle='round', pad=0.05, facecolor='white', edgecolor=colors[key], ls='')
            )
            ax.scatter(params.s.value, 0, marker='s', s=20, c=colors[key])
            ax.annotate(
                '$s$', (params.s.value, 0), c=colors[key], zorder=2,
                textcoords="offset points", xytext=(2, 6), ha='left', va='center'
            )
        elif key=='xs':
            ax.axvline(params.xs.mean, c=colors[key], ls='--', lw=0.8)
            ax.annotate(
                r'$\mathrm{\mathbb{E}}[x|s]$', (params.xs.mean, ax.get_ylim()[1]/2), c=colors[key], zorder=2,
                textcoords="offset points", xytext=(0, 5), ha='center', va='top', fontsize=8,
                bbox=dict(boxstyle='round', pad=0.05, facecolor='white', edgecolor=colors[key], ls='')
            )
            ax.scatter(params.xs.value, 0, marker='s', s=20, c=colors[key], zorder=2)
            ax.annotate(
                '$x$', (params.xs.value, 0), c=colors[key], zorder=2,
                textcoords="offset points", xytext=(2, 5), ha='left', va='center'
            )
        elif key=='l':
            ax.axvline(params.xs.value, c=colors[key], ls='--', lw=0.8)
            ax.annotate(
                '$\\hat{s}_{ML}$', (params.xs.value, ax.get_ylim()[1]/2), c=colors[key], zorder=2,
                textcoords="offset points", xytext=(5, 5), ha='center', va='top',
                bbox=dict(boxstyle='round', pad=0.05, facecolor='white', edgecolor=colors[key], ls='')
            )
        elif key=='sx':
            ax.axvline(params.sx.mean, c=colors[key], ls='--', lw=0.8)
            ax.annotate(
                '$\\hat{s}_{PM}$', (params.sx.mean, ax.get_ylim()[1]/2), c=colors[key], zorder=2,
                textcoords="offset points", xytext=(5, 5), ha='center', va='top',
                bbox=dict(boxstyle='round', pad=0.05, facecolor='white', edgecolor=colors[key], ls='', alpha=0.9)
            )

# define sliders
slide_sval = widgets.FloatSlider(value=0, min=xlims[0], max=xlims[1], step=0.1, description='s', readout_format='0.1f')
slide_xval = widgets.FloatSlider(value=0, min=xlims[0], max=xlims[1], step=0.1, description='x|s', readout_format='0.1f')
slide_ssd  = widgets.FloatSlider(value=1, min=0.1, max=2, step=0.1, description='SD(s)', readout_format='0.1f')
slide_xsd  = widgets.FloatSlider(value=0.8, min=0.1, max=2, step=0.1, description='SD(x|s)', readout_format='0.1f')

# define update functions
def update_srange(change):
    # adjust range for s according to sd(s)
    slide_sval.min = 0 - 3 * slide_ssd.value
    slide_sval.max = 0 + 3 * slide_ssd.value
def update_xrange(change):
    # adjust range for x according to E[x|s]=s and sd(x|s)
    slide_xval.min = slide_sval.value - 3 * slide_xsd.value
    slide_xval.max = slide_sval.value + 3 * slide_xsd.value

# monitor changes (e.g. upon change in ssd, run update_srange)
slide_ssd.observe(update_srange, names='value')
slide_sval.observe(update_xrange, names='value')
slide_xsd.observe(update_xrange, names='value')

# make interactive plot output
plotoutput = widgets.interactive_output(
    make_plot,
    {'s_val': slide_sval, 's_sd': slide_ssd, 'x_val': slide_xval, 'x_sd': slide_xsd}
)

# layout
widgets.HBox([
    plotoutput,
    widgets.VBox([
        widgets.VBox([slide_sval, slide_ssd], layout = widgets.Layout(top = '5%')),
        widgets.VBox([slide_xval, slide_xsd], layout = widgets.Layout(top = '13%'))],
    layout = widgets.Layout()
    )],
layout = widgets.Layout()
)

In [None]:
# @title Constructing the likelihood

# @markdown This one can be confusing. The measurement distribution $(x\mid s)$ tells you how probable a certain measurement is given the stimulus presented...so here, $s$ is fixed to the actual presented stimulus and the input is any hypothetical measurement $x$, and $p(x\mid s)$ is the output! The Likelihood function $\mathscr{L}(s;x) \equiv p(x\mid s)$ tells you how probably different stimuli make your measurement...so here, $x$ is fixed to the actual measurement recieved and the input is any hypothetical stimulus $s$, and the output is $p(x\mid s)$. So the key difference is, the measurement distribution is a function of $x$, the likelihood is a function of $s$.

# @markdown Below, you can see how the observer constructs the likelihood function. The observer does not know the stimulus $s$, but for any *hypothetical* stimulus, the observer knows the measurement distribution $p(x\mid s)$. Likelihood function is computed as follows:

# @markdown Repeat for each hypothetical $s$:
# @markdown * assume this was the true $s$, and take the corresponding measurement distribution $p(x\mid s)$
# @markdown * check how probable your actual measurement $x$ is in this hypothetical situation
# @markdown * this value is the likelihood of this hypothetical stimulus $s$

# @markdown Proceed as follows below: choose a certain $x$ (the actual measurement received). Then, slide $s$ around (hypothetical stimuli that might have caused this measurement) and see how the likelihood function is constructed. You can also change the width of the measurement distribution, which will affect the width of the likelihood function. You can also change the heteroskedasticity of the measurement distribution, i.e. it is still Gaussian but the SD depends on the measurement, which will change the shape of the likelihood function. In either case, slide around $s$ to see how the likelihood function is constructed from the measurement distribution.

def myplot(x_val, x_sd, s_hyp, x_sd_fac, show_lines=True):

    xbase  = np.linspace(xlims[0], xlims[1], 5000)
    dens = PlotData()
    dens.xs = stats.norm.pdf(xbase, s_hyp, x_sd + np.absolute(s_hyp) * x_sd_fac)
    dens.l  = stats.norm.pdf(x_val, xbase, x_sd + np.absolute(xbase) * x_sd_fac)

    # make plot
    layout = [[x] for x in PlotData().keys() if not x=="sx"]
    fig, axd = plt.subplot_mosaic(layout, sharex=True, height_ratios=[0.2, 1, 1],)

    for key, ax in axd.items():
        ax.set_xlim(np.min(xbase), np.max(xbase))
        ax.set_yticklabels([])
        ax.set_yticks([])

        if not key=='s':
            ax.set_ylabel(labels[key], rotation=0, horizontalalignment='right')
            ax.plot(xbase, dens[key], c=colors[key])

        if key=='s':
            ax.set_ylim(-.05, .15)
            ax.scatter(s_hyp, 0, marker='v', s=45, c=colors[key])
            ax.annotate(
                '$s$', (s_hyp, 0), c=colors[key], zorder=2,
                textcoords="offset points", xytext=(1, 8.5), ha='center', va='center',
                bbox=dict(boxstyle='round', pad=0.001, facecolor='white', edgecolor=colors[key], ls='', alpha=0.9)
            )
        elif key=='xs' and show_lines==True:
            ax.scatter(x_val, 0, marker='v', s=45, c=colors[key], zorder=2)
            ax.plot((x_val, x_val), (0, stats.norm.pdf(x_val, s_hyp, x_sd + np.absolute(s_hyp) * x_sd_fac)), c=colors['xs'], ls='-', lw=0.8, zorder=1)
            ax.plot((xlims[0], x_val), (stats.norm.pdf(x_val, s_hyp, x_sd + np.absolute(s_hyp) * x_sd_fac), stats.norm.pdf(x_val, s_hyp, x_sd + np.absolute(s_hyp) * x_sd_fac)),
                    c=colors['l'], ls='-', lw=0.8, zorder=1)
            ax.annotate(
                '$x$', (x_val, 0), c=colors[key], zorder=2,
                textcoords="offset points", xytext=(1, 8.5), ha='center', va='center',
                bbox=dict(boxstyle='round', pad=0.001, facecolor='white', edgecolor=colors[key], ls='', alpha=0.9)
            )
        elif key=='l' and show_lines==True:
            ax.plot((s_hyp, s_hyp), (0, stats.norm.pdf(x_val, s_hyp, x_sd + np.absolute(s_hyp) * x_sd_fac)), c=colors['s'], ls='-', lw=0.8, zorder=1)
            ax.vlines(x=s_hyp, ymin=0, ymax=stats.norm.pdf(x_val, s_hyp, x_sd + np.absolute(s_hyp) * x_sd_fac), colors=colors['s'], ls='-', lw=0.8, zorder=1)
            ax.plot((xlims[0], s_hyp), (stats.norm.pdf(x_val, s_hyp, x_sd + np.absolute(s_hyp) * x_sd_fac), stats.norm.pdf(x_val, s_hyp, x_sd + np.absolute(s_hyp) * x_sd_fac)),
                    c=colors['l'], ls='-', lw=0.8, zorder=1)
        else:
            pass

# define sliders
slide_shyp = widgets.FloatSlider(value=0, min=xlims[0], max=xlims[1], step=0.1, description='s', readout_format='0.1f')
slide_xval = widgets.FloatSlider(value=0, min=xlims[0], max=xlims[1], step=0.1, description='x|s', readout_format='0.1f')
slide_xsd  = widgets.FloatSlider(value=0.2, min=0.1, max=1, step=0.1, description='sd(x|s)', readout_format='0.1f')
slide_xsdf = widgets.FloatSlider(value=0, min=0, max=1, step=0.1, description='Heterosk.', readout_format='0.1f')
tick_showlines = widgets.Checkbox(value=True, description='Helper Lines', disabled=False, indent=True)

# make interactive plot output
plotoutput = widgets.interactive_output(
    myplot,
    {'x_val': slide_xval, 'x_sd': slide_xsd, 's_hyp': slide_shyp, 'x_sd_fac': slide_xsdf, 'show_lines': tick_showlines}
)

# layout
widgets.HBox([
    plotoutput,
    widgets.VBox([
        widgets.VBox([slide_shyp], layout = widgets.Layout(top = '5%')),
        widgets.VBox([slide_xval, slide_xsd, slide_xsdf], layout = widgets.Layout(top = '13%')),
        widgets.VBox([tick_showlines], layout = widgets.Layout(top = '28%')),
    ], layout = widgets.Layout())
], layout = widgets.Layout()
)