In [1]:
import numpy as np
import pandas as pd

import gc
import random
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_df = pd.read_csv('/opt/ml/input/data/train_data.csv')
train_df

Unnamed: 0,userID,assessmentItemID,testId,answerCode,Timestamp,KnowledgeTag
0,0,A060001001,A060000001,1,2020-03-24 00:17:11,7224
1,0,A060001002,A060000001,1,2020-03-24 00:17:14,7225
2,0,A060001003,A060000001,1,2020-03-24 00:17:22,7225
3,0,A060001004,A060000001,1,2020-03-24 00:17:29,7225
4,0,A060001005,A060000001,1,2020-03-24 00:17:36,7225
...,...,...,...,...,...,...
2266581,7441,A030071005,A030000071,0,2020-06-05 06:50:21,438
2266582,7441,A040165001,A040000165,1,2020-08-21 01:06:39,8836
2266583,7441,A040165002,A040000165,1,2020-08-21 01:06:50,8836
2266584,7441,A040165003,A040000165,1,2020-08-21 01:07:36,8836


In [3]:
problems = train_df['assessmentItemID'].unique()
num_problems = len(problems)

In [4]:
ref_df = train_df.copy()

In [5]:
assessmentItemID_to_idx = {v:k for k,v in enumerate(ref_df['assessmentItemID'].unique())}
idx_to_assessmentItemID = {k:v for k,v in enumerate(ref_df['assessmentItemID'].unique())}

In [6]:
ref_df['assessmentItemID'] = ref_df['assessmentItemID'].map(assessmentItemID_to_idx)
ref_df

Unnamed: 0,userID,assessmentItemID,testId,answerCode,Timestamp,KnowledgeTag
0,0,0,A060000001,1,2020-03-24 00:17:11,7224
1,0,1,A060000001,1,2020-03-24 00:17:14,7225
2,0,2,A060000001,1,2020-03-24 00:17:22,7225
3,0,3,A060000001,1,2020-03-24 00:17:29,7225
4,0,4,A060000001,1,2020-03-24 00:17:36,7225
...,...,...,...,...,...,...
2266581,7441,3147,A030000071,0,2020-06-05 06:50:21,438
2266582,7441,1286,A040000165,1,2020-08-21 01:06:39,8836
2266583,7441,1287,A040000165,1,2020-08-21 01:06:50,8836
2266584,7441,1288,A040000165,1,2020-08-21 01:07:36,8836


In [7]:
ref_df = ref_df[['userID', 'assessmentItemID', 'answerCode']].groupby('userID').apply(lambda r: (
            r['assessmentItemID'].values,
            r['answerCode'].values))

In [8]:
ref_df[7441]

(array([3143, 3144, 3145, 3146, 3147, 1286, 1287, 1288, 1289]),
 array([0, 0, 1, 0, 0, 1, 1, 1, 1]))

In [9]:
class SAKTDataset(Dataset):
    def __init__(self, group, n_skill, max_seq=100):
        super(SAKTDataset, self).__init__()
        self.max_seq = max_seq
        self.n_skill = n_skill
        self.samples = group
        
#         self.user_ids = [x for x in group.index]
        self.user_ids = []
        for user_id in group.index:
            q, qa = group[user_id]
            if len(q) < 10:
                continue
            self.user_ids.append(user_id)

    def __len__(self):
        return len(self.user_ids)

    def __getitem__(self, index):
        user_id = self.user_ids[index]
        q_, qa_ = self.samples[user_id]
        seq_len = len(q_)

        q = np.zeros(self.max_seq, dtype=int)
        qa = np.zeros(self.max_seq, dtype=int)
        if seq_len >= self.max_seq:
            q[:] = q_[-self.max_seq:]
            qa[:] = qa_[-self.max_seq:]
        else:
            q[-seq_len:] = q_
            qa[-seq_len:] = qa_
        
        target_id = q[1:]
        label = qa[1:]

        x = np.zeros(self.max_seq-1, dtype=int)
        x = q[:-1].copy()
        x += (qa[:-1] == 1) * self.n_skill

        return x, target_id, label

In [10]:
train, val = train_test_split(ref_df, test_size=0.2)

