<a href="https://colab.research.google.com/github/motorlearner/neuromatch/blob/main/observers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
from scipy.stats import norm, vonmises, circmean
from scipy.optimize import minimize_scalar
from arviz import kde
from matplotlib import pyplot as plt
import ipywidgets as widget
import time

In [2]:
# @title `utils: center/uncenter angles`

demos = False

def center_angle(x:float|np.ndarray, ref:float|np.ndarray, period:float):
  """
  Map angles to [`-period/2`, `+period/2`) such that 0 corresponds to `ref`,
  positive angles mean you move counter-clockwise and negative mean clockwise.

  Args:
    `x      :` angles (scalar or array)
    `ref    :` reference angles (scalar or array)
    `period :` length of one period (scalar)

  Returns:
    Float or array of centered angles in [`-period/2`, `+period/2`).
  """
  y = np.atleast_1d(((x - ref + period/2) % period) - period/2)
  return y[0] if np.isscalar(x) and np.isscalar(ref) else y

def uncenter_angle(x:float|np.ndarray, ref:float|np.ndarray, period:float):
  """
  Convert centered angles in [`-period/2`, `+period/2`) back to absolute angles
  in [0, period), using the original reference.

  Args:
    `x      :` centered angles (scalar or array)
    `ref    :` reference angle(s) used during centering
    `period :` length of one period (scalar)

  Returns:
    Float or array of absolute angles in [0, period).
  """
  y = np.atleast_1d((x + ref) % period)
  y[np.isclose(y, period)] = 0.0 # floating point errors
  return y[0] if np.isscalar(x) and np.isscalar(ref) else y

if demos:

  def demo(degref:float=0):
    radref = np.deg2rad(degref)

    xdeg = np.linspace(0, 359, 360)
    cdeg = center_angle(xdeg, degref, 360)
    xrad = xdeg * np.pi/180
    crad = center_angle(xrad, radref, 2*np.pi)

    fig, axes = plt.subplots(2, 2, figsize=(5,5))

    for ax in [axes[0,0]]:
      ax.axvline(degref, c='red', lw=1)
      ax.axhline(0, c='red', lw=1)
      ax.scatter(xdeg, cdeg, s=1, c='black')
      ax.set_ylabel('Centered')
      ax.set_title('Degrees')
      ax.set_xticks(np.linspace(0, 360, 5))
      ax.set_yticks(np.linspace(-180, 180, 5))

    for ax in [axes[0,1]]:
      ax.axvline(radref, c='red', lw=1)
      ax.axhline(0, c='red', lw=1)
      ax.scatter(xrad, crad, s=1, c='black')
      ax.set_title('Radians')
      ax.set_xticks(np.linspace(0, 2*np.pi, 5))
      ax.set_yticks(np.linspace(-np.pi, np.pi, 5))
      ax.set_xticklabels([r'$0$', r'$\frac{\pi}{2}$', r'$\pi$', r'$\frac{3\pi}{2}$', r'$2\pi$'])
      ax.set_yticklabels([r'$-\pi$', r'$-\frac{\pi}{2}$', r'$0$', r'$\frac{\pi}{2}$', r'$\pi$'])

    for ax in [axes[1,0]]:
      ax.axline((0,0), slope=1, c='red', lw=1)
      ax.scatter(xdeg, uncenter_angle(cdeg, degref, 360), s=1, c='black')
      ax.set_xlabel('Uncentered')
      ax.set_ylabel('Center then Uncenter')
      ax.set_xticks(np.linspace(0, 360, 5))
      ax.set_yticks(np.linspace(0, 360, 5))

    for ax in [axes[1,1]]:
      ax.axline((0,0), slope=1, c='red', lw=1)
      ax.scatter(xrad, uncenter_angle(crad, radref, 2*np.pi), s=1, c='black')
      ax.set_xlabel('Uncentered')
      ax.set_xticks(np.linspace(0, 2*np.pi, 5))
      ax.set_yticks(np.linspace(0, 2*np.pi, 5))
      ax.set_xticklabels([r'$0$', r'$\frac{\pi}{2}$', r'$\pi$', r'$\frac{3\pi}{2}$', r'$2\pi$'])
      ax.set_yticklabels([r'$0$', r'$\frac{\pi}{2}$', r'$\pi$', r'$\frac{3\pi}{2}$', r'$2\pi$'])

    fig.tight_layout()

  widget.interact(demo, degref=(0, 360, 1))

