Paper:   [SAINT+: Integrating Temporal Features for EdNet Correctness Prediction](https://arxiv.org/abs/2010.12042)

# Import everything now...

In [1]:
import torch
from torch import nn
import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import torch
import numpy as np 
from torch import nn 
import copy 
import pytorch_lightning as pl 
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import gc
from sklearn.model_selection import train_test_split 
from sklearn.metrics import roc_auc_score
import torch.nn.functional as F
import gc
import random

# Configure constants

In [2]:
class config:
        device = torch.device("cuda") 
        MAX_SEQ = 55
        EMBED_DIMS = 512
        ENC_HEADS = DEC_HEADS = 8
        NUM_ENCODER = NUM_DECODER = 4
        BATCH_SIZE = 256
        #TRAIN_FILE = "../input/riiid-test-answer-prediction/train.csv"
        TOTAL_EXE = 13523
        TOTAL_CAT = 8
        SAMPLE = 50000
        EPOCHS = 10
        lagtime = True

In [3]:
pl.seed_everything(42)

42

In [4]:
def feature_time_lag(df, time_dict):

    tt = np.zeros(len(df), dtype=np.int64)

    for ind, row in enumerate(df[['user_id','timestamp','task_container_id']].values):

        if row[0] in time_dict.keys():
            if row[2]-time_dict[row[0]][1] == 0:

                tt[ind] = time_dict[row[0]][2]

            else:
                t_last = time_dict[row[0]][0]
                task_ind_last = time_dict[row[0]][1]
                tt[ind] = row[1]-t_last
                time_dict[row[0]] = (row[1], row[2], tt[ind])
        else:
            # time_dict : timestamp, task_container_id, lag_time
            time_dict[row[0]] = (row[1], row[2], -1)
            tt[ind] =  0

    df["time_lag"] = tt
    return df

In [5]:
def reduce_mem_usage(df):
    """ 
    iterate through all the columns of a dataframe and 
    modify the data type to reduce memory usage.        
    """
    start_mem = df.memory_usage().sum() / 1024**2
    print(('Memory usage of dataframe is {:.2f}' 
                     'MB').format(start_mem))
    
    for col in df.columns:
        col_type = df[col].dtype
        
        if col_type != object:
            c_min = df[col].min()
            c_max = df[col].max()
            if str(col_type)[:3] == 'int':
                if c_min > np.iinfo(np.int8).min and c_max <\
                  np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max <\
                   np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max <\
                   np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
                elif c_min > np.iinfo(np.int64).min and c_max <\
                   np.iinfo(np.int64).max:
                    df[col] = df[col].astype(np.int64)  
            else:
                if c_min > np.finfo(np.float16).min and c_max <\
                   np.finfo(np.float16).max:
                    df[col] = df[col].astype(np.float16)
                elif c_min > np.finfo(np.float32).min and c_max <\
                   np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
                else:
                    df[col] = df[col].astype(np.float64)
        else:
            df[col] = df[col].astype('category')
    end_mem = df.memory_usage().sum() / 1024**2
    print(('Memory usage after optimization is: {:.2f}' 
                              'MB').format(end_mem))
    print('Decreased by {:.1f}%'.format(100 * (start_mem - end_mem) 
                                             / start_mem))
    
    return df

# Dataset

In [6]:
class DKTDataset(Dataset):
  def __init__(self,samples,max_seq,start_token=0): 
    super().__init__()
    self.samples = samples
    self.max_seq = max_seq
    self.start_token = start_token
    self.data = []
    for id in self.samples.index:
      exe_ids,answers,ela_time,categories = self.samples[id]
      if len(exe_ids)>max_seq:
        self.data.extend([(exe_ids[l:l+self.max_seq],answers[l:l+self.max_seq],
                                   ela_time[l:l+self.max_seq],categories[l:l+self.max_seq])\
                for l in range(len(exe_ids)) if l%self.max_seq==0])
        #for l in range((len(exe_ids)+max_seq-1)//max_seq):
        #    self.data.append((exe_ids[l:l+max_seq],answers[l:l+max_seq],ela_time[l:l+max_seq],categories[l:l+max_seq]))
      elif len(exe_ids)<self.max_seq and len(exe_ids)>9:
            self.data.append((exe_ids,answers,ela_time,categories))
      else :
            continue

  def __len__(self):
    return len(self.data)
  
  def __getitem__(self,idx):
    question_ids,answers,ela_time,exe_category = self.data[idx]
    seq_len = len(question_ids)

    exe_ids = np.zeros(self.max_seq,dtype=int)
    ans = np.zeros(self.max_seq,dtype=int)
    elapsed_time = np.zeros(self.max_seq,dtype=int)
    exe_cat = np.zeros(self.max_seq,dtype=int)
    if seq_len<self.max_seq:
      exe_ids[-seq_len:] = question_ids
      ans[-seq_len:] = answers
      elapsed_time[-seq_len:] = ela_time 
      exe_cat[-seq_len:] = exe_category
    else:
      exe_ids[:] = question_ids[-self.max_seq:]
      ans[:] = answers[-self.max_seq:]
      elapsed_time[:] = ela_time[-self.max_seq:]
      exe_cat[:] = exe_category[-self.max_seq:]

    #input_rtime = np.zeros(self.max_seq,dtype=int)
    #input_rtime = np.insert(elapsed_time,0,self.start_token)
    #input_rtime = np.delete(input_rtime,-1)
    input_rtime = elapsed_time[:].copy()
    input = {"input_ids":exe_ids,"input_rtime":input_rtime.astype(np.int),"input_cat":exe_cat}
    answers = np.append([0],ans[:-1]) #start token
    assert ans.shape[0]==answers.shape[0] and answers.shape[0]==input_rtime.shape[0], "both ans and label should be \
                                                                                            same len with start-token"
    return input,answers,ans


# SAINT+ model

In [7]:
class FFN(nn.Module):
  def __init__(self,in_feat):
    super(FFN,self).__init__()
    self.linear1 = nn.Linear(in_feat,in_feat)
    self.linear2 = nn.Linear(in_feat,in_feat)
    #self.drop = nn.Dropout(0.2)
  
  def forward(self,x):
    #out = F.relu(self.drop(self.linear1(x)))
    out = F.relu(self.linear1(x))
    out = self.linear2(out)
    return out 


class EncoderEmbedding(nn.Module):
  def __init__(self,n_exercises,n_categories,n_dims,seq_len):
    super(EncoderEmbedding,self).__init__()
    self.n_dims = n_dims
    self.seq_len = seq_len
    self.exercise_embed = nn.Embedding(n_exercises,n_dims)
    self.category_embed = nn.Embedding(n_categories,n_dims)
    self.position_embed = nn.Embedding(seq_len,n_dims)

  def forward(self,exercises,categories):
    e = self.exercise_embed(exercises)
    c = self.category_embed(categories)
    seq = torch.arange(self.seq_len,device=config.device).unsqueeze(0)
    p = self.position_embed(seq)
    return p + c + e

class DecoderEmbedding(nn.Module):
  def __init__(self,n_responses,n_dims,seq_len):
    super(DecoderEmbedding,self).__init__()
    self.n_dims = n_dims
    self.seq_len = seq_len
    self.response_embed = nn.Embedding(n_responses,n_dims)
    self.time_embed = nn.Linear(1,n_dims,bias=False)
    self.position_embed = nn.Embedding(seq_len,n_dims)

  def forward(self,responses):
    e = self.response_embed(responses)
    seq = torch.arange(self.seq_len,device=config.device).unsqueeze(0)
    p = self.position_embed(seq)
    return p + e 


# layers of encoders stacked onver, multiheads-block in each encoder is n.
# Stacked N MultiheadAttentions 
class StackedNMultiHeadAttention(nn.Module):
  def __init__(self,n_stacks,n_dims,n_heads,seq_len,n_multihead=1,dropout=0.0):
    super(StackedNMultiHeadAttention,self).__init__()
    self.n_stacks = n_stacks
    self.n_multihead = n_multihead
    self.n_dims = n_dims 
    self.norm_layers = nn.LayerNorm(n_dims)
    #n_stacks has n_multiheads each
    self.multihead_layers = nn.ModuleList(n_stacks*[nn.ModuleList(n_multihead*[nn.MultiheadAttention(embed_dim = n_dims,
                                                      num_heads = n_heads,
                                                        dropout = dropout),]),])
    self.ffn = nn.ModuleList(n_stacks*[FFN(n_dims)])
    self.mask = torch.triu(torch.ones(seq_len,seq_len),diagonal=1).to(dtype=torch.bool)
  
  def forward(self,input_q,input_k,input_v,encoder_output=None,break_layer=None):
    for stack in range(self.n_stacks):
        for multihead in range(self.n_multihead):
          norm_q = self.norm_layers(input_q)
          norm_k = self.norm_layers(input_k)
          norm_v = self.norm_layers(input_v) 
          heads_output,_ = self.multihead_layers[stack][multihead](query=norm_q.permute(1,0,2),
                                                                    key=norm_k.permute(1,0,2),
                                                                    value=norm_v.permute(1,0,2),
                                                                    attn_mask=self.mask.to(config.device))
          heads_output = heads_output.permute(1,0,2)
          #assert encoder_output != None and break_layer is not None     
          if encoder_output != None and multihead == break_layer:
            assert break_layer <= multihead, " break layer should be less than multihead layers and postive integer"
            input_k = input_v = encoder_output
            input_q =input_q + heads_output
          else:
            input_q =input_q+ heads_output
            input_k =input_k+ heads_output
            input_v =input_v +heads_output
        last_norm = self.norm_layers(heads_output)
        ffn_output = self.ffn[stack](last_norm)
        ffn_output =ffn_output+ heads_output
    return ffn_output


# Final Model with Trainer

In [8]:
# Main model for training 
class PlusSAINTModule(pl.LightningModule):
  def __init__(self):
    super(PlusSAINTModule,self).__init__()
    self.loss = nn.BCEWithLogitsLoss()
    self.encoder_layer = StackedNMultiHeadAttention(n_stacks=config.NUM_DECODER,
                                                    n_dims=config.EMBED_DIMS,
                                                    n_heads=config.DEC_HEADS,
                                                    seq_len=config.MAX_SEQ,
                                                    n_multihead=1,dropout=0.0)
    self.decoder_layer = StackedNMultiHeadAttention(n_stacks=config.NUM_ENCODER,
                                                    n_dims=config.EMBED_DIMS,
                                                    n_heads=config.ENC_HEADS,
                                                    seq_len=config.MAX_SEQ,
                                                    n_multihead=2,dropout=0.0)
    self.encoder_embedding = EncoderEmbedding(n_exercises=config.TOTAL_EXE,
                                              n_categories=config.TOTAL_CAT,
                                              n_dims=config.EMBED_DIMS,seq_len=config.MAX_SEQ)
    self.decoder_embedding = DecoderEmbedding(n_responses=3,n_dims=config.EMBED_DIMS,seq_len=config.MAX_SEQ)
    self._time = nn.Linear(1,config.EMBED_DIMS)
    self.fc = nn.Linear(config.EMBED_DIMS,1)

  #TODO: implement embdding layer and its output
  def forward(self,x,y): 
    enc = self.encoder_embedding(exercises=x["input_ids"].long().to(config.device),categories=x['input_cat'].long().to(config.device))
    dec = self.decoder_embedding(responses=y.long().to(config.device))
    _time=x["input_rtime"].unsqueeze(-1).float()
    time = self._time(_time)
    dec = dec + time
    #this is encoder 
    encoder_output = self.encoder_layer(input_k=enc,
                                        input_q=enc,
                                        input_v=enc)
    #this is decoder
    decoder_output = self.decoder_layer(input_k=dec,
                                        input_q=dec,
                                        input_v=dec,
                                        encoder_output = encoder_output,
                                        break_layer=1)
    #fully connected layer
    out = self.fc(decoder_output)
    return out.squeeze()

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters())
  
  def training_step(self,batch,batch_ids):
    input,ans,labels = batch
    target_mask = (input["input_ids"]!=0)
    out = self(input,ans)
    loss = self.loss(out.view(-1).float(),labels.view(-1).float()) 
    out = torch.masked_select(out,target_mask)
    out = torch.sigmoid(out) 
    labels = torch.masked_select(labels,target_mask)    
    self.log("train_loss",loss,on_step=True,prog_bar=True)
    return {"loss":loss,"outs":out,"labels":labels}
  
  def validation_step(self,batch,batch_ids):
    input,ans,labels = batch
    target_mask = (input["input_ids"]!=0)
    out = self(input,ans)
    loss = self.loss(out.view(-1).float(),labels.view(-1).float())
    out = torch.masked_select(out,target_mask)
    pred = (torch.sigmoid(out) >= 0.5).long()
    out = torch.sigmoid(out) 
    labels = torch.masked_select(labels,target_mask) 
    num_corrects = 0
    num_total = 0
    num_corrects += (pred == labels).sum().item()
    num_total += len(labels)
    acc = num_corrects / num_total
    self.log("val_loss",loss,on_step=True,prog_bar=True)
    output = {"outs":out,"labels":labels,"loss":loss,"acc":acc}
    return output
  
  def validation_epoch_end(self,validation_ouput): 
    out = torch.cat([i["outs"] for i in validation_ouput]).view(-1) 
    labels = torch.cat([i["labels"] for i in validation_ouput]).view(-1)
    auc = roc_auc_score(labels.cpu().detach().numpy(),out.cpu().detach().numpy())
    ave_loss = np.average([i["loss"].cpu().detach().numpy()
                            for i in validation_ouput])
    acc_list = ([i["acc"]
                                for i in validation_ouput])
    acc = acc_list[-1]
    self.print("val acc", acc)
    self.log("val_acc", acc)
    self.print("val auc", auc)
    self.log("val_auc", auc)
    self.print("val loss", ave_loss)
    self.log("val_loss", ave_loss)
    
  def test_step(self, batch, batch_idx):
    input,ans,labels = batch
    target_mask = (input["input_ids"]!=0)
    out = self(input,ans)
    loss = self.loss(out.view(-1).float(),labels.view(-1).float())
    out = torch.masked_select(out,target_mask)
    pred = (torch.sigmoid(out) >= 0.5).long()
    out = torch.sigmoid(out) 
    labels = torch.masked_select(labels,target_mask) 
    #num_corrects = 0
    #num_total = 0
    #num_corrects += (pred == labels).sum().item()
    #num_total += len(labels)
    #acc = num_corrects / num_total
    self.log("test_loss",loss,on_step=True,prog_bar=True)
    output = {"outs":out,"labels":labels,"loss":loss}
    return output

  def test_epoch_end(self,test_ouput): 
    out = torch.cat([i["outs"] for i in test_ouput]).view(-1) 
    labels = torch.cat([i["labels"] for i in test_ouput]).view(-1)
    auc = roc_auc_score(labels.cpu().detach().numpy(),out.cpu().detach().numpy())
    ave_loss = np.average([i["loss"].cpu().detach().numpy()
                            for i in test_ouput])
    #acc_list = ([i["acc"]
    #                            for i in test_ouput])
    #acc = acc_list[-1]
    #self.print("test acc", acc)
    #self.log("test_acc", acc)
    self.print("test auc", auc)
    self.log("test_auc", auc)
    self.print("test loss", ave_loss)
    self.log("test_loss", ave_loss)