train_dataset = SAKTDataset(train, num_problems)
train_dataloader = DataLoader(train_dataset, batch_size=2048, shuffle=True, num_workers=8)
del train

val_dataset = SAKTDataset(val, num_problems)
val_dataloader = DataLoader(val_dataset, batch_size=2048, shuffle=True, num_workers=8)
del val

In [11]:
class FFN(nn.Module):
    def __init__(self, state_size=200):
        super(FFN, self).__init__()
        self.state_size = state_size

        self.lr1 = nn.Linear(state_size, state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(state_size, state_size)
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        x = self.lr1(x)
        x = self.relu(x)
        x = self.lr2(x)
        return self.dropout(x)

def future_mask(seq_length):
    future_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')
    return torch.from_numpy(future_mask)


class SAKTModel(nn.Module):
    def __init__(self, n_skill, max_seq=100, embed_dim=128):
        super(SAKTModel, self).__init__()
        self.n_skill = n_skill
        self.embed_dim = embed_dim

        self.embedding = nn.Embedding(2*n_skill+1, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq-1, embed_dim)
        self.e_embedding = nn.Embedding(n_skill+1, embed_dim)

        self.multi_att = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=8, dropout=0.2)

        self.dropout = nn.Dropout(0.2)
        self.layer_normal = nn.LayerNorm(embed_dim) 

        self.ffn = FFN(embed_dim)
        self.pred = nn.Linear(embed_dim, 1)
    
    def forward(self, x, question_ids):
        device = x.device        
        x = self.embedding(x)
        pos_id = torch.arange(x.size(1)).unsqueeze(0).to(device)

        pos_x = self.pos_embedding(pos_id)
        x = x + pos_x

        e = self.e_embedding(question_ids)

        x = x.permute(1, 0, 2) # x: [bs, s_len, embed] => [s_len, bs, embed]
        e = e.permute(1, 0, 2)
        att_mask = future_mask(x.size(0)).to(device)
        att_output, att_weight = self.multi_att(e, x, x, attn_mask=att_mask)
        att_output = self.layer_normal(att_output + e)
        att_output = att_output.permute(1, 0, 2) # att_output: [s_len, bs, embed] => [bs, s_len, embed]

        x = self.ffn(att_output)
        x = self.layer_normal(x + att_output)
        x = self.pred(x)

        return x.squeeze(-1), att_weight

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [13]:
model = SAKTModel(num_problems, embed_dim=128)
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.99, weight_decay=0.005)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss() # Binary Cross Entropy Loss에 Sigmod가 결합된 형태

model.to(device)
criterion.to(device)


BCEWithLogitsLoss()

In [14]:
def train_epoch(model, train_iterator, optim, criterion, device="cuda"):
    model.train()

    train_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    tbar = tqdm(train_iterator)
    for item in tbar:
        x = item[0].to(device).long()
        target_id = item[1].to(device).long()
        label = item[2].to(device).float()
        target_mask = (target_id != 0)

        optim.zero_grad()
        output, atten_weight = model(x, target_id)
        
        output = torch.masked_select(output, target_mask)
        label = torch.masked_select(label, target_mask)
        
        loss = criterion(output, label)
        loss.backward()
        optim.step()
        train_loss.append(loss.item())
        pred = (torch.sigmoid(output) >= 0.5).long()
        
        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

        tbar.set_description('loss - {:.4f}'.format(loss))

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.average(train_loss)

    return loss, acc, auc

In [15]:
def val_epoch(model, val_iterator, criterion, device="cuda"):
    model.eval()

    train_loss = []
    num_corrects = 0
    num_total = 0
    labels = []
    outs = []

    tbar = tqdm(val_iterator)
    for item in tbar:
        x = item[0].to(device).long()
        target_id = item[1].to(device).long()
        label = item[2].to(device).float()
        target_mask = (target_id != 0)

        with torch.no_grad():
            output, atten_weight = model(x, target_id)
        
        output = torch.masked_select(output, target_mask)
        label = torch.masked_select(label, target_mask)

        loss = criterion(output, label)
        train_loss.append(loss.item())

        pred = (torch.sigmoid(output) >= 0.5).long()
        
        num_corrects += (pred == label).sum().item()
        num_total += len(label)

        labels.extend(label.view(-1).data.cpu().numpy())
        outs.extend(output.view(-1).data.cpu().numpy())

        tbar.set_description('loss - {:.4f}'.format(loss))

    acc = num_corrects / num_total
    auc = roc_auc_score(labels, outs)
    loss = np.average(train_loss)

    return loss, acc, auc

