# Causal inference model for auditory-visual duration discrimination with conflict

# Notations:
- $S_a$: true auditory duration
- $S_v$: true visual duration
- $m_a$: noisy auditory measurement
- $m_v$: noisy visual measurement
- $\sigma_a$: auditory noise (standard deviation)
- $\sigma_v$: visual noise (standard deviation)
- $C$: common cause (1 if common, 2 if independent)
- $\mu_p$: prior bias (assumed to be 0 for simplicity)
- $\sigma_p$: prior noise (assumed to be infinite for simplicity)
- $\sigma_{av,a}$: effective auditory noise in the AV condition
- $\sigma_{av,v}$: effective visual noise in the AV condition
- $\hat{S}_{av,a}$: estimated auditory duration in the AV condition
- $\hat{S}_{av,v}$: estimated visual duration in the AV condition
- $\hat{S}_{CI,t}$: final internal estimate for the test interval
- $\hat{S}_{CI,s}$: final internal estimate for the standard interval



# 1 - Reliability based duration estimation assuming fully fusion


$\hat{S}_{av,a}=\hat{S}_{av,v}= \frac{\sigma_{av,a}^{-2} m_a+\sigma_{av,v}^{-2} m_v}{\sigma_{av,a}^{-2} + \sigma_{av,v}^{-2}}$ 

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider, IntSlider
#import normal
from scipy.stats import norm
# unimodal measurements
def unimodalMeasurements(sigma, S):
    # P(x|s) # generate measurements from a normal distribution
    m = np.random.normal(S, sigma, 1000)  # true duration is S seconds
    return m

# probability density function of a Gaussian distribution
def gaussianPDF(x,S, sigma):
	return (1/(np.sqrt(2*np.pi)*sigma))*np.exp(-((x-S)**2)/(2*(sigma**2)))

# likelihood function
def likelihood( S, sigma):
    # P(m|s) # likelihood of measurements given the true duration
    m=np.linspace(S - 4*sigma, S + 4*sigma, 500)
    p_m=gaussianPDF(m,S,sigma)
    return m, p_m

def plotLikelihood(S,sigma):
	x, p_x = likelihood(S, sigma)
	plt.plot(x, p_x, label='Likelihood Function')
	plt.xlabel('Measurement $m$')
	plt.ylabel('Probability Density')
	plt.title('Analytical Likelihood $P(m|s)$')
	plt.legend()

def plotMeasurements(sigma, S):
    m = unimodalMeasurements(sigma, S)
    plt.hist(m, bins=50, density=True, alpha=0.5, label='Measurements Histogram')
    plt.xlabel('Measurement $m$')
    plt.ylabel('Density')
    plt.title('Unimodal Measurements Histogram')
    plt.legend()      


def plotMeasurementsAndLikelihood(sigma, S):
    plt.figure(figsize=(10, 6))
    plotMeasurements(sigma, S)
    plotLikelihood(S, sigma)
    plt.xlim(-1, 1)

    plt.show()

# interactive plotting
import ipywidgets as widgets
interact(plotMeasurementsAndLikelihood,
         sigma=widgets.FloatSlider(value=0.2, min=0.1, max=5.0, step=0.1, description='Sigma'),
         S=widgets.FloatSlider(value=0.5, min=-1.5, max=1.5, step=0.1, description='True Duration S'))


