In [79]:
import matplotlib.pyplot as plt 
import json
import pickle
import torch
import random 
import numpy as np
import os 

from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, random_split
from torch.nn.functional import binary_cross_entropy
from sklearn import metrics 

from data_loaders.assist2009 import ASSIST2009
from data_loaders.assist2012 import ASSIST2012
from data_loaders.ednet01 import EdNet01

from models.dkvmn_text import SUBJ_DKVMN
from models.dkvmn_text import train_model as plus_train

from models.utils import collate_fn, collate_ednet, cal_acc_class

%matplotlib inline

seed = 1004

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

In [80]:
model_name = 'dkvmn+'
dataset_name = 'ASSIST2009'
dataset = None
ckpts = f"ckpts/{model_name}/{dataset_name}/"

with open("config.json") as f:
    config = json.load(f)
    model_config = config[model_name]
    train_config = config["train_config"]
    
batch_size = train_config["batch_size"]
num_epochs = train_config["num_epochs"]
train_ratio = train_config["train_ratio"]
learning_rate = train_config["learning_rate"]
optimizer = train_config["optimizer"] # can be sgd, adam
seq_len = train_config["seq_len"] # 샘플링 할 갯수

In [81]:
# 데이터셋 추가 가능
collate_pt = collate_fn
if dataset_name == "ASSIST2009":
    dataset = ASSIST2009(seq_len, 'datasets/ASSIST2009/')
elif dataset_name == "ASSIST2012":
    dataset = ASSIST2012(seq_len, 'datasets/ASSIST2012/')

In [82]:
def train_model(model, test_loader, ckpt_path):
    '''
        Args:
            train_loader: the PyTorch DataLoader instance for training
            test_loader: the PyTorch DataLoader instance for test
            num_epochs: the number of epochs
            opt: the optimization to train this model
            ckpt_path: the path to save this model's parameters
    '''
    aucs = []
    loss_means = []  
    accs = []
    q_accs = {}
    
    max_auc = 0
    
    # Test
    model.load_state_dict(torch.load(os.path.join(ckpt_path, "model.ckpt"), map_location=device))
    loss_mean = []
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            q, r, qshft_seqs, rshft_seqs, m, bert_s, bert_t, bert_m, q2diff_seqs, pid_seqs, pidshift, hint_seqs = data

            model.eval()

            y, Mv, w = model(q.long(), r.long(), bert_s, bert_t, bert_m, q2diff_seqs.long(), pid_seqs.long())

            # y와 t 변수에 있는 행렬들에서 마스킹이 true로 된 값들만 불러옴
            q = torch.masked_select(q, m).detach().cpu()
            y = torch.masked_select(y, m).detach().cpu()
            t = torch.masked_select(r, m).detach().cpu()

            auc = metrics.roc_auc_score(
                y_true=t.numpy(), y_score=y.numpy()
            )
            bin_y = [1 if p >= 0.5 else 0 for p in y.numpy()]
            acc = metrics.accuracy_score(t.numpy(), bin_y)
            loss = binary_cross_entropy(y, t) # 실제 y^T와 원핫 결합, 다음 answer 간 cross entropy

            print(f"[Test] number: {i}, AUC: {auc}, ACC: :{acc} Loss: {loss} ")

            # evaluation metrics
            aucs.append(auc)
            loss_mean.append(loss)     
            accs.append(acc)
            q_accs, cnt = cal_acc_class(q.long(), t.long(), bin_y)
        loss_means.append(np.mean(loss_mean))


    return aucs, loss_means, accs, q_accs, cnt, Mv, w

In [83]:
model = torch.nn.DataParallel(SUBJ_DKVMN(dataset.num_q, num_qid=dataset.num_pid, **model_config)).to(device)
train_model = train_model

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [84]:
# 데이터셋 분할
data_size = len(dataset)
train_size = int(data_size * train_ratio) 
valid_size = int(data_size * ((1.0 - train_ratio) / 2.0))
test_size = data_size - train_size - valid_size

train_dataset, valid_dataset, test_dataset = random_split(
    dataset, [train_size, valid_size, test_size], generator=torch.Generator(device=device)
)

# pickle에 얼마만큼 분할했는지 읽기
if os.path.exists(os.path.join(dataset.dataset_dir, "train_indices.pkl")):
    with open(
        os.path.join(dataset.dataset_dir, "train_indices.pkl"), "rb"
    ) as f:
        train_dataset.indices = pickle.load(f)
    with open(
        os.path.join(dataset.dataset_dir, "valid_indicies.pkl"), "rb"
    ) as f:
        valid_dataset.indices = pickle.load(f)
    with open(
        os.path.join(dataset.dataset_dir, "test_indices.pkl"), "rb"
    ) as f:
        test_dataset.indices = pickle.load(f)

In [85]:
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)
valid_loader = DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True,
    collate_fn=collate_pt, generator=torch.Generator(device=device)
)

if optimizer == "sgd":
    opt = SGD(model.parameters(), learning_rate, momentum=0.9)
elif optimizer == "adam":
    opt = Adam(model.parameters(), learning_rate)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.5)
opt.lr_scheduler = lr_scheduler

# 모델에서 미리 정의한 함수로 AUCS와 LOSS 계산    
aucs, loss_means, accs, q_accs, q_cnts, Mv, w = \
    train_model(
        model, test_loader, ckpts
    )

In [None]:
print(Mv) # 컨셉 수 / 시퀀스? 임베딩 수

tensor([[[[ 0.1046, -0.0137,  0.0843,  ...,  0.1525, -0.2694, -0.1215],
          [ 0.0667, -0.0029,  0.3115,  ...,  0.2028, -0.1982,  0.1142],
          [-0.1486,  0.0222, -0.0744,  ..., -0.2432, -0.0805, -0.0261],
          ...,
          [-0.0995, -0.0133, -0.0387,  ...,  0.0863, -0.1137, -0.2638],
          [ 0.0052, -0.0726, -0.0853,  ...,  0.0196, -0.2335,  0.2153],
          [-0.0590,  0.2141, -0.0652,  ..., -0.2442, -0.2831, -0.1283]],

         [[ 0.1050, -0.0134,  0.0838,  ...,  0.1526, -0.2688, -0.1211],
          [ 0.2006,  0.0828,  0.0804,  ...,  0.2108, -0.0269,  0.1822],
          [-0.1484,  0.0223, -0.0745,  ..., -0.2430, -0.0804, -0.0259],
          ...,
          [-0.0882, -0.0080, -0.0455,  ...,  0.0891, -0.1051, -0.2522],
          [ 0.0058, -0.0721, -0.0857,  ...,  0.0199, -0.2327,  0.2155],
          [-0.0588,  0.2141, -0.0653,  ..., -0.2440, -0.2828, -0.1281]],

         [[ 0.1055, -0.0132,  0.0833,  ...,  0.1526, -0.2682, -0.1206],
          [ 0.3133,  0.1217, -

In [71]:
print(Mv[:, :-1].shape, w.shape)

torch.Size([13, 100, 50, 100])
