## Define the item response model


In [None]:
import torch
from torch.distributions import Bernoulli
from tqdm import trange
import matplotlib.pyplot as plt

class ResponseModel(torch.nn.Module):
    def __init__(self, n_students, n_questions):
        super().__init__()
        self.skill = torch.nn.Parameter(torch.zeros(n_students), requires_grad=True)
        self.difficulty = torch.nn.Parameter(torch.zeros(n_questions), requires_grad=True)

    def forward(self):
        return Bernoulli(probs=torch.sigmoid(self.skill[:, None] - self.difficulty[None, :]))

## Generate some data


In [None]:
def generate_data(n_students, n_questions):
    model = ResponseModel(n_students, n_students)
    model.skill.data = -3 + 6*torch.rand(n_students)
    model.difficulty.data = -3 + 6*torch.rand(n_questions)
    answers = model().sample()
    return model, answers

true_model, answers = generate_data(n_students=100, n_questions=100)
print(true_model.skill, true_model.difficulty)
print(answers)

## Fit the model by Maximum Likelihood Estimation
Optimized by Stochastic Gradient Descent

In [None]:
def fit_model(answers, epochs, alpha=1.0):
    model = ResponseModel(n_students=answers.shape[0], n_questions=answers.shape[1])
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    losses = []

    device="cuda:0"
    model = model.to(device)
    answers = answers.to(device)
    for _ in trange(epochs):
        optimizer.zero_grad()
        distribution = model()
        loss = -distribution.log_prob(answers).mean() + alpha*torch.pow(model.difficulty.mean(), 2)
        losses.append(loss.detach())
        loss.backward()
        optimizer.step()
    model.to("cpu")
    return model, losses

model, losses = fit_model(answers, epochs=1000, alpha=1e-2)
plt.plot(losses)
plt.yscale("log")


## Visualize estimation errors

In [None]:
plt.hist((true_model.difficulty - model.difficulty).detach().numpy(), density=True, label="difficulty error")
plt.hist((true_model.skill - model.skill).detach().numpy(), density=True, label="skill error")
plt.legend()
