In [2]:
from torch import nn
import torch
import torch.nn.functional as F
from torch import optim

In [None]:
class DKT(nn.Module):
    
    def __init__(self,input_size,output_size):
        super().__init__()
        self.net = nn.RNN(input_size,100,batch_first=True)
        self.linear = nn.Linear(100,output_size)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self,x):
        # B,T,C
        out,h = self.net(x)
        return self.sigmoid(self.linear(out))

In [9]:
import json
import numpy as np

In [21]:
from tqdm import tqdm

In [177]:
block_size = 50
num_questions = 123

def get_data(name):
    with open('{}.json'.format(name),'r') as f:
        train_data = json.load(f)
        all_data = []
        for x,y in tqdm(train_data):
            x_len = len(x)
            left = block_size - x_len % block_size
            total = x_len + left
            ans = np.zeros((total,num_questions*2))
            for i,(xi,yi) in enumerate(zip(x,y)):
                if yi == 1:
                    ans[i][xi+num_questions] = 1
                else:
                    ans[i][xi] = 1
            all_data.extend(ans.tolist())
    all_data=torch.tensor(all_data)
    all_data=all_data.reshape(-1,block_size,num_questions*2)
    return all_data

In [178]:
all_data=get_data('train')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3746/3746 [00:03<00:00, 959.19it/s]


In [181]:
def parse_raw_data(raw_question_matrix, raw_pred):
    questions = torch.nonzero(raw_question_matrix)[1:, 1] % num_questions
    length = questions.shape[0]
    pred = raw_pred[: length]
    pred = pred.gather(1, questions.view(-1, 1)).flatten()
    truth = torch.nonzero(raw_question_matrix)[1:, 1] // num_questions
    return truth,pred

In [182]:
def parse_raw_data(x,logits):
    # x, T,C
    # logits T,C
    predicted_idx = torch.nonzero(x)[1:,-1] % num_questions
    true_pred = (torch.nonzero(x)[1:,-1] // num_questions)*1.0
    logits = logits[:len(predicted_idx)]
    prediced = logits.gather(1,predicted_idx.view(-1,1)).flatten()
    return true_pred,prediced

In [183]:
all_data.shape

torch.Size([7812, 50, 246])

In [184]:
x=all_data[0]

In [185]:
torch.nonzero(x)[1:,-1]

tensor([212, 212, 212,  89,  89,  89,  89, 212, 212, 212, 212, 212, 212,  89,
        212, 212, 212, 212, 212, 212, 212,  89, 212, 212, 212,  89, 212, 212,
         89,  89, 212, 212, 212, 212, 212,  89, 212,  89,  26,  26, 149,  26,
         26, 149,  26,  25,  25, 148, 148])

In [186]:
np.concatenate([[[1,2]],[[2,3]]])

array([[1, 2],
       [2, 3]])

In [187]:
dkt = DKT(num_questions*2,num_questions)

In [188]:
op = optim.Adam(dkt.parameters(),lr=0.002)
loss_func = nn.BCELoss()

In [191]:
for i,data in enumerate(all_data):
    logits = dkt(data)
    true_pred,pred = parse_raw_data(data,logits)
#     print(true_pred,pred)
    loss = loss_func(pred,true_pred)
#     if torch.isnan(loss):
#         print(data)
#         print(true_pred,pred)
#         print(torch.nonzero(data))
    if i % 1000 == 0:
        print(loss)
    op.zero_grad(set_to_none=True)
    loss.backward()
    op.step()

tensor(0.5861, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.2816, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.3498, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.3393, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.5113, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(1.3329, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.5326, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.2389, grad_fn=<BinaryCrossEntropyBackward0>)


In [190]:
eval_data(all_data)

tensor(0.6951)


In [87]:
def eval_data(all_data):
    true_pre,pre=torch.tensor([]),torch.tensor([])
    for data in all_data:
        logits = dkt(data)
        true_pred,pred = parse_raw_data(data,logits)
        true_pre = torch.cat([true_pre,true_pred])
        pre = torch.cat([pre,(pred>=0.5).float()])
    print((true_pre==pre).sum()/len(pre)) 

In [168]:
eval_data(val_data)

tensor(0.9725)


In [89]:
val_data=get_data('val')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 417/417 [00:00<00:00, 678.18it/s]