In [3]:
# @title `utils: normalize densities`

demos = False

def density_rad2deg(y:np.ndarray):
  """
  Convert density over radian support to density over degree support.

  Args:
      `y :` array of density values, which intagrates to 1 over support in rad

  Returns:
    Array of density values, which integrates to 1 over support in deg.
  """
  return y * np.pi/180

def density_normalize(y:np.ndarray, axis:int|None=None):
  """
  Normalize an array of densities so that the total sum is 1.

  Args:
      `density :` array of non-negative density values
      `axis    :` axis to normalize across

  Returns:
      Normalized density array where all values sum to 1.
  """
  return y / np.sum(y, axis=axis, keepdims=True)

if demos:
  r   = vonmises(loc=0*np.pi, kappa=20).rvs(5000)
  x,y = kde(r, circular=True)
  plt.hist(np.rad2deg(r), bins=100, density=True, label='plt.hist')
  plt.plot(np.rad2deg(x), density_rad2deg(y), label='av.kde')
  plt.xlabel('Angle [Deg]')
  plt.ylabel('Density')
  plt.legend()

In [4]:
# @title `utils: von mises concentration`

demos = True

def kappa2sigma(kappa:float|np.ndarray, mu:float=0):
  """
  Convert von Mises concentration parameter `kappa` to circular standard
  deviation (degrees).

  Args:
      `kappa :` concentration parameter of von Mises distribution  (non-negative float/array).
      `mu    :` mean of von Mises distribution in degrees (non-negative float/array)

  Returns:
      Circular standard deviation in degrees (float/array).
  """
  kappa_array = np.atleast_1d(kappa)
  x = np.linspace(1, 360, 3600)

  def compute_sigma(k):
    d = center_angle(x, mu, 360)
    p = density_normalize(vonmises(loc=np.deg2rad(mu), kappa=k).pdf(np.deg2rad(x)))
    return np.sqrt(np.sum(p * d**2))

  sigmas = np.array([compute_sigma(k) for k in kappa_array])
  return sigmas if sigmas.size > 1 else sigmas[0]

def sigma2kappa(sigma:float|np.ndarray):
  """
  Numerically invert circular standard deviation in degrees to
  von Mises concentration parameter `kappa`, rounded to nearest 0.1.
  Note this may fail with small values, e.g. sigma < 3.

  Args:
    `sigma :` Circular standard deviation in degrees (float/array).

  Returns:
    Estimated von Mises concentration parameter (float/array).
  """
  sigma_arr = np.atleast_1d(sigma)

  if sigma_arr.ndim > 1:
    raise ValueError(f"sigma2kappa(sigma): sigma must be shape (n,) but is shape {sigma_arr.shape}")

  def invert_elem(sigma_val):
    def objective(log_kappa):  # log(kappa) search
      k = np.exp(log_kappa)
      s = kappa2sigma(k)
      return (s - sigma_val) ** 2
    res = minimize_scalar(
      objective,
      bounds=(np.log(1e-6), np.log(1e3)),
      method='bounded',
      options={'xatol': 1e-6}
    )
    if not res.success:
        raise RuntimeError(f"Failed to invert sigma={sigma_val}")
    return np.round(np.exp(res.x), 6)

  kappas = np.array([invert_elem(s) for s in sigma_arr])
  # return kappas if kappas.size > 1 else kappas[0]
  return kappas