In [16]:
epochs = 20

over_fit = 0
last_auc = 0
for epoch in range(epochs):
    train_loss, train_acc, train_auc = train_epoch(model, train_dataloader, optimizer, criterion, device)
    print("epoch - {} train_loss - {:.2f} acc - {:.3f} auc - {:.3f}".format(epoch, train_loss, train_acc, train_auc))
    
    val_loss, avl_acc, val_auc = val_epoch(model, val_dataloader, criterion, device)
    print("epoch - {} val_loss - {:.2f} acc - {:.3f} auc - {:.3f}".format(epoch, val_loss, avl_acc, val_auc))
    
    if val_auc > last_auc:
        last_auc = val_auc
        over_fit = 0
    else:
        over_fit += 1
        
    
    if over_fit >= 2:
        print("early stop epoch ", epoch)
        break

loss - 0.6900: 100%|██████████| 3/3 [00:01<00:00,  2.25it/s]


epoch - 0 train_loss - 0.71 acc - 0.532 auc - 0.497


loss - 0.6871: 100%|██████████| 1/1 [00:00<00:00,  2.62it/s]


epoch - 0 val_loss - 0.69 acc - 0.608 auc - 0.509


loss - 0.6882: 100%|██████████| 3/3 [00:00<00:00,  4.83it/s]


epoch - 1 train_loss - 0.69 acc - 0.609 auc - 0.512


loss - 0.6798: 100%|██████████| 1/1 [00:00<00:00,  2.64it/s]


epoch - 1 val_loss - 0.68 acc - 0.615 auc - 0.521


loss - 0.6744: 100%|██████████| 3/3 [00:00<00:00,  4.86it/s]


epoch - 2 train_loss - 0.68 acc - 0.611 auc - 0.525


loss - 0.6685: 100%|██████████| 1/1 [00:00<00:00,  2.57it/s]


epoch - 2 val_loss - 0.67 acc - 0.610 auc - 0.528


loss - 0.6608: 100%|██████████| 3/3 [00:00<00:00,  5.05it/s]


epoch - 3 train_loss - 0.67 acc - 0.610 auc - 0.533


loss - 0.6621: 100%|██████████| 1/1 [00:00<00:00,  2.47it/s]


epoch - 3 val_loss - 0.66 acc - 0.617 auc - 0.537


loss - 0.6632: 100%|██████████| 3/3 [00:00<00:00,  4.80it/s]


epoch - 4 train_loss - 0.66 acc - 0.617 auc - 0.542


loss - 0.6582: 100%|██████████| 1/1 [00:00<00:00,  2.61it/s]


epoch - 4 val_loss - 0.66 acc - 0.624 auc - 0.545


loss - 0.6560: 100%|██████████| 3/3 [00:00<00:00,  4.65it/s]


epoch - 5 train_loss - 0.66 acc - 0.622 auc - 0.551


loss - 0.6556: 100%|██████████| 1/1 [00:00<00:00,  2.50it/s]


epoch - 5 val_loss - 0.66 acc - 0.627 auc - 0.556


loss - 0.6531: 100%|██████████| 3/3 [00:00<00:00,  4.85it/s]


epoch - 6 train_loss - 0.66 acc - 0.625 auc - 0.563


loss - 0.6531: 100%|██████████| 1/1 [00:00<00:00,  2.54it/s]


epoch - 6 val_loss - 0.65 acc - 0.628 auc - 0.571


loss - 0.6533: 100%|██████████| 3/3 [00:00<00:00,  4.76it/s]


epoch - 7 train_loss - 0.65 acc - 0.626 auc - 0.579


loss - 0.6512: 100%|██████████| 1/1 [00:00<00:00,  2.42it/s]


epoch - 7 val_loss - 0.65 acc - 0.628 auc - 0.581


loss - 0.6469: 100%|██████████| 3/3 [00:00<00:00,  4.74it/s]