interactive(children=(FloatSlider(value=0.2, description='Sigma', max=5.0, min=0.1), FloatSlider(value=0.5, de…

<function __main__.plotMeasurementsAndLikelihood(sigma, S)>

### 2.1 Fusion (C=1)

### **2.1.1 Fusion of one interval**

$$\hat{S}_{av,a}=\hat{S}_{av,v}= \frac{\sigma_{av,a}^{-2} m_a+\sigma_{av,v}^{-2} m_v}{\sigma_{av,a}^{-2} + \sigma_{av,v}^{-2}}\\ 
= w_aS_a+w_vS_v$$ 
$$J_a=\frac{1}{\sigma_{av,a}^{2}} \\
J_v=\frac{1}{\sigma_{av,v}^{2}}\\
\sigma_{av}^2=\frac{1}{J_1+J_2}$$

$$p(S|m_a,m_v)\sim p(S)p(m_a|S)p(m_v|S)\\
p(S|m_a,m_v)\sim N(\hat S_{av},\sigma_{av}^2)\\
$$


In [10]:
# compute fused estimate (reliability weighted avg)
def fusionAV(sigmaAV_A,sigmaAV_V, S_a, visualConflict):
	m_a=unimodalMeasurements(sigmaAV_A, S_a)
	S_v=S_a+visualConflict
	m_v = unimodalMeasurements(sigmaAV_V,S_v)  # visual measurement
	# compute the precisons inverse of variances
	J_AV_A= sigmaAV_A**-2 # auditory precision
	J_AV_V=sigmaAV_V**-2 # visual precision
	# compute the fused estimate using reliability weighted averaging
	hat_S_AV= (J_AV_A*m_a+J_AV_V*m_v)/(J_AV_V+J_AV_A)
	sigma_S_AV_hat=np.sqrt(1 / (J_AV_A + J_AV_V))  # fused standard deviation

	return hat_S_AV, sigma_S_AV_hat


In [None]:
# create interactive plot 
from ipywidgets import interact, FloatSlider, IntSlider
#import normal
from scipy.stats import norm

def generativeModelPlot(sigmaAV_A, sigmaAV_V, S_a,visualConflict,showHistograms=True):
    # Go analytic likelihoods
    # S: True Stim
    S_v=S_a+visualConflict


    # plot the likelihoods
    plt.figure(figsize=(10, 6))
    # line plot probablity distribution of auditory measurement
    colorA='teal'
    x_a,p_a = likelihood( S_a, sigmaAV_A) # probability density function for auditory measurement
    plt.plot(x_a, p_a, color=colorA, label='Auditory PDF')
    plt.axvline(S_a, color=colorA, label= "Auditory Stimulus", linestyle='--')

    # Plot measurements and true stimulus value for Visual
    x_v= np.linspace(S_v - 4*sigmaAV_V, S_v + 4*sigmaAV_V, 500)
    colorV='forestgreen'
    x_v,p_v=likelihood(S_v, sigmaAV_V) # args: mean, std and m_v means measurement
    plt.plot(x_v, p_v, color=colorV, label='Visual PDF')
    plt.axvline(S_v, color=colorV, label= "Visual Stimulus", linestyle='--')

    # Fused estimate
    hat_S_AV , sigma_S_AV_hat= fusionAV(sigmaAV_A, sigmaAV_V, S_a, visualConflict)
    plt.axvline(np.mean(hat_S_AV), color='orange', label='$\hat S_{av}$Fused Estimate', linestyle='--')
    
    # fused likelihood analytically
    J_a = sigmaAV_A**-2  # auditory precision
    J_v = sigmaAV_V**-2  # visual precision

    w_a = J_a / (J_a + J_v)  # weight for auditory
    w_v = 1 - w_a  # weight for visual
    mu_Shat = w_a * S_a + w_v * S_v  # fused mean
    
    x_av = np.linspace(mu_Shat - 4 * sigma_S_AV_hat, mu_Shat + 4 * sigma_S_AV_hat, 500)
    p_S_AV= gaussianPDF(x_av,mu_Shat,sigma_S_AV_hat)

    #x_av, p_S_AV = likelihood(mu_Shat, sigma_S_AV_hat)
    plt.plot(x_av, p_S_AV, color='orange', label=f'$S$ Fused PDF')
    plt.axvline(mu_Shat, color='orange', linestyle='--')

    # plot the measurements
    if showHistograms:
        plt.hist(unimodalMeasurements(sigmaAV_A, S_a), bins=30, density=True, alpha=0.5, color='teal')
        plt.hist(unimodalMeasurements(sigmaAV_V, S_v), bins=30, density=True, alpha=0.5, color='forestgreen')
        plt.hist(hat_S_AV, bins=30, density=True, alpha=0.5, color='orange' )
    plt.title('Generative Model for AV Fusion')
    plt.xlabel('Internal Measurement / Estimate')
    plt.ylabel('Probability Density')
    plt.legend(loc='best')
    #plt.tight_layout()
    plt.xlim(-1.5,1.5)

# Interactive Plotting
interact(generativeModelPlot,
         sigmaAV_A=widgets.FloatSlider(value=0.2, min=0.1, max=5.0, step=0.1, description='$Sigma_{AV,A}'),
         sigmaAV_V=widgets.FloatSlider(value=0.2, min=0.1, max=5.0, step=0.1, description='Sigma AV V'),
         S_a=widgets.FloatSlider(value=0.3, min=-1.5, max=1.5, step=0.1, description='True Duration S_a'),
         visualConflict=widgets.FloatSlider(value=0.9, min=0.0, max=2.0, step=0.1, description='Visual Conflict'),
         showHistograms=widgets.Checkbox(value=True, description='Show Histograms'))


interactive(children=(FloatSlider(value=0.2, description='$Sigma_{AV,A}', max=5.0, min=0.1), FloatSlider(value…

<function __main__.generativeModelPlot(sigmaAV_A, sigmaAV_V, S_a, visualConflict, showHistograms=True)>

## 2.1.2 Fusion of two intervals

\begin{align}
\Delta_{t-s}=w_a({m_a^t} -m_a^S)+ w_v ({m_v^t} -m_v^s)\\
=w_a\Delta S_a +w_v \Delta S_v
\end{align}

In [None]:
# create interactive plot 
from ipywidgets import interact, FloatSlider, IntSlider
#import normal
from scipy.stats import norm

def generativeModelPlot(sigmaAV_A, sigmaAV_V, S_a,visualConflict,showHistograms=True):
    # Go analytic likelihoods
    # S: True Stim
    S_v=S_a+visualConflict


    # plot the likelihoods
    plt.figure(figsize=(10, 6))
    # line plot probablity distribution of auditory measurement
    colorA='teal'
    x_a,p_a = likelihood( S_a, np.sqrt(2)*sigmaAV_A) # probability density function for auditory measurement
    plt.plot(x_a, p_a, color=colorA, label='Auditory PDF')
    plt.axvline(S_a, color=colorA, label= "Auditory Stimulus", linestyle='--')

    # Plot measurements and true stimulus value for Visual
    x_v= np.linspace(S_v - 4*sigmaAV_V, S_v + 4*sigmaAV_V, 500)
    colorV='forestgreen'
    x_v,p_v=likelihood(S_v,np.sqrt(2)*sigmaAV_V) # args: mean, std and m_v means measurement
    plt.plot(x_v, p_v, color=colorV, label='Visual PDF')
    plt.axvline(S_v, color=colorV, label= "Visual Stimulus", linestyle='--')

    # Fused estimate
    hat_S_AV , sigma_S_AV_hat= fusionAV(sigmaAV_A, sigmaAV_V, S_a, visualConflict)
    plt.axvline(np.mean(hat_S_AV), color='orange', label='$\hat S_{av}$Fused Estimate', linestyle='--')
    
    # fused likelihood analytically
    J_a = sigmaAV_A**-2  # auditory precision
    J_v = sigmaAV_V**-2  # visual precision

    w_a = J_a / (J_a + J_v)  # weight for auditory
    w_v = 1 - w_a  # weight for visual
    mu_Shat = w_a * S_a + w_v * S_v  # fused mean
    
    x_av = np.linspace(mu_Shat - 4 * sigma_S_AV_hat, mu_Shat + 4 * sigma_S_AV_hat, 500)
    p_S_AV= gaussianPDF(x_av,mu_Shat,np.sqrt(2)*sigma_S_AV_hat)

    #x_av, p_S_AV = likelihood(mu_Shat, sigma_S_AV_hat)
    plt.plot(x_av, p_S_AV, color='orange', label=f'$S$ Fused PDF')
    plt.axvline(mu_Shat, color='orange', linestyle='--')

    # plot the measurements
    if showHistograms:
        plt.hist(unimodalMeasurements(sigmaAV_A, S_a), bins=30, density=True, alpha=0.5, color='teal')
        plt.hist(unimodalMeasurements(sigmaAV_V, S_v), bins=30, density=True, alpha=0.5, color='forestgreen')
        plt.hist(hat_S_AV, bins=30, density=True, alpha=0.5, color='orange' )
    plt.title('Generative Model for AV Fusion')
    plt.xlabel('Internal Measurement / Estimate')
    plt.ylabel('Probability Density')
    plt.legend(loc='best')
    #plt.tight_layout()
    plt.xlim(-1.5,1.5)

# Interactive Plotting
interact(generativeModelPlot,
         sigmaAV_A=widgets.FloatSlider(value=0.2, min=0.1, max=5.0, step=0.1, description='Sigma AV,A'),
         sigmaAV_V=widgets.FloatSlider(value=0.2, min=0.1, max=5.0, step=0.1, description='Sigma AV V'),
         S_a=widgets.FloatSlider(value=0.3, min=-1.5, max=1.5, step=0.1, description='True Duration S_a'),
         visualConflict=widgets.FloatSlider(value=0.9, min=0.0, max=2.0, step=0.1, description='Visual Conflict'),
         showHistograms=widgets.Checkbox(value=True, description='Show Histograms'))