if demos:
  kappas = (33.3, 8.7, 2.8, 0.7)
  print(f'kappas: {kappas}')
  print(f'to sigmas: {kappa2sigma(kappas)}')
  print(f'to kappas: {sigma2kappa(kappa2sigma(kappas))}')
  sigmas = (10, 20, 40, 80)
  print(f'sigmas: {sigmas}')
  print(f'to kappas: {sigma2kappa(sigmas)}')
  print(f'to sigmas: {kappa2sigma(sigma2kappa(sigmas))}')

kappas: (33.3, 8.7, 2.8, 0.7)
to sigmas: [10.19009342 20.24497727 39.91304508 81.67631574]
to kappas: [33.300004  8.7       2.8       0.7     ]
sigmas: (10, 20, 40, 80)
to kappas: [34.582285  8.901683  2.791226  0.754532]
to sigmas: [ 9.99999969 20.0000001  40.00000173 79.99999995]


In [5]:
# @title `utils: circular statistics`

demos = False

def cmean(x:np.ndarray, w:np.ndarray|None=None):
  """
  Compute the weighted circular mean of angles (in radians) using complex numbers.

  Args:
    x  : angles in degrees
    w  : weights for each angle

  Returns:
    Weighted circular mean angle in radians (wrapped to [0, 360)).
  """
  xrad = np.deg2rad(np.atleast_1d(x))

  if w is None:
    w = np.ones_like(xrad)
  else:
    w = np.atleast_1d(w)

  z = w * np.exp(1j * xrad)
  m = np.sum(z) / np.sum(w)

  mrad = np.angle(m)
  mdeg = np.rad2deg(uncenter_angle(mrad, 0, 2*np.pi))
  return mdeg

def cmode(x:np.ndarray, y:np.ndarray|None=None):
  """
  Compute the circular mode from binned angles in degrees.

  Args:
    x : angles in degrees
    y : array of frequencies

  Returns:
    Mode angle in degrees, in [0, 360)
  """
  if y is None:
      y = np.ones_like(x)
  else:
      y = np.atleast_1d(y)

  return x[np.argmax(y)]

In [6]:
# @title `class: Prior`
class Prior:

  def __init__(self, weights, means, sigmas):
    # argproc
    weights = np.atleast_1d(weights)
    means   = np.atleast_1d(means)
    sigmas  = np.atleast_1d(sigmas)
    # argcheck
    if not len({ arr.shape for arr in (weights, means, sigmas) }) == 1:
      raise ValueError('Expected all inputs to be same shape.')
    if not np.sum(weights)==1:
      raise ValueError('Expected weights to sum to 1.')
    # initialize
    self.weights = weights
    self.means   = means
    self.sigmas  = sigmas

  def plot(self, theta:np.ndarray=np.linspace(0,359,360)):
    y = self.pdf(theta)
    plt.plot(theta, y, ls='-', lw=2, c='black', label='$p(\\theta)$')
    plt.xlim(0,360)
    plt.xlabel('$\\theta$')
    plt.ylabel('$p(\\theta)$')
    plt.legend()
    plt.show()

  def pdf(self, theta:np.ndarray):
    # evaluate pdf
    theta   = np.deg2rad(np.atleast_1d(theta))
    kappas  = self.get_kappas(self.sigmas)
    priors  = self.get_priors(self.means, kappas)
    density = np.array([ w * prior.pdf(theta) for w,prior in zip(self.weights, priors) ]).sum(axis=0)
    return density_rad2deg(density)

  def get_kappas(self, sigmas:np.ndarray):
    kappas = np.atleast_1d(sigma2kappa(sigmas))
    return kappas

  def get_priors(self, means:np.ndarray, kappas:np.ndarray):
    priors = np.array([ vonmises(loc=np.deg2rad(m), kappa=k) for m,k in zip(means,kappas) ])
    return priors