epoch - 8 train_loss - 0.65 acc - 0.626 auc - 0.589


loss - 0.6500: 100%|██████████| 1/1 [00:00<00:00,  2.63it/s]


epoch - 8 val_loss - 0.65 acc - 0.628 auc - 0.585


loss - 0.6443: 100%|██████████| 3/3 [00:00<00:00,  4.98it/s]


epoch - 9 train_loss - 0.65 acc - 0.627 auc - 0.596


loss - 0.6483: 100%|██████████| 1/1 [00:00<00:00,  2.51it/s]


epoch - 9 val_loss - 0.65 acc - 0.628 auc - 0.593


loss - 0.6452: 100%|██████████| 3/3 [00:00<00:00,  4.97it/s]


epoch - 10 train_loss - 0.65 acc - 0.628 auc - 0.607


loss - 0.6458: 100%|██████████| 1/1 [00:00<00:00,  2.56it/s]


epoch - 10 val_loss - 0.65 acc - 0.629 auc - 0.604


loss - 0.6358: 100%|██████████| 3/3 [00:00<00:00,  4.86it/s]


epoch - 11 train_loss - 0.64 acc - 0.630 auc - 0.619


loss - 0.6431: 100%|██████████| 1/1 [00:00<00:00,  2.59it/s]


epoch - 11 val_loss - 0.64 acc - 0.631 auc - 0.614


loss - 0.6333: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


epoch - 12 train_loss - 0.64 acc - 0.633 auc - 0.630


loss - 0.6400: 100%|██████████| 1/1 [00:00<00:00,  2.51it/s]


epoch - 12 val_loss - 0.64 acc - 0.635 auc - 0.623


loss - 0.6313: 100%|██████████| 3/3 [00:00<00:00,  4.78it/s]


epoch - 13 train_loss - 0.64 acc - 0.638 auc - 0.641


loss - 0.6365: 100%|██████████| 1/1 [00:00<00:00,  2.57it/s]


epoch - 13 val_loss - 0.64 acc - 0.639 auc - 0.631


loss - 0.6281: 100%|██████████| 3/3 [00:00<00:00,  4.85it/s]


epoch - 14 train_loss - 0.63 acc - 0.645 auc - 0.651


loss - 0.6325: 100%|██████████| 1/1 [00:00<00:00,  2.56it/s]


epoch - 14 val_loss - 0.63 acc - 0.644 auc - 0.640


loss - 0.6259: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


epoch - 15 train_loss - 0.62 acc - 0.651 auc - 0.663


loss - 0.6282: 100%|██████████| 1/1 [00:00<00:00,  2.40it/s]


epoch - 15 val_loss - 0.63 acc - 0.648 auc - 0.649


loss - 0.6171: 100%|██████████| 3/3 [00:00<00:00,  4.88it/s]


epoch - 16 train_loss - 0.62 acc - 0.659 auc - 0.674


loss - 0.6240: 100%|██████████| 1/1 [00:00<00:00,  2.50it/s]


epoch - 16 val_loss - 0.62 acc - 0.653 auc - 0.658


loss - 0.6069: 100%|██████████| 3/3 [00:00<00:00,  4.93it/s]


epoch - 17 train_loss - 0.61 acc - 0.668 auc - 0.686


loss - 0.6205: 100%|██████████| 1/1 [00:00<00:00,  2.44it/s]


epoch - 17 val_loss - 0.62 acc - 0.658 auc - 0.667


loss - 0.6042: 100%|██████████| 3/3 [00:00<00:00,  4.81it/s]


epoch - 18 train_loss - 0.60 acc - 0.675 auc - 0.697


loss - 0.6175: 100%|██████████| 1/1 [00:00<00:00,  2.38it/s]


epoch - 18 val_loss - 0.62 acc - 0.662 auc - 0.675


loss - 0.5989: 100%|██████████| 3/3 [00:00<00:00,  4.79it/s]


epoch - 19 train_loss - 0.60 acc - 0.682 auc - 0.707


loss - 0.6135: 100%|██████████| 1/1 [00:00<00:00,  2.23it/s]

epoch - 19 val_loss - 0.61 acc - 0.667 auc - 0.683



