In [20]:
from torch import nn
import torch
import torch.nn.functional as F
import json
import numpy as np
from tqdm import tqdm
from torch import optim

In [106]:
def get_data(name,block_size=50,num_questions=123):
    with open('{}.json'.format(name),'r') as f:
        data = json.load(f)
        ans = []
        for x,y in tqdm(data):
            x_len = len(x)
            left = block_size - x_len % block_size
            total = x_len + left
            res = np.zeros((total,num_questions*2))
            for i,(xi,yi) in enumerate(zip(x,y)):
                if yi == 1:
                    res[i][num_questions+xi] = 1
                else:
                    res[i][xi] = 1
            ans.extend(res.tolist())
        ans =  torch.tensor(ans)
        ans = ans.reshape(-1,block_size,num_questions*2)
        return ans

In [6]:
train_data = get_data('train')
val_data = get_data('val')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3746/3746 [00:04<00:00, 771.46it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 417/417 [00:00<00:00, 737.39it/s]


In [8]:
train_data.shape

torch.Size([7812, 50, 246])

In [9]:
val_data.shape

torch.Size([1573, 50, 246])

In [17]:
def parse_raw_data(x,logits,num_questions=123):
    # x T,C
    # logits T,C
    pred_idx = torch.nonzero(x)[1:,-1] % num_questions
    pred_pro = logits.gather(1,pred_idx.view(-1,1)).flatten()
    true_labels = torch.nonzero(x)[1:,-1] // num_questions
    return true_labels,pred_pro
    

In [18]:
parse_raw_data(train_data[0],train_data[0])

(tensor([1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
         1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1,
         1]),
 tensor([1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0.,
         1., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 1., 0.]))

In [115]:
class Head(nn.Module):
    
    def __init__(self,embed_size,head_size):
        super().__init__()
        self.key = nn.Linear(embed_size,head_size)
        self.query = nn.Linear(embed_size,head_size)
        self.value = nn.Linear(embed_size,head_size)
    
    def forward(self,x):
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        B,T,C = k.shape
        w = q @ k.transpose(-1,-2) * (C**-0.5)
        mask = torch.tril(torch.ones(T,T))
        w = w.masked_fill(mask==0,float("-inf"))
        w = F.softmax(w,dim=-1)
        v = w @ v
        return v

In [116]:
class SAKT(nn.Module):
    
    def __init__(self,input_size,output_size):
        super().__init__()
        head_size = 64
        self.head = Head(input_size,head_size)
        self.linear = nn.Linear(head_size,head_size)
        self.proj = nn.Linear(head_size,output_size)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self,x):
        x = self.head(x)
        x = self.linear(x)
        x = self.proj(x)
        x = self.sigmoid(x)
        return x

In [79]:
num_questions = 123

In [117]:
sakt = SAKT(num_questions*2,num_questions)

In [110]:
loss_func = nn.BCELoss()

In [111]:
def eval_data(true_labels,pred_pro):
    sakt.eval()
    pred_idx = (pred_pro>=0.5).float()
    p = (true_labels == pred_idx).sum() / len(pred_idx)
    print('acc',p)

In [113]:
op = optim.Adam(sakt.parameters(),lr=0.0002)

In [None]:
for _ in range(10):
    true_arr = torch.tensor([])
    pre_arr = torch.tensor([])
    sakt.train()
    for i,xb in enumerate(train_data):
        logits = sakt(xb.unsqueeze(0))
        true_labels,pred_pro = parse_raw_data(xb,logits.squeeze(0))
        true_arr = torch.cat([true_arr,true_labels])
        pre_arr = torch.cat([pre_arr,pred_pro])
        loss = loss_func(pred_pro,true_labels*1.)
        op.zero_grad()
        loss.backward()
        op.step()
    eval_data(true_arr,pre_arr)

acc tensor(0.4997)


In [104]:
def eval_all_data(data):
    true_arr = torch.tensor([])
    pre_arr = torch.tensor([])
    for i,xb in enumerate(data):
        logits = sakt(xb.unsqueeze(0))
        true_labels,pred_pro = parse_raw_data(xb,logits.squeeze(0))
        true_arr = torch.cat([true_arr,true_labels])
        pre_arr = torch.cat([pre_arr,pred_pro])
    eval_data(true_arr,pre_arr)

In [105]:
eval_all_data(val_data)

acc tensor(0.7875)
