In [3]:
import os
import argparse
import csv
import torch

from data.simulated import SyntheticDataLoader
from data.assistment import AssistDataLoader

from loader.dkt import KTDataLoader

from trainer.assistment import AssistTrainer
from trainer.simulated import SimulatedTrainer

from model.lstm import LSTM
from model.rnn import RNN

import main

In [24]:
import easydict

args = easydict.EasyDict({
        "syn_path": '../data/synthetic/',
        "ast_path": '../data/assistments/',
        "dataset": 'assist',
    
        "syn_c": 5,
        "syn_q": 500,
        "syn_n": 4000,
        "syn_v": 0,
    
        "train_max_seq_len": 500,
        "eval_max_seq_len": 500,
        "use_dev": False,
    
        "model_type": 'LSTM',
        "input_type": 'onehot',
        "n_input_dim": 100,
        "embedding_dim": 100,
        "hidden_dim": 200,
        "n_layers": 1,
    
        "device": 'gpu',
        "device_id": 0,
    
        "loss": 'BCE',
        "optim": 'adam',
        "lr": 0.001,
        "momentum": 0.9,
        "n_epochs": 10,
        "clip_grad": 100.0,
        "log_steps": 10,
    
        "train_batch_size": 100,
        "eval_batch_size": 100,
    
        "task": 'dkt',
        "save_model": True,
        "save_dir": "./../ckpt/",
        "save_epochs": 10,
        "load_dir": "./../ckpt/",
        "load_model": True,
        "load_epoch": 30,
        "refresh_optim": False,
         
    
})

In [25]:
args.n_questions = 124 # 0-123
args.data = AssistDataLoader(args)
args.loader = KTDataLoader(args)
trainer = AssistTrainer(args)

args.model = LSTM

Loading Assisment...
training data:
 n_students: 3274
 n_questions: 124
 total answers: 407967
 longest: 4290
test data:
 n_students: 834
 n_questions: 124
 total answers: 117567
 longest: 8214
There are 1 GPU(s) available.
We will use the GPU: A100-SXM4-40GB MIG 1g.5gb


In [26]:
args.load_dir

'./../ckpt/'

In [31]:
!ls ./../ckpt

assist-dkt-LSTM-1-10.ckpt


In [None]:
# model.load_state_dict(torch.load('./ckpts/dkt/ASSIST2009/model.ckpt'))


In [34]:
train_loader, test_loader = trainer.dataloader()

In [41]:
for data in test_loader:
    print(data)
    break

[tensor([[51, 51, 51,  ...,  0,  0,  0],
        [82, 82, 82,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        ...,
        [14, 14, 14,  ...,  0,  0,  0],
        [ 3,  3,  3,  ...,  0,  0,  0],
        [79, 79, 79,  ...,  0,  0,  0]]), tensor([[0, 1, 1,  ..., 0, 0, 0],
        [0, 0, 1,  ..., 0, 0, 0],
        [1, 0, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [0, 1, 1,  ..., 0, 0, 0]]), tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False]]), tensor([  4.,   9.,  20.,  49.,  90.,  28.,  68.,   4.,  16.,  22., 101., 100.,
         13.,   2.,  33.,  13.,   8.,  11.,  39.,  51.,   4.,   4.,  64.,  30.,
  

In [39]:
trainer.predict(test_loader)

(array([[0.5010038 , 0.5010034 , 0.5010028 , ..., 0.        , 0.        ,
         0.        ],
        [0.5010125 , 0.501015  , 0.5010157 , ..., 0.        , 0.        ,
         0.        ],
        [0.50100356, 0.5010064 , 0.5010058 , ..., 0.        , 0.        ,
         0.        ],
        ...,
        [0.50101346, 0.500956  , 0.5009572 , ..., 0.        , 0.        ,
         0.        ],
        [0.5010071 , 0.5010027 , 0.50100064, ..., 0.        , 0.        ,
         0.        ],
        [0.50096714, 0.50096744, 0.5009672 , ..., 0.        , 0.        ,
         0.        ]], dtype=float32),
 array([[1., 1., 1., ..., 0., 0., 0.],
        [0., 1., 1., ..., 0., 0., 0.],
        [0., 1., 1., ..., 0., 0., 0.],
        ...,
        [1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [1., 1., 0., ..., 0., 0., 0.]], dtype=float32),
 array([[1., 1., 1., ..., 0., 0., 0.],
        [1., 1., 1., ..., 0., 0., 0.],
        [1., 1., 1., ..., 0., 0., 0.],
        ...,


In [36]:
tst_tp, tst_tc, tst_tm = trainer.predict(test_loader)
tst_predict, tst_correct = trainer.metrics.flatten(tst_tp, tst_tc, tst_tm)


In [37]:
tst_predict

[0.5010038,
 0.5010034,
 0.5010028,
 0.5010125,
 0.501015,
 0.5010157,
 0.50101626,
 0.5010168,
 0.5010172,
 0.5010179,
 0.5010184,
 0.50100356,
 0.5010064,
 0.5010058,
 0.5010054,
 0.5009969,
 0.5009964,
 0.50099605,
 0.50096196,
 0.5009592,
 0.5009651,
 0.5009685,
 0.5009707,
 0.50097203,
 0.50097287,
 0.50100803,
 0.50100946,
 0.5010098,
 0.5010097,
 0.50100946,
 0.50100654,
 0.50096184,
 0.50095856,
 0.50095725,
 0.50095654,
 0.50097495,
 0.5009734,
 0.5009727,
 0.5009732,
 0.50097287,
 0.5009727,
 0.5009832,
 0.5009824,
 0.5009859,
 0.50098366,
 0.50098646,
 0.50098836,
 0.500985,
 0.5009665,
 0.500968,
 0.50096875,
 0.500969,
 0.50096565,
 0.50101745,
 0.5010181,
 0.5010181,
 0.50101805,
 0.501018,
 0.501018,
 0.50101805,
 0.5010044,
 0.5010067,
 0.5010082,
 0.50101036,
 0.50101227,
 0.5010137,
 0.5010147,
 0.50101584,
 0.5010157,
 0.50101525,
 0.50101477,
 0.5010137,
 0.50101465,
 0.50101465,
 0.5010145,
 0.5010143,
 0.50101405,
 0.5010132,
 0.5009678,
 0.5009745,
 0.50098234,
 

In [38]:
tst_correct

[1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 1.0,
 0.0,
 1.0