In [2]:
import numpy as np
import numpy.matlib as ml
import numba
import matplotlib.pyplot as plt

In [30]:
@numba.njit
def calculate_likelihood_c1(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP):
    # likelihood P(Xv, Xa|C =1)
    
    firstDenom = 2*np.pi*np.sqrt(varV*varA + varV*varP +varA*varP)
    firstTerm = 1/firstDenom 
    secondNum = (Xv - Xa)**2 * varP + (Xv -0)**2 * varA + (Xa - 0)**2* varV 
    secondDenom = (varV * varA) + (varV * varP) + (varA * varP)
    secondTerm = np.exp((-0.5*(secondNum/secondDenom)))
    likelihood_com = firstTerm*secondTerm

    return likelihood_com

@numba.njit
def calculate_likelihood_c2(Xv,Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP):
    # likelihood P(Xv, Xa|C =2)
    
    firstTerm = 2*np.pi*np.sqrt((varV + varP)*(varA+varP))
    secondTerm1 = (Xv - 0)**2/(varV + varP)
    secondTerm2 = (Xa - 0)**2 / (varA + varP)
    secondTermFull = np.exp((-0.5*(secondTerm1+secondTerm2)) )
    likelihood_ind = secondTermFull/firstTerm

    return likelihood_ind

@numba.njit
def calculate_posterior(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP):
    # p(C = 1|Xv,Xa) posterior
    
    likelihood_common = calculate_likelihood_c1(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)
    likelihood_ind = calculate_likelihood_c2(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)
    post_common = likelihood_common * pCommon 
    post_indep = likelihood_ind * (1-pCommon)
    posterior = post_common/(post_common +post_indep)

    return posterior

def opt_position_conditionalised_C1(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP):
    # Get optimal location given C = 1
    
    cues = Xv/varV + Xa/varA + ml.repmat(pCommon,N,1)/varP
    inverseVar = 1/varV + 1/varA + 1/varP
    sHatC1 = cues/inverseVar

    return sHatC1

def opt_position_conditionalised_C2(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP):
        # Get optimal locationS given C = 2
        
    visualCue = Xv/varV +ml.repmat(pCommon,N,1)/varP
    visualInvVar = 1/varV + 1/ varP
    sHatVC2 = visualCue/visualInvVar
    audCue = Xa/varA + ml.repmat(pCommon,N,1)/varP
    audInvVar = 1/varA + 1/ varP
    sHatAC2 = audCue/audInvVar

    return sHatVC2, sHatAC2

def optimal_visual_location(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP):
    # Use Model Averaging to compute final visual est
    
    posterior_1C = calculate_posterior(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)
    sHatVC1 = opt_position_conditionalised_C1(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)
    sHatVC2 = opt_position_conditionalised_C2(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)[0]
    sHatV = posterior_1C*sHatVC1 + (1-posterior_1C)*sHatVC2 #model averaging

    return sHatV

def optimal_aud_location(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP):
    # Use Model Averaging to compute final auditory est
    
    posterior_1C = calculate_posterior(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)
    sHatAC1 = opt_position_conditionalised_C1(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)
    sHatAC2 = opt_position_conditionalised_C2(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)[1]
    sHatA = posterior_1C*sHatAC1 + (1-posterior_1C)*sHatAC2 #model averaging
    return sHatA


In [31]:
N = 10000
pCommon, sigV, sigA, sigP = 0.1, 10, 1, 2
varV, varA, varP = sigV**2, sigA**2, sigP**2
vloc, aloc = 12, 0

Xv, Xa = sigV * np.random.randn(N,1) + vloc, sigA * np.random.randn(N,1) + aloc

In [41]:

sHatA = optimal_aud_location(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)

In [25]:
@numba.njit
def calculate_posterior(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP):
    # p(C = 1|Xv,Xa) posterior
    
    likelihood_common = calculate_likelihood_c1(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)
    likelihood_ind = calculate_likelihood_c2(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)
    post_common = likelihood_common * pCommon 
    post_indep = likelihood_ind * (1-pCommon)
    posterior = post_common/(post_common +post_indep)
    #plt.hist(posterior)
    #plt.title("posterior")
    #plt.show()
    return posterior

In [26]:
%%timeit
calculate_posterior(Xv, Xa, N, pCommon, sigV, varV, sigA, varA, sigP, varP)

188 µs ± 3.45 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [52]:
def count():

    pass

@numba.njit
def bin(y, possible_locations):
    return min(possible_locations, key = lambda x:abs(x-y))

@numba.njit
def binner(sHatA, possible_locations): 
    for i in sHatA:
        min(possible_locations, key = lambda x:abs(x-i))

    return [min(possible_locations, key=lambda x:abs(x-i)) for i in sHatA]

In [53]:
%%timeit
bin(6, [4, 8.2])

TypingError: Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'y' in a function that will escape.

File "../../../../../var/folders/cz/gy460qws18d1blj5gj79f8m80000gq/T/ipykernel_17579/498289010.py", line 7:
<source missing, REPL/exec in use?>


In [58]:
@numba.njit
def count(binnedV, possible_locations):
    return [binnedV.count(bvc) for bvc in possible_locations]

In [59]:
%%timeit
count([12, 12 ,12 , 12 ,12 , 12 , 12, 24, 24, 24, 24,24 ], [-24, -12, 0, 12, 24])

Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'binnedV' of function 'count'.

For more information visit https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types

File "../../../../../var/folders/cz/gy460qws18d1blj5gj79f8m80000gq/T/ipykernel_17579/1550215794.py", line 1:
<source missing, REPL/exec in use?>

Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'possible_locations' of function 'count'.

For more information visit https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types

File "../../../../../var/folders/cz/gy460qws18d1blj5gj79f8m80000gq/T/ipykernel_17579/1550215794.py", line 1:
<source missing, REPL/exec in use?>

Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'lst' of function 'list_count

12.5 µs ± 62.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
