-
Notifications
You must be signed in to change notification settings - Fork 760
/
rasch_model.py
63 lines (49 loc) · 2.12 KB
/
rasch_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Rasch model (Rasch, 1960)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import edward as ed
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from edward.models import Bernoulli, Normal, Empirical
from scipy.special import expit
tf.flags.DEFINE_integer("nsubj", default=200, help="")
tf.flags.DEFINE_integer("nitem", default=25, help="")
tf.flags.DEFINE_integer("T", default=5000, help="Number of posterior samples.")
FLAGS = tf.flags.FLAGS
def main(_):
# DATA
trait_true = np.random.normal(size=[FLAGS.nsubj, 1])
thresh_true = np.random.normal(size=[1, FLAGS.nitem])
X_data = np.random.binomial(1, expit(trait_true - thresh_true))
# MODEL
trait = Normal(loc=0.0, scale=1.0, sample_shape=[FLAGS.nsubj, 1])
thresh = Normal(loc=0.0, scale=1.0, sample_shape=[1, FLAGS.nitem])
X = Bernoulli(logits=trait - thresh)
# INFERENCE
q_trait = Empirical(params=tf.get_variable("q_trait/params",
[FLAGS.T, FLAGS.nsubj, 1]))
q_thresh = Empirical(params=tf.get_variable("q_thresh/params",
[FLAGS.T, 1, FLAGS.nitem]))
inference = ed.HMC({trait: q_trait, thresh: q_thresh}, data={X: X_data})
inference.run(step_size=0.1)
# Alternatively, use variational inference.
# q_trait = Normal(
# loc=tf.get_variable("q_trait/loc", [FLAGS.nsubj, 1]),
# scale=tf.nn.softplus(
# tf.get_variable("q_trait/scale", [FLAGS.nsubj, 1])))
# q_thresh = Normal(
# loc=tf.get_variable("q_thresh/loc", [1, FLAGS.nitem]),
# scale=tf.nn.softplus(
# tf.get_variable("q_thresh/scale", [1, FLAGS.nitem])))
# inference = ed.KLqp({trait: q_trait, thresh: q_thresh}, data={X: X_data})
# inference.run(n_iter=2500, n_samples=10)
# CRITICISM
# Check that the inferred posterior mean captures the true traits.
plt.scatter(trait_true, q_trait.mean().eval())
plt.show()
print("MSE between true traits and inferred posterior mean:")
print(np.mean(np.square(trait_true - q_trait.mean().eval())))
if __name__ == "__main__":
tf.app.run()