# Dataloader

In [9]:
dtypes = {'timestamp': 'int64', 'user_id': 'int32' ,'content_id': 'int16',
                'answered_correctly':'int8',"content_type_id":"int8",
                  "prior_question_elapsed_time":"float32",
              "task_container_id":"int16"}
print("loading data.....")
train = pd.read_feather('../input/riiid-feather-files/train.feather',columns=set(dtypes.keys()))
test_user=random.sample(list(train['user_id'].unique()),1000)
test_df = train[train['user_id'].isin(test_user)]
train_df = train.drop(train[train['user_id'].isin(test_user)].index)
train_df = train_df[train_df['user_id'].isin(random.sample(list(train_df['user_id'].unique()),config.SAMPLE))]
del train
print("shape of train dataframe :",train_df.shape) 
print("shape of test dataframe :",test_df.shape)

loading data.....
shape of train dataframe : (12703820, 7)
shape of test dataframe : (259734, 7)


In [10]:
    questions_df=pd.read_feather('../input/riiid-feather-files/questions.feather')
    questions_df.rename(columns={'question_id': 'content_id'}, inplace=True)
    train_df=train_df[train_df.content_type_id==False]
    train_df = train_df[train_df['answered_correctly'] != -1].reset_index(drop = True, inplace = False)
    train_df=train_df.merge(questions_df[['content_id', 'part']], how='left')
    #train_df=train_df.groupby('user_id').tail(config.MAX_SEQ*3)
    
    #print("shape of train dataframe :",train_df.shape) 
    
    train_df.prior_question_elapsed_time.fillna(0,inplace=True) 
    train_df.prior_question_elapsed_time /=1000 
    train_df.prior_question_elapsed_time.clip(lower=0,upper=300,inplace=True)    
    train_df.prior_question_elapsed_time = train_df.prior_question_elapsed_time.astype(np.int)    
    train_df = train_df.sort_values(["timestamp"],ascending=True).reset_index(drop=True)
    
    print("train shape after exlusion:",train_df.shape)

    #grouping based on user_id to get the data supplu
    print("Grouping users...") 
    if config.lagtime == False:
        group = train_df[["user_id","content_id","answered_correctly","prior_question_elapsed_time","part"]]\
                    .groupby("user_id")\
                    .apply(lambda r: (r.content_id.values,r.answered_correctly.values,\
                                      r.prior_question_elapsed_time.values,r.part.values))
    else:
        time_dict = dict()
        train_df = feature_time_lag(train_df, time_dict)
        del time_dict
        train_df.time_lag.fillna(0,inplace=True) 
        train_df.time_lag /=1000 
        train_df.time_lag.clip(lower=0,upper=300,inplace=True)    
        train_df.time_lag = train_df.time_lag.astype(np.int)    
        group = train_df[["user_id","content_id","answered_correctly","time_lag","part"]]\
                    .groupby("user_id")\
                    .apply(lambda r: (r.content_id.values,r.answered_correctly.values,\
                                      r.time_lag.values,r.part.values))
    del train_df
    gc.collect() 
    print("splitting") 
    train,val = train_test_split(group,test_size=0.2,random_state=42) 
    print("train size: ",train.shape,"validation size: ",val.shape)
    train_dataset = DKTDataset(train,max_seq = config.MAX_SEQ)
    val_dataset = DKTDataset(val,max_seq = config.MAX_SEQ)
    
    train_loader = DataLoader(train_dataset,
                          batch_size=config.BATCH_SIZE,
                          num_workers=8,
                          shuffle=True) 
    val_loader = DataLoader(val_dataset,
                          batch_size=config.BATCH_SIZE,
                          num_workers=8,
                          shuffle=False)
    del train_dataset,val_dataset 
    gc.collect() 

