In [65]:
%matplotlib inline
import numpy as np

import scipy.stats as st
import matplotlib.pyplot as plt
from scipy import stats
import math
import datetime

In [66]:
mean1 = [0 for _ in range(101)]

cov1 = np.diag([1 for _ in range(101)])

# Knowledge
theta1 = stats.multivariate_normal(mean=mean1, cov=cov1)

mean2 = [0 for _ in range(4)]

cov2 = np.diag([0.8 for _ in range(4)])

# Skill
theta2 = stats.multivariate_normal(mean=mean2, cov=cov2)

# Take samples
data1 = theta1.rvs()
data2 = theta2.rvs()
print(type(data2))


<class 'numpy.ndarray'>


In [67]:
theta = np.array([1, 2, 3])
mean = np.array([0.5, 0.5, 0.5])
cov = np.diag([1, 1, 1])
print(np.matmul(cov, (theta - mean)))


[ 0.5  1.5  2.5]


In [91]:
def pseudo_posterior(theta, mean, cov, knowledge_weight, skill_weight, 
                     answer, difficulty, guessing_para, N=101):
    '''
    theta: 
        
    '''
    p = guessing_para + (1 - guessing_para) * \
        sigmoid(np.matmul(np.transpose(theta[0:N]), knowledge_weight) - difficulty) * \
        sigmoid(np.matmul(np.transpose(theta[N:]), skill_weight) - difficulty)
    
    # Calculate likelihood
    likelihood = p**answer * (1-p)**(1-answer)
    
    pseudo_posterior = likelihood * \
    math.exp(-0.5*np.matmul(np.matmul(np.transpose(theta - mean), np.linalg.inv(cov)), (theta - mean)))
    
    return pseudo_posterior

def sigmoid(x):
    return 1 / (1 + math.exp(-x))

def time_to_day(time_delta):
    days = time_delta.day + time_delta.hour / 24 + time_delta.minute / 1440 + time_delta.second / 86400
    return days
    
def gibbs_sampling(mean, cov, knowledge_weight, skill_weight, 
                     answer, difficulty, guessing_para, N=101):
    '''
    return:
        mean, cov of approximation with gibbs sampling
    '''
    prior_dist = stats.multivariate_normal(mean=mean, cov=cov)
    current_theta = prior_dist.rvs()
    # Calculate w_l
    Gibbs_samples = []
    for iter in range(600):
        time1 = datetime.datetime.now()
        for i in range(105):
            z = prior_dist.rvs(size=100)
            for item in z:
                for index in range(len(item)):
                    if index == i:
                        continue
                    item[index] = current_theta[index]
            numerator = []
            for item in z:
                p = pseudo_posterior(item, mean, cov, knowledge_weight, skill_weight, 
                     answer, difficulty, guessing_para)
                q = prior_dist.pdf(item)
                numerator.append(p/q)
            denominator = sum(numerator)
            w_l = [x/denominator for x in numerator]
            chosen_z_idx = np.random.choice(100, 1, p=w_l)
            chosen_z = z[chosen_z_idx]
            current_theta[i] = chosen_z[0][i]
        Gibbs_samples.append(current_theta)
        time2 = datetime.datetime.now()
        print('Iteration {}. Time: {}\n'.format(iter, time2 - time1))
    
    Gibbs_samples = Gibbs_samples[100:]
    Gibbs_samples = np.array(Gibbs_samples)
    mean = np.mean(Gibbs_samples, axis=0)
    cov = np.cov(Gibbs_samples, rowvar=False)
    
    return mean, cov
    
                 
    

# Update distribution
def update_theta(mean, cov, question_no, answer, last_time,
                 knowledge_weight, skill_weight, difficulty, guessing_para, N=101):
    '''
    Input:
    answer: binary, 0 - wrong, 1 - correct
    knowledge_weight: 101*1 vector
    skill_weight: 4*1 vector
    difficulty: string
    
    return:
    new_mean
    new_cov
    update_time
    '''
    # Get time delta
    # t2 - t1 (in days)
    now = datetime.datetime.now()
    days = time_to_day(now - last_time)
                 
    # Step0: Brownian motion
    mean = mean + 0.008 * days
    cov = np.add(cov, np.diag([0.008*days for _ in range(105)]))
    
    # Step1: Get prior sample
    norm_dist = stats.multivariate_normal(mean=mean, cov=cov)
    prior_sample = norm_dist.rvs()
    prior_sample_knowledge = prior_sample[0:N]
    prior_sample_skill = prior_sample[N:]
    
    # Step2: Generating samples from posterior
    gibbs_mean, gibbs_cov = gibbs_sampling(mean, cov, knowledge_weight, skill_weight, 
                     answer, difficulty, guessing_para)    
    

In [92]:
knowledge_weight = [np.random.choice([0,1]) for _ in range(101)]

In [93]:
skill_weight = [np.random.choice([0,1]) for _ in range(4)]

In [94]:
time1 = datetime.datetime.now()
mean, cov = gibbs_sampling([0 for _ in range(105)], np.diag([1 for _ in range(101)] + [0.8 for _ in range(4)]),
                          knowledge_weight, skill_weight,1, 2, 0.25)
time2 = datetime.datetime.now()
print(time2-time1)

Iteration 0. Time: 0:00:08.294427

Iteration 1. Time: 0:00:08.932065

Iteration 2. Time: 0:00:07.679656

Iteration 3. Time: 0:00:08.299462