In [91]:
# @title `class: Mdist`
class Mdist:

  def __init__(self, weights, sigma_funcs, mean_funcs):
    """
    Initialize a measurement distribution, which is a mixture of von Mises
    distributions with signal-dependent concentration.

    Args:
      `weights     :` tuple of mixture weights (must sum to 1)
      `sigma_funcs :` tuple of functions f: theta[deg] -> sd
      `mean_funcs  :` tuple of functions f: theta[deg] -> mu (array)
    """
    # argproc
    weights     = np.atleast_1d(weights).astype(float)
    sigma_funcs = np.atleast_1d(sigma_funcs).astype(object)
    mean_funcs  = np.atleast_1d(mean_funcs).astype(object)
    # argchecks
    if not np.all(weights.shape == sigma_funcs.shape) and np.all(weights.shape==mean_funcs.shape):
      raise ValueError('Expected weights and sigma_funcs to be same shape.')
    # initialize
    self.weights     = weights / np.sum(weights)
    self.sigma_funcs = sigma_funcs
    self.mean_funcs  = mean_funcs

  def plot(self, theta:float, n:int=1000):
    # visualize
    theta_til=np.linspace(0,359,360)
    r = self.rvs(theta, n)
    y = self.pdf(theta, theta_til)
    plt.hist(r, density=True, color='orange', edgecolor='white', bins=np.linspace(0,360,90), label='samples')
    plt.plot(theta_til, y, ls='-', lw=2, c='navy', label='$p(\\tilde{\\theta}\\mid\\theta)$')
    plt.axvline(theta, ls='--', lw=1, c='navy', label=f'$\\theta$')
    plt.xlabel('$\\tilde{\\theta}$ [deg]')
    plt.ylabel('$p(\\tilde{\\theta}\\mid\\theta)$')
    plt.xlim(0,360)
    plt.legend()
    plt.show()

  def interact(self):
    widget.interact(
      self.plot,
      theta=widget.FloatSlider(min=0, max=359, step=1, value=180, readout_format='.0f', continuous_update=False),
      n=widget.fixed(1000)
    )

  def rvs(self, theta:float, n:int=10000):
    # draw random samples
    kappas  = self.get_kappas(theta)
    means   = self.get_means(theta)
    mdists  = self.get_mdists(theta, means, kappas)
    rng     = np.random.default_rng()
    choices = rng.choice(len(mdists), size=n, p=self.weights)
    samples = np.empty(n)
    for c in np.arange(len(mdists)):
      linds = (choices == c)
      samples[linds] = mdists[c].rvs(np.sum(linds))
    return np.rad2deg(uncenter_angle(samples, 0, 2*np.pi))

  def pdf(self, theta:float, theta_til:np.ndarray):
    # evaluate the pdf for given theta at theta_til
    pdf = self.get_pdffunc(theta)
    return pdf(theta_til)

  def get_kappas(self, theta:float):
    sigmas = [ f(theta) for f in self.sigma_funcs ]
    kappas = [ sigma2kappa(s) for s in sigmas ]
    # kappas = [ np.atleast_1d(sigma2kappa(s)) for s in sigmas ]
    return kappas

  def get_means(self, theta:float):
    means = [ f(theta) for f in self.mean_funcs ]
    return means

  def get_mdists(self, theta:float, means: np.ndarray, kappas:np.ndarray):
    mdists = np.array([ vonmises(loc=np.deg2rad(m), kappa=k) for m,k in zip(means, kappas) ])
    return mdists

  def get_pdffunc(self, theta:float):
    kappas = self.get_kappas(theta)
    means  = self.get_means(theta)
    mdists = self.get_mdists(theta, means, kappas)
    def pdf_func(theta_til:np.ndarray):
      theta_til = np.deg2rad(np.atleast_1d(theta_til))
      densities = np.array([ w * mdist.pdf(theta_til) for w,mdist in zip(self.weights, mdists) ])
      density   = densities.sum(axis=0)
      return density_rad2deg(density)
    return pdf_func

