BCI Sample Size Determination (SSD)

In [None]:
import numpy as np
import pandas as pd
import pymc3 as pm
from scipy import special

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# M = 5 # iterations
# Ns = [4, 8, 12, 16, 20] # number of subjects
# Ts = [20, 36, 65] # number of trials

M = 1 # iterations
n_subjects = [4] # number of subjects
n_trials = [20] # number of trials

n_samples = 5000
n_chains = 3
n_tune = 500

In [None]:
for Ns in n_subjects:
    ALCs = []
    for T in n_trials:
        ALC = 0
        for m in range(M):
            print("M: "+str(m)+"\t Ns: "+str(Ns)+"\t T: "+str(T))
            
            # draw parameters theta hat from sampling prior
            # draw dataset D^(n) from sampling distribution
            with pm.Model() as model:
                # group level parameters - a single value for mean and std
                group_level_mean_prob = pm.Uniform('μ_φ', lower=0.55, upper=0.95)
                group_level_mean_logit = pm.Deterministic('μ_α', pm.math.logit(group_level_mean_prob))
                group_level_std_logit = pm.Uniform('σ_α', lower=0.2, upper=1.2)
                
                # subject level parameters - a vector of size Ns
                subject_level_accuracy_logit = pm.Normal('α', mu=group_level_mean_logit,
                                                         sd=group_level_std_logit, shape=Ns)
                subject_level_accuracy_prob = pm.Deterministic('φ', pm.math.invlogit(subject_level_accuracy_logit))
                
                # sample the subject level accuracies
                trace = pm.sample(n_samples, chains=n_chains, tune=n_tune, discard_tuned_samples=True)
            
            # compute delta(D^(n)) using Baye's rule (via MCMC)
            # generate data for each subject
            trace_np = pd.DataFrame(trace['α']).to_numpy() # 15000 x Ns array of subject-level accuracies
#             data = np.empty([n_samples*n_chains, Ns])
#             for i in range(trace_np.shape[0]):
#                 for j in range(trace_np.shape[1]):
#                     # draw from a binomial distribution given phi and T
#                     y = pm.Binomial.dist(n=T, p=subject_accuracy)
#                     # calculate the accuracy across all trials
#                     data[i,j] = np.mean(y.random(size=T) / T)
            
            # take mean across all Ns subjects --> vector of 15000
            group_level_mean_logit_hat = np.mean(trace_np, axis=1)
            # compute 95% CI
            delta = np.percentile(group_level_mean_logit_hat, 97.5) - np.percentile(group_level_mean_logit_hat, 2.5)
            print("delta:", delta)

            # approximate ALC
            ALC += delta
        ALC /= M
        print("ALC:", ALC)
        ALCs.append(ALC)
    plt.plot(T, ALCs, marker='x', label='Ns = '+str(Ns))

plt.xlabel('Number of trials T')
plt.ylabel('Average 95% CI width ALC(n)')
plt.legend()
plt.show()

In [None]:
ALCs

In [None]:
pm.traceplot(trace);

In [None]:
print("55% accuracy = " + str(special.logit(0.55)))
print("95% accuracy = " + str(special.logit(0.95)))
print()

def print_alc_effect(alc, prob):
    print("For ALC = " + str(alc) + ":")
    print("CI Range in probability scale for " + str(round(prob*100)) + "%")
    print("= %.2f - %.2f = %.2f" % (prob * 100,
                                    special.expit(special.logit(prob) - alc) * 100,
                                    (prob - special.expit(special.logit(prob) - alc)) * 100))
    print()

print_alc_effect(3, 0.95)
print_alc_effect(3, 0.55)
print_alc_effect(1, 0.95)
print_alc_effect(1, 0.55)

print_alc_effect(ALC, 0.95)
print_alc_effect(ALC, 0.55)

In [None]:
trace_np.shape

In [None]:
y = pm.Binomial.dist(n=20, p=0.564146)
y.random(size=20)

In [None]:
data = np.empty([n_samples*n_chains, Ns])
for i in range(trace_np.shape[0]):
    for j in range(trace_np.shape[1]):
        # draw from a binomial distribution given phi and T
        y = pm.Binomial.dist(n=T, p=0.564146)
        # calculate the accuracy across all trials
        data[i,j] = np.mean(y.random(size=T) / T)