Iteration 4. Time: 0:00:09.161990

Iteration 5. Time: 0:00:08.679282

Iteration 6. Time: 0:00:10.209145

Iteration 7. Time: 0:00:08.292400

Iteration 8. Time: 0:00:08.113947

Iteration 9. Time: 0:00:09.050769

Iteration 10. Time: 0:00:08.981661

Iteration 11. Time: 0:00:07.869338

Iteration 12. Time: 0:00:07.992737

Iteration 13. Time: 0:00:07.706870

Iteration 14. Time: 0:00:07.998259

Iteration 15. Time: 0:00:08.177054

Iteration 16. Time: 0:00:08.376938

Iteration 17. Time: 0:00:08.315301

Iteration 18. Time: 0:00:07.968607

Iteration 19. Time: 0:00:07.884188

Iteration 20. Time: 0:00:08.334739

Iteration 21. Time: 0:00:07.545504

Iteration 22. Time: 0:00:08.759003

Iteration 23. Time: 0:00:08.056518

Iteration 24. Time: 0:00:08.465745

Iteration 25. Time: 0:00:09.010600

Iteration 26. Time: 0:00:07.188879

Iteration 27. Time: 0:00:07.822995

It

Iteration 225. Time: 0:00:04.531942

Iteration 226. Time: 0:00:04.512602

Iteration 227. Time: 0:00:04.533697

Iteration 228. Time: 0:00:04.506520

Iteration 229. Time: 0:00:04.538717

Iteration 230. Time: 0:00:04.724753

Iteration 231. Time: 0:00:04.982524

Iteration 232. Time: 0:00:05.240315

Iteration 233. Time: 0:00:04.942848

Iteration 234. Time: 0:00:05.056415

Iteration 235. Time: 0:00:05.143165

Iteration 236. Time: 0:00:04.874807

Iteration 237. Time: 0:00:04.843825

Iteration 238. Time: 0:00:04.899584

Iteration 239. Time: 0:00:04.832660

Iteration 240. Time: 0:00:04.838086

Iteration 241. Time: 0:00:04.877733

Iteration 242. Time: 0:00:04.989270

Iteration 243. Time: 0:00:04.537953

Iteration 244. Time: 0:00:04.494678

Iteration 245. Time: 0:00:04.549050

Iteration 246. Time: 0:00:04.545220

Iteration 247. Time: 0:00:04.533043

Iteration 248. Time: 0:00:04.530476

Iteration 249. Time: 0:00:04.510679

Iteration 250. Time: 0:00:04.495373

Iteration 251. Time: 0:00:04.517317

I

Iteration 447. Time: 0:00:04.496559

Iteration 448. Time: 0:00:04.501238

Iteration 449. Time: 0:00:04.472168

Iteration 450. Time: 0:00:04.494204

Iteration 451. Time: 0:00:04.548294

Iteration 452. Time: 0:00:04.512826

Iteration 453. Time: 0:00:04.490299

Iteration 454. Time: 0:00:04.645942

Iteration 455. Time: 0:00:04.527444

Iteration 456. Time: 0:00:04.494745

Iteration 457. Time: 0:00:04.493722

Iteration 458. Time: 0:00:04.494091

Iteration 459. Time: 0:00:04.478856

Iteration 460. Time: 0:00:04.498930

Iteration 461. Time: 0:00:04.523507

Iteration 462. Time: 0:00:04.499023

Iteration 463. Time: 0:00:04.490122

Iteration 464. Time: 0:00:04.510287

Iteration 465. Time: 0:00:04.532004

Iteration 466. Time: 0:00:04.487203

Iteration 467. Time: 0:00:04.719922

Iteration 468. Time: 0:00:04.504537

Iteration 469. Time: 0:00:04.703335

Iteration 470. Time: 0:00:04.501727

Iteration 471. Time: 0:00:04.507321

Iteration 472. Time: 0:00:04.498844

Iteration 473. Time: 0:00:04.521897

I

In [96]:
print(max(mean))
print(min(mean))

1.80247038102
-2.31798419313


In [98]:
print(mean)

[ 1.28247188  0.63748917  1.80247038  1.02918157  0.82248318 -0.71480715
  0.29587115  0.14857127  1.43956019  0.33251357  0.34735089  1.65440324
 -0.57412651  1.74553235  1.26245191 -1.39858631  0.5601826  -0.99654739
  0.7020176   0.8846054   0.24879804  0.86815845  0.06132029 -1.13023133
 -1.43556112  0.25183769 -1.91157598  0.76030668  0.28293769  1.13939494
  0.31108729 -0.51236596  0.48555772 -2.06741238  0.58440537 -0.07790519
  1.35153606  0.87616974  1.55309921  0.03226904  0.16199132  0.59017867
 -0.85261749 -0.1526775   0.28320064  0.4217614  -0.95205898 -0.80692458
  0.4578584  -0.40566918  1.51877578  0.56687946  1.42510046 -0.59580839
 -0.4750243  -1.36430736  1.26249885  0.72349948 -0.95632585  0.42967481
 -1.07403307  1.43038483 -1.56994222 -0.49869158 -0.24300631  0.68196646
  0.86248651 -0.77804233 -1.3487668   0.64610209 -0.65888903 -0.7419359
  0.49722286 -0.59613539 -1.05151505  0.35836648  0.56173366 -2.31798419
  1.02286052 -0.33457472 -0.36459745 -0.57716738  0.