In [8]:
# @title `class: BayesObserver`
class BayesObserver:

  def __init__(self, prior, mdist, n_support=360):
    """
    Initialize a Bayesian Observer.

    Args:
      prior  :  object of class Prior
      midst  :  object of class Mdist
    """
    # initialize...
    self.prior   = prior
    self.mdist   = mdist
    # ...support
    self.xdeg    = np.linspace(0,359,n_support)
    self.xrad    = np.deg2rad(self.xdeg)
    self.x       = self.xdeg
    # ...random draws from mdist
    self.r_mdist = None
    # ...densities for prior, mdist, like, posterior, rdist
    self.y_prior = self.get_y_prior()

  def respondslow(self, theta:float, n:int=1000, f=cmean):
    self.r_mdist = self.get_r_mdist(theta, n)
    self.y_mdist = self.get_y_mdist(theta)
    self.Y_like  = self.get_Y_like(self.r_mdist)
    self.Y_post  = self.get_Y_post(self.y_prior, self.Y_like)
    self.X_est   = self.get_X_est(self.Y_post, f)
    self.y_est   = self.get_y_est(self.X_est)

  def respond(self, theta:float, n:int=1000, f=cmean):
    self.r_mdist = self.get_r_mdist(theta, n)
    self.y_mdist = self.get_y_mdist(theta)
    self.Y_like  = self.get_Y_like_fast(self.r_mdist)
    self.Y_post  = self.get_Y_post(self.y_prior, self.Y_like)
    self.X_est   = self.get_X_est(self.Y_post, f)
    self.y_est   = self.get_y_est(self.X_est)

  def get_y_prior(self):
    # get density of prior distribution
    y = self.prior.pdf(self.xdeg)
    return density_normalize(y)

  def get_r_mdist(self, theta:float, n:int):
    # get n random samples from measurement distribution
    r = self.mdist.rvs(theta, n)
    return r

  def get_ry_mdist(self, r_mdist:np.ndarray):
    # get density of random samples from measurement distribution
    x_, y_ = kde(np.deg2rad(r_mdist), circular=True)
    y      = np.interp(self.xrad, x_, y_, period=2*np.pi)
    return density_normalize(y)

  def get_y_mdist(self, theta:float):
    # get density of measurement distribution
    y = self.mdist.pdf(theta, self.xdeg)
    return density_normalize(y)

  def get_y_like(self, theta:np.ndarray, theta_til:float):
    # get normalized likelihood
    y = np.squeeze((self.mdist.pdf(theta, theta_til)))
    return density_normalize(y)

  def get_Y_like(self, r_mdist:np.ndarray):
    # get normalized likelihood for each sample (each row a sample)
    Y = np.array([ self.get_y_like(self.xdeg, tt) for tt in r_mdist ])
    return Y

  def get_Y_like_fast(self, r_mdist:np.ndarray):
    # faster version: for each theta, compute likelihood across theta_til
    # the mdist pdf depends on theta, so make pdf once, then compute all
    pdfs = [ self.mdist.get_pdffunc(theta) for theta in self.xdeg ]
    Y = np.array([ f(r_mdist) for f in pdfs ]).T
    return density_normalize(Y, axis=1)

  def get_y_post(self, y_prior:np.ndarray, y_like:np.ndarray):
    y = y_prior * y_like
    return density_normalize(y)

  def get_Y_post(self, y_prior:np.ndarray, Y_like:np.ndarray):
    Y = np.array([ self.get_y_post(y_prior, y_like) for y_like in Y_like ])
    return Y

  def get_x_est(self, y_post:np.ndarray, func=cmean):
    x = func(self.xdeg, y_post)
    return x

  def get_X_est(self, Y_post:np.ndarray, func=cmean):
    X = np.array([ self.get_x_est(y_post, func) for y_post in Y_post ])
    return np.squeeze(X)

  def get_y_est(self, X_est:np.ndarray):
    x_, y_ = kde(np.deg2rad(X_est), circular=True)
    y      = np.interp(self.xrad, x_, y_, period=2*np.pi)
    return density_normalize(y)

In [9]:
# @title `funcs: mean_*`
def mean_identity():
  def f(theta):
    return theta
  return f

def mean_opposite():
  def f(theta):
    return (theta + 180) % 360
  return f

In [75]:
# @title `funcs: sigma_*`
def sigma_static(value:float):
  def f(theta):
    return np.full_like(theta, value)
  return f