#train_loader, val_loader = get_dataloaders() 

train shape after exlusion: (12457576, 8)
Grouping users...
splitting
train size:  (40000,) validation size:  (10000,)


0

# Training

In [11]:
%%time
saint_plus = PlusSAINTModule()
callback = pl.callbacks.ModelCheckpoint(monitor='val_loss_epoch',save_top_k=1,mode='min')
trainer = pl.Trainer(gpus=-1,max_epochs=config.EPOCHS,progress_bar_refresh_rate=1,callbacks = [callback]) 
trainer.fit(model=saint_plus,
            train_dataloader=train_loader, 
            val_dataloaders = [val_loader,]
            ) 

#trainer.save_checkpoint("model.pt") 

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                       | Params
-----------------------------------------------------------------
0 | loss              | BCEWithLogitsLoss          | 0     
1 | encoder_layer     | StackedNMultiHeadAttention | 1.6 M 
2 | decoder_layer     | StackedNMultiHeadAttention | 1.6 M 
3 | encoder_embedding | EncoderEmbedding           | 7.0 M 
4 | decoder_embedding | DecoderEmbedding           | 30.2 K
5 | _time             | Linear                     | 1.0 K 
6 | fc                | Linear                     | 513   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

val acc 0.5844308560677328
val auc 0.5036863335957266
val loss 0.67877495


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val acc 0.7577243746934772
val auc 0.7292865657320924
val loss 0.50962937


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val acc 0.7743992153016185
val auc 0.7442565064407061
val loss 0.49864376


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val acc 0.77145659637077
val auc 0.7594496631881196
val loss 0.49259546


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val acc 0.769494850416871
val auc 0.7499772951331317
val loss 0.5036231


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val acc 0.7685139774399216
val auc 0.7584243076444206
val loss 0.48964912


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val acc 0.7788131436978911
val auc 0.7618790347312328
val loss 0.4858561


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val acc 0.7601765571358509
val auc 0.7632160347442325
val loss 0.48997217


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val acc 0.7719470328592447
val auc 0.7614050368770462
val loss 0.4909211


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val acc 0.7783227072094164
val auc 0.7648964978515296
val loss 0.48372167


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

