In [1]:
import torch
import torch.nn as nn
from transformers import BertConfig, BertModel


class EmbeddingLayer(nn.Module):
    def __init__(self, args, hidden_dim):
        super(EmbeddingLayer, self).__init__()

        self.args = args
        self.device = args.device
        self.hidden_dim = hidden_dim

        labels_dim = self.hidden_dim // (len(self.args.n_embeddings) + 1)
        interaction_dim = self.hidden_dim - (labels_dim * len(self.args.n_embeddings))

        self.embedding_interaction = nn.Embedding(3, interaction_dim)
        self.embeddings = nn.ModuleDict(
            {k: nn.Embedding(v + 1, labels_dim) for k, v in self.args.n_embeddings.items()}  # plus 1 for padding
        )

    def forward(self, batch):
        embed_interaction = self.embedding_interaction(batch["interaction"])
        embed = torch.cat(
            [embed_interaction] + [self.embeddings[k](batch[k]) for k in self.args.n_embeddings.keys()], 2
        )
        return embed


class LinearLayer(nn.Module):
    def __init__(self, args, hidden_dim):
        super(LinearLayer, self).__init__()

        self.args = args
        self.device = args.device

        self.hidden_dim = hidden_dim
        in_features = len(self.args.n_linears)
        self.fc_layer = nn.Linear(in_features, self.hidden_dim)

    def forward(self, batch):
        cont_v = torch.stack([batch[k] for k in self.args.n_linears]).permute(1, 2, 0)
        output = self.fc_layer(cont_v)
        return output


class Bert(nn.Module):
    def __init__(self, args):
        super(Bert, self).__init__()
        self.args = args
        self.device = args.device

        # Defining some parameters
        self.hidden_dim = self.args.hidden_dim
        self.n_layers = self.args.n_layers
        
        self.emb_layer = EmbeddingLayer(args, self.hidden_dim // 2)
        self.nli_layer = LinearLayer(args, self.hidden_dim // 2)
        
        self.comb_proj = nn.Linear(self.hidden_dim, self.hidden_dim)

        # Bert config
        self.config = BertConfig(
            3,  # not used
            hidden_size=self.hidden_dim,
            num_hidden_layers=self.args.n_layers,
            num_attention_heads=self.args.n_heads,
            max_position_embeddings=self.args.max_seq_len,
        )

        # Defining the layers
        # Bert Layer
        self.encoder = BertModel(self.config)

        # Fully connected layer
        self.fc = nn.Linear(self.args.hidden_dim, 1)
        self.activation = nn.Sigmoid()

    def forward(self, batch):
        batch_size = batch["interaction"].size(0)
    
        embed = self.emb_layer(batch)
        nnbed = self.nli_layer(batch)
        
        embed = torch.cat([embed, nnbed], 2)
        X = self.comb_proj(embed)

        # Bert
        encoded_layers = self.encoder(inputs_embeds=X, attention_mask=batch["mask"])
        out = encoded_layers[0]
        out = out.contiguous().view(batch_size, -1, self.hidden_dim)
        out = self.fc(out)
        preds = self.activation(out).view(batch_size, -1)

        return preds

In [2]:
import sys
sys.path.append("..")

from utils import get_args, get_root_dir
from helper import get_dkt_loader

In [3]:
args = get_args()
args.root_dir = get_root_dir("../bert_test")
args.data_dir = "../../input/data/train_dataset/"

In [7]:
import torch
from models.lstm.model import LSTM
from trainer import DKTTrainer

class FeatureTestTrainer(DKTTrainer):
    def _process_batch(self, batch):
        batch['mask'] = batch['mask'].type(torch.FloatTensor)
        batch["answerCode"] = batch["answerCode"].type(torch.FloatTensor)

        batch["interaction"] = batch["answerCode"] + 1
        batch["interaction"] = batch["interaction"].roll(shifts=1, dims=1)
        batch["mask"] = batch["mask"].roll(shifts=1, dims=1)
        batch["mask"][:, 0] = 0
        batch["interaction"] = (batch["interaction"] * batch["mask"]).to(torch.int64)
        
        
        for k in self.args.n_linears: # 수치형
            batch[k] = batch[k].type(torch.FloatTensor)
            
        for k, v in self.args.n_embeddings.items(): # 범주형
            batch[k] = batch[k].to(torch.int64)
            
        for k in batch.keys():
            batch[k] = batch[k].to(self.args.device)
        
        return batch

In [13]:
from helper import get_dkt_dataset, get_simple_data

In [22]:
args.n_heads = 8
args.n_layers = 12

In [None]:
train_data, valid_data, test_data = get_simple_data(args)

In [23]:
trainer = FeatureTestTrainer(args, Bert)

In [24]:
trainer.debug(train_data, valid_data, test_data)

In [25]:
from IPython.display import clear_output

## n_heads 4, n_layers 2

In [21]:
auc, acc = trainer.run(train_data, valid_data, test_data)
clear_output()
print(f"auc : {auc} acc : {acc}")
print(f"logging path: {trainer.prefix_save_path}")

auc : 0.7469929482776747 acc : 0.6935969868173258
logging path: ../bert_test/LOG_[06.11_03:58]


### Test 2

In [32]:
args.hidden_dim = 256

In [40]:
args.n_heads, args.n_layers, args.hidden_dim, args.batch_size

(8, 12, 256, 64)

In [37]:
trainer = FeatureTestTrainer(args, Bert)

In [38]:
trainer.debug(train_data, valid_data, test_data)

In [39]:
auc, acc = trainer.run(train_data, valid_data, test_data)
clear_output()
print(f"auc : {auc} acc : {acc}")
print(f"logging path: {trainer.prefix_save_path}")

auc : 0.760216754306576 acc : 0.6949152542372882
logging path: ../bert_test/LOG_[06.11_04:04]