def sigma_cardinal_sinabs(smin:float, smax:float, p=None):
  assert 0 < smin < smax, "need 0 > smin > smax"
  def f(theta):
    return smin + (smax-smin) * np.abs(np.sin(2*np.deg2rad(theta)))
  return f

def sigma_cardinal_sinpow(smin:float, smax:float, p:int=2):
  assert 0 < smin < smax, "need 0 > smin > smax"
  assert p > 0 and p % 2 == 0, "p must be positive even integer"
  def f(theta):
    return smin + (smax-smin) * np.sin(2*np.deg2rad(theta))**p
  return f

def sigma_cardinal_cospow(smin:float, smax:float, p:int=2):
  assert 0 < smin < smax, "need 0 > smin > smax"
  assert p > 0 and p % 2 == 0, "p must be positive even integer"
  def f(theta):
    return smax + (smin-smax) * np.cos(2*np.deg2rad(theta))**p
  return f

def sigma_vonmises(smin:float, smax:float, vmloc:float, vmsigma:float):
  assert 0 < smin < smax, "need 0 > smin > smax"
  vmkappa = sigma2kappa(vmsigma)
  vm = vonmises(loc=np.deg2rad(vmloc), kappa=vmkappa)
  x  = np.linspace(0, 2*np.pi, 100)
  y  = (tmp:=vm.pdf(x)) / np.max(tmp)
  def f(theta):
    theta = np.atleast_1d(np.deg2rad(theta))
    inds  = np.argmin([ np.abs(t-x) for t in theta ])
    return smax - (smax-smin) * y[inds]
  return f

In [11]:
# @title `showresponses`
def showresponses(obs:BayesObserver):
  xrel = center_angle(obs.x, 225, 360)
  inds = np.argsort(xrel)
  sort = lambda x: x[inds]

  def makeplot(obs, theta_rel, n, whichfunc):
    func = cmode if whichfunc=='mode' else cmean

    theta = uncenter_angle(theta_rel, 225, 360)
    obs.respond(theta, n, func)
    fig, axs = plt.subplots(2, 2, figsize=(8,4))

    x       = [sort(xrel), obs.x]
    x_stim  = [theta_rel, theta]
    y_mdist = [sort(obs.y_mdist), obs.y_mdist]
    y_prior = [sort(obs.y_prior), obs.y_prior]
    y_est   = [sort(obs.y_est), obs.y_est]

    for (i,j), ax in np.ndenumerate(axs):
      # settings left
      if j==0:
        ax.set_xlim(-180,180)
        ax.set_xticks(np.linspace(-180,180,5))
        ax.set_xlabel('Relative Angle [deg]' if i==1 else '')
      # settings right
      if j==1:
        ax.set_xlim(0,360)
        ax.set_xticks(np.linspace(0,360,5))
        ax.set_yticks([])
        ax.set_xlabel('Angle [deg]' if i==1 else '')
      # data top
      if i==0:
        ax.axvline(x_stim[j], ls='--', lw=1, c='black', label=f'Stim({x_stim[j]:.0f})')
        ax.plot(x[j], y_mdist[j], ls='-', lw=1, c='red', label='Mdist')
      if i==1:
        ax.axvline(x_stim[j], ls='--', lw=1, c='black', label=f'Stim({x_stim[j]:.0f})')
        ax.plot(x[j], y_prior[j], ls='-', lw=1, c='black', label='Prior')
        ax.plot(x[j], y_est[j], ls='-', lw=1, c='magenta', label='Responses')
      # all
      ax.legend()

  widget.interact(
    makeplot,
    obs=widget.fixed(bob),
    whichfunc=widget.Dropdown(options=['mode', 'mean'], value='mode', description='Function'),
    theta_rel=widget.FloatSlider(min=-180, max=179, step=5, value=-170, readout_format='.0f', continuous_update=False),
    n=widget.fixed(10000)
  )