val acc 0.77145659637077
val auc 0.7648540013219329
val loss 0.48358694

CPU times: user 45min 44s, sys: 28.6 s, total: 46min 12s
Wall time: 47min 22s


1

# Test

In [12]:
test_df=test_df[test_df.content_type_id==False]
test_df = test_df[test_df['answered_correctly'] != -1].reset_index(drop = True, inplace = False)
test_df=test_df.merge(questions_df[['content_id', 'part']], how='left')
test_df.prior_question_elapsed_time.fillna(0,inplace=True)
test_df.prior_question_elapsed_time /=1000 
test_df.prior_question_elapsed_time.clip(lower=0,upper=300,inplace=True)
test_df.prior_question_elapsed_time = test_df.prior_question_elapsed_time.astype(np.int)
print("test shape after exlusion:",test_df.shape)
if config.lagtime == False:
    test_group = test_df[["user_id","content_id","answered_correctly","prior_question_elapsed_time","part"]]\
                    .groupby("user_id")\
                    .apply(lambda r: (r.content_id.values,r.answered_correctly.values,\
                                      r.prior_question_elapsed_time.values,r.part.values))
else:
    time_dict = dict()
    test_df = feature_time_lag(test_df, time_dict)
    del time_dict
    test_df.time_lag.fillna(0,inplace=True) 
    test_df.time_lag /=1000 
    test_df.time_lag.clip(lower=0,upper=300,inplace=True)    
    test_df.time_lag = test_df.time_lag.astype(np.int)    
    test_group = test_df[["user_id","content_id","answered_correctly","time_lag","part"]]\
                    .groupby("user_id")\
                    .apply(lambda r: (r.content_id.values,r.answered_correctly.values,\
                                      r.time_lag.values,r.part.values))
print("test size: ",test_group.shape)

test shape after exlusion: (255276, 8)
test size:  (1000,)


In [13]:
test_dataset = DKTDataset(test_group,max_seq = config.MAX_SEQ)
test_loader = DataLoader(test_dataset,
                          batch_size=config.BATCH_SIZE,
                          num_workers=8,
                          shuffle=False) 

In [14]:
print(trainer.checkpoint_callback.best_model_path)
print(trainer.checkpoint_callback.best_model_score)

/kaggle/working/lightning_logs/version_0/checkpoints/epoch=9.ckpt
tensor(0.4840, device='cuda:0')


In [15]:
model_clone = saint_plus.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
trainer_clone = pl.Trainer(gpus=-1,max_epochs=config.EPOCHS) 
trainer_clone.test(model_clone,test_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

test auc 0.7655655685204239
test loss 0.48228472
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_auc': 0.7655655685204239,
 'test_loss': tensor(0.3455, device='cuda:0'),
 'test_loss_epoch': tensor(0.4888, device='cuda:0')}
--------------------------------------------------------------------------------



[{'test_auc': 0.7655655685204239,
  'test_loss': 0.3454764187335968,
  'test_loss_epoch': 0.48878923058509827}]