In [12]:
# @title `showinference`
def showinference(obs:BayesObserver):
  x = obs.x
  y_prior = obs.get_y_prior()
  Y_sigma = np.array([ f(x) for f in obs.mdist.sigma_funcs ]).T

  def makeplot(obs, theta, theta_til, whichfunc):
    func = cmode if whichfunc=='mode' else cmean
    y_mdist = obs.get_y_mdist(theta)
    y_like  = obs.get_y_like(obs.xdeg, theta_til)
    y_post  = obs.get_y_post(y_prior, y_like)
    x_est   = obs.get_x_est(y_post, func)

    fig, axs = plt.subplots(
      nrows=3, ncols=1, sharex=True,
      figsize=(6, 6), gridspec_kw={'height_ratios': [0.5, 0.5, 1]}
    )

    for i,ax in enumerate(axs):
      ax.set_xlim(0,360)
      ax.set_xticks(np.linspace(0,360,5))
      ax.set_xlabel('Angle [deg]')
      ax.axvline(theta, ls='--', lw=1, c='black', label=f'Stim')
      i > 0 and ax.axvline(theta_til, ls='--', lw=1, c='red', label='Meas')

    for ax in [axs[0]]:
      ax.plot(x, Y_sigma, c='gray', ls='-', lw=1, label=r'$\sigma$')
      ax.set_ylabel(r'$\sigma$')

    for ax in [axs[1]]:
      ax.plot(x, y_mdist, ls='-', lw=1, c='red', label='Mdist')
      ax.legend()

    for ax in [axs[2]]:
      ax.axvline(x_est, ls='--', lw=1, c='dodgerblue', label='Estim')
      ax.plot(x, y_prior, ls='-', lw=1, c='black', label='Prior')
      ax.plot(x, y_like, ls='-', lw=1, c='red', label='Like')
      ax.plot(x, y_post, ls='-', lw=1, c='dodgerblue', label='Post')
      ax.legend()

  widget.interact(
    makeplot,
    obs=widget.fixed(bob),
    theta=widget.FloatSlider(min=0, max=359, step=1, value=225-170, readout_format='.0f', continuous_update=False),
    theta_til=widget.FloatSlider(min=0, max=359, step=1, value=225-160, readout_format='.0f', continuous_update=False),
    whichfunc=widget.Dropdown(options=['mode', 'mean'], value='mean', description='Function')
  )

In [158]:
pd_stimulus = Prior(weights=1, means=225, sigmas=85)
pd_cardinal = Prior(weights=(.25,.25,.25,.25), means=(0,90,180,270), sigmas=(10,10,10,10))

md_simple    = Mdist(1, sigma_static(20), mean_identity())
md_cardinal  = Mdist(1, sigma_cardinal_cospow(10,40,10), mean_identity())
md_waterfall = Mdist((.5,.5), (sigma_static(30), sigma_static(35)), (mean_identity(), mean_opposite()))
md_prior     = Mdist(1, sigma_vonmises(10,20,225,80), mean_identity())
md_waterfall_prior = Mdist((.5,.5), (sigma_vonmises(10,30,225,90), sigma_vonmises(10,30,225,90)), (mean_identity(), mean_opposite()))


# bob = BayesObserver(pd_stimulus, md_simple, 100)
# bob = BayesObserver(pd_stimulus, md_cardinal, 100)
# bob = BayesObserver(pd_cardinal, md_simple, 100)
# bob = BayesObserver(pd_cardinal, md_cardinal, 100)
# bob = BayesObserver(pd_stimulus, md_waterfall, 100)
# bob = BayesObserver(pd_stimulus, md_prior, 100)
bob = BayesObserver(pd_stimulus, md, 100)

# showinference(bob)
showresponses(bob)

interactive(children=(FloatSlider(value=-170.0, continuous_update=False, description='theta_rel', max=179.0, m…

In [157]:
md = Mdist((1,1), (sigma_vonmises(1,10,225,10), sigma_vonmises(32,100,225,10)), (mean_identity(), mean_opposite()))
md.interact()

interactive(children=(FloatSlider(value=180.0, continuous_update=False, description='theta', max=359.0, readou…