In [1]:
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.autograd import Variable
import math
import pandas as pd

In [2]:
def build_dense_graph(node_num):
    graph = 1. / (node_num - 1) * np.ones((node_num, node_num))
    np.fill_diagonal(graph, 0)
    graph = torch.from_numpy(graph).float()
    return graph


def sample_gumbel(shape, eps=1e-10):
    U = torch.rand(shape).float()
    return - torch.log(eps - torch.log(U + eps))


def gumbel_softmax_sample(logits, tau=1, eps=1e-10, dim=-1):
    """
    NOTE: Stolen from https://github.com/ethanfetaya/NRI/blob/master/utils.py
    Draw a sample from the Gumbel-Softmax distribution
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
    """
    gumbel_noise = sample_gumbel(logits.size(), eps=eps)
    if logits.is_cuda:
        gumbel_noise = gumbel_noise.cuda()
    y = logits + Variable(gumbel_noise)
    return F.softmax(y / tau, dim=dim)


def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
    """
    NOTE: Stolen from https://github.com/ethanfetaya/NRI/blob/master/utils.py
    Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
      logits: [batch_size, n_class] unnormalized log-probs
      tau: non-negative scalar temperature
      hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
      [batch_size, n_class] sample from the Gumbel-Softmax distribution.
      If hard=True, then the returned sample will be one-hot, otherwise it will
      be a probability distribution that sums to 1 across classes
    Constraints:
    - this implementation only works on batch_size x num_features tensor for now
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
    """
    y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps, dim=dim)
    if hard:
        shape = logits.size()
        _, k = y_soft.data.max(-1)
        # this bit is based on
        # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
        y_hard = torch.zeros(*shape)
        if y_soft.is_cuda:
            y_hard = y_hard.cuda()
        y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0)
        # this cool bit of code achieves two things:
        # - makes the output value exactly one-hot (since we add then
        #   subtract y_soft value)
        # - makes the gradient equal to y_soft gradient (since we strip
        #   all other gradients)
        y = Variable(y_hard - y_soft.data) + y_soft
    else:
        y = y_soft
    return y


def kl_categorical(preds, log_prior, concept_num, eps=1e-16):
    kl_div = preds * (torch.log(preds + eps) - log_prior)
    return kl_div.sum() / (concept_num * preds.size(0))


def kl_categorical_uniform(preds, concept_num, num_edge_types, add_const=False, eps=1e-16):
    kl_div = preds * torch.log(preds + eps)
    if add_const:
        const = np.log(num_edge_types)
        kl_div += const
    return kl_div.sum() / (concept_num * preds.size(0))


def nll_gaussian(preds, target, variance, add_const=False):
    # pred: [concept_num, embedding_dim]
    # target: [concept_num, embedding_dim]
    neg_log_p = ((preds - target) ** 2 / (2 * variance))
    if add_const:
        const = 0.5 * np.log(2 * np.pi * variance)
        neg_log_p += const
    return neg_log_p.mean()


# Calculate accuracy of prediction result and its corresponding label
# output: tensor, labels: tensor
def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels.reshape(-1)).double()
    correct = correct.sum()
    return correct / len(labels)

In [3]:
class MLP(nn.Module):
    """Two-layer fully-connected ReLU net with batch norm."""

    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0., bias=True):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
        self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
        self.norm = nn.BatchNorm1d(output_dim)
        # the paper said they added Batch Normalization for the output of MLPs, as shown in Section 4.2
        self.dropout = dropout
        self.output_dim = output_dim
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def batch_norm(self, inputs):
        if inputs.numel() == self.output_dim or inputs.numel() == 0:
            # batch_size == 1 or 0 will cause BatchNorm error, so return the input directly
            return inputs
        if len(inputs.size()) == 3:
            x = inputs.view(inputs.size(0) * inputs.size(1), -1)
            x = self.norm(x)
            return x.view(inputs.size(0), inputs.size(1), -1)
        else:  # len(input_size()) == 2
            return self.norm(inputs)

    def forward(self, inputs):
        x = F.relu(self.fc1(inputs))
        x = F.dropout(x, self.dropout, training=self.training)  # pay attention to add training=self.training
        x = F.relu(self.fc2(x))
        return self.batch_norm(x)


class EraseAddGate(nn.Module):
    """
    Erase & Add Gate module
    NOTE: this erase & add gate is a bit different from that in DKVMN.
    For more information about Erase & Add gate, please refer to the paper "Dynamic Key-Value Memory Networks for Knowledge Tracing"
    The paper can be found in https://arxiv.org/abs/1611.08108
    """

    def __init__(self, feature_dim, concept_num, bias=True):
        super(EraseAddGate, self).__init__()
        # weight
        self.weight = nn.Parameter(torch.rand(concept_num))
        self.reset_parameters()
        # erase gate
        self.erase = nn.Linear(feature_dim, feature_dim, bias=bias)
        # add gate
        self.add = nn.Linear(feature_dim, feature_dim, bias=bias)

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(0))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x):
        r"""
        Params:
            x: input feature matrix
        Shape:
            x: [batch_size, concept_num, feature_dim]
            res: [batch_size, concept_num, feature_dim]
        Return:
            res: returned feature matrix with old information erased and new information added
        The GKT paper didn't provide detailed explanation about this erase-add gate. As the erase-add gate in the GKT only has one input parameter,
        this gate is different with that of the DKVMN. We used the input matrix to build the erase and add gates, rather than $\mathbf{v}_{t}$ vector in the DKVMN.
        """
        erase_gate = torch.sigmoid(self.erase(x))  # [batch_size, concept_num, feature_dim]
        # self.weight.unsqueeze(dim=1) shape: [concept_num, 1]
        tmp_x = x - self.weight.unsqueeze(dim=1) * erase_gate * x
        add_feat = torch.tanh(self.add(x))  # [batch_size, concept_num, feature_dim]
        res = tmp_x + self.weight.unsqueeze(dim=1) * add_feat
        return res


class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention
    NOTE: Stole and modify from https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Modules.py
    """

    def __init__(self, temperature, attn_dropout=0.):
        super().__init__()
        self.temperature = temperature
        self.dropout = attn_dropout

    def forward(self, q, k, mask=None):
        r"""
        Parameters:
            q: multi-head query matrix
            k: multi-head key matrix
            mask: mask matrix
        Shape:
            q: [n_head, mask_num, embedding_dim]
            k: [n_head, concept_num, embedding_dim]
        Return: attention score of all queries
        """
        attn = torch.matmul(q / self.temperature, k.transpose(1, 2))  # [n_head, mask_number, concept_num]
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)
        # pay attention to add training=self.training!
        attn = F.dropout(F.softmax(attn, dim=0), self.dropout, training=self.training)  # pay attention that dim=-1 is not as good as dim=0!
        return attn


class MLPEncoder(nn.Module):
    """
    MLP encoder module.
    NOTE: Stole and modify the code from https://github.com/ethanfetaya/NRI/blob/master/modules.py
    """
    def __init__(self, input_dim, hidden_dim, output_dim, factor=True, dropout=0., bias=True):
        super(MLPEncoder, self).__init__()
        self.factor = factor
        self.mlp = MLP(input_dim * 2, hidden_dim, hidden_dim, dropout=dropout, bias=bias)
        self.mlp2 = MLP(hidden_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias)
        if self.factor:
            self.mlp3 = MLP(hidden_dim * 3, hidden_dim, hidden_dim, dropout=dropout, bias=bias)
        else:
            self.mlp3 = MLP(hidden_dim * 2, hidden_dim, hidden_dim, dropout=dropout, bias=bias)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)

    def node2edge(self, x, sp_send, sp_rec):
        # NOTE: Assumes that we have the same graph across all samples.
        receivers = torch.matmul(sp_rec, x)
        senders = torch.matmul(sp_send, x)
        edges = torch.cat([senders, receivers], dim=1)
        return edges

    def edge2node(self, x, sp_send_t, sp_rec_t):
        # NOTE: Assumes that we have the same graph across all samples.
        incoming = torch.matmul(sp_rec_t, x)
        return incoming

    def forward(self, inputs, sp_send, sp_rec, sp_send_t, sp_rec_t):
        r"""
        Parameters:
            inputs: input concept embedding matrix
            sp_send: one-hot encoded send-node index(sparse tensor)
            sp_rec: one-hot encoded receive-node index(sparse tensor)
            sp_send_t: one-hot encoded send-node index(sparse tensor, transpose)
            sp_rec_t: one-hot encoded receive-node index(sparse tensor, transpose)
        Shape:
            inputs: [concept_num, embedding_dim]
            sp_send: [edge_num, concept_num]
            sp_rec: [edge_num, concept_num]
            sp_send_t: [concept_num, edge_num]
            sp_rec_t: [concept_num, edge_num]
        Return:
            output: [edge_num, edge_type_num]
        """
        x = self.node2edge(inputs, sp_send, sp_rec)  # [edge_num, 2 * embedding_dim]
        x = self.mlp(x)  # [edge_num, hidden_num]
        x_skip = x

        if self.factor:
            x = self.edge2node(x, sp_send_t, sp_rec_t)  # [concept_num, hidden_num]
            x = self.mlp2(x)  # [concept_num, hidden_num]
            x = self.node2edge(x, sp_send, sp_rec)  # [edge_num, 2 * hidden_num]
            x = torch.cat((x, x_skip), dim=1)  # Skip connection  shape: [edge_num, 3 * hidden_num]
            x = self.mlp3(x)  # [edge_num, hidden_num]
        else:
            x = self.mlp2(x)  # [edge_num, hidden_num]
            x = torch.cat((x, x_skip), dim=1)  # Skip connection  shape: [edge_num, 2 * hidden_num]
            x = self.mlp3(x)  # [edge_num, hidden_num]
        output = self.fc_out(x)  # [edge_num, output_dim]
        return output


class MLPDecoder(nn.Module):
    """
    MLP decoder module.
    NOTE: Stole and modify the code from https://github.com/ethanfetaya/NRI/blob/master/modules.py
    """

    def __init__(self, input_dim, msg_hidden_dim, msg_output_dim, hidden_dim, edge_type_num, dropout=0., bias=True):
        super(MLPDecoder, self).__init__()
        self.msg_out_dim = msg_output_dim
        self.edge_type_num = edge_type_num
        self.dropout = dropout

        self.msg_fc1 = nn.ModuleList([nn.Linear(2 * input_dim, msg_hidden_dim, bias=bias) for _ in range(edge_type_num)])
        self.msg_fc2 = nn.ModuleList([nn.Linear(msg_hidden_dim, msg_output_dim, bias=bias) for _ in range(edge_type_num)])
        self.out_fc1 = nn.Linear(msg_output_dim, hidden_dim, bias=bias)
        self.out_fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias)
        self.out_fc3 = nn.Linear(hidden_dim, input_dim, bias=bias)

    def node2edge(self, x, sp_send, sp_rec):
        receivers = torch.matmul(sp_rec, x)  # [edge_num, embedding_dim]
        senders = torch.matmul(sp_send, x)  # [edge_num, embedding_dim]
        edges = torch.cat([senders, receivers], dim=-1)  # [edge_num, 2 * embedding_dim]
        return edges

    def edge2node(self, x, sp_send_t, sp_rec_t):
        # NOTE: Assumes that we have the same graph across all samples.
        incoming = torch.matmul(sp_rec_t, x)
        return incoming

    def forward(self, inputs, rel_type, sp_send, sp_rec, sp_send_t, sp_rec_t):
        r"""
        Parameters:
            inputs: input concept embedding matrix
            rel_type: inferred edge weights for all edge types from MLPEncoder
            sp_send: one-hot encoded send-node index(sparse tensor)
            sp_rec: one-hot encoded receive-node index(sparse tensor)
            sp_send_t: one-hot encoded send-node index(sparse tensor, transpose)
            sp_rec_t: one-hot encoded receive-node index(sparse tensor, transpose)
        Shape:
            inputs: [concept_num, embedding_dim]
            sp_send: [edge_num, concept_num]
            sp_rec: [edge_num, concept_num]
            sp_send_t: [concept_num, edge_num]
            sp_rec_t: [concept_num, edge_num]
        Return:
            output: [edge_num, edge_type_num]
        """
        # NOTE: Assumes that we have the same graph across all samples.
        # Node2edge
        pre_msg = self.node2edge(inputs, sp_send, sp_rec)
        all_msgs = Variable(torch.zeros(pre_msg.size(0), self.msg_out_dim, device=inputs.device))  # [edge_num, msg_out_dim]
        for i in range(self.edge_type_num):
            msg = F.relu(self.msg_fc1[i](pre_msg))
            msg = F.dropout(msg, self.dropout, training=self.training)
            msg = F.relu(self.msg_fc2[i](msg))
            msg = msg * rel_type[:, i:i + 1]
            all_msgs += msg

        # Aggregate all msgs to receiver
        agg_msgs = self.edge2node(all_msgs, sp_send_t, sp_rec_t)  # [concept_num, msg_out_dim]
        # Output MLP
        pred = F.dropout(F.relu(self.out_fc1(agg_msgs)), self.dropout, training=self.training)  # [concept_num, hidden_dim]
        pred = F.dropout(F.relu(self.out_fc2(pred)), self.dropout, training=self.training)  # [concept_num, hidden_dim]
        pred = self.out_fc3(pred)  # [concept_num, embedding_dim]
        return pred

In [4]:
class GKT(nn.Module):

    def __init__(self, concept_num, hidden_dim, embedding_dim, edge_type_num, graph_type, graph=None, graph_model=None, dropout=0.5, bias=True, binary=False, has_cuda=False):
        super(GKT, self).__init__()
        self.concept_num = concept_num
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        self.edge_type_num = edge_type_num

        self.res_len = 2 if binary else 12
        self.has_cuda = has_cuda

        assert graph_type in ['Dense', 'Transition', 'DKT', 'PAM', 'MHA', 'VAE']
        self.graph_type = graph_type
        if graph_type in ['Dense', 'Transition', 'DKT']:
            assert edge_type_num == 2
            assert graph is not None and graph_model is None
            self.graph = nn.Parameter(graph)  # [concept_num, concept_num]
            self.graph.requires_grad = False  # fix parameter
            self.graph_model = graph_model
        else:  # ['PAM', 'MHA', 'VAE']
            assert graph is None
            self.graph = graph  # None
            if graph_type == 'PAM':
                assert graph_model is None
                self.graph = nn.Parameter(torch.rand(concept_num, concept_num))
            else:
                assert graph_model is not None
            self.graph_model = graph_model

        # one-hot feature and question
        one_hot_feat = torch.eye(self.res_len * self.concept_num)
        self.one_hot_feat = one_hot_feat.cuda() if self.has_cuda else one_hot_feat
        self.one_hot_q = torch.eye(self.concept_num, device=self.one_hot_feat.device)
        zero_padding = torch.zeros(1, self.concept_num, device=self.one_hot_feat.device)
        self.one_hot_q = torch.cat((self.one_hot_q, zero_padding), dim=0)
        # concept and concept & response embeddings
        self.emb_x = nn.Embedding(self.res_len * concept_num, embedding_dim)
        # last embedding is used for padding, so dim + 1
        self.emb_c = nn.Embedding(concept_num + 1, embedding_dim, padding_idx=-1)

        # f_self function and f_neighbor functions
        mlp_input_dim = hidden_dim + embedding_dim
        self.f_self = MLP(mlp_input_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias)
        self.f_neighbor_list = nn.ModuleList()
        if graph_type in ['Dense', 'Transition', 'DKT', 'PAM']:
            # f_in and f_out functions
            self.f_neighbor_list.append(MLP(2 * mlp_input_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias))
            self.f_neighbor_list.append(MLP(2 * mlp_input_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias))
        else:  # ['MHA', 'VAE']
            for i in range(edge_type_num):
                self.f_neighbor_list.append(MLP(2 * mlp_input_dim, hidden_dim, hidden_dim, dropout=dropout, bias=bias))

        # Erase & Add Gate
        self.erase_add_gate = EraseAddGate(hidden_dim, concept_num)
        # Gate Recurrent Unit
        self.gru = nn.GRUCell(hidden_dim, hidden_dim, bias=bias)
        # prediction layer
        self.predict = nn.Linear(hidden_dim, 1, bias=bias)

    # Aggregate step, as shown in Section 3.2.1 of the paper
    def _aggregate(self, xt, qt, ht, batch_size):
        r"""
        Parameters:
            xt: input one-hot question answering features at the current timestamp
            qt: question indices for all students in a batch at the current timestamp
            ht: hidden representations of all concepts at the current timestamp
            batch_size: the size of a student batch
        Shape:
            xt: [batch_size]
            qt: [batch_size]
            ht: [batch_size, concept_num, hidden_dim]
            tmp_ht: [batch_size, concept_num, hidden_dim + embedding_dim]
        Return:
            tmp_ht: aggregation results of concept hidden knowledge state and concept(& response) embedding
        """
        qt_mask = torch.ne(qt, -1)  # [batch_size], qt != -1
        x_idx_mat = torch.arange(self.res_len * self.concept_num, device=xt.device)
        x_embedding = self.emb_x(x_idx_mat)  # [res_len * concept_num, embedding_dim]
        masked_feat = F.embedding(xt[qt_mask], self.one_hot_feat)  # [mask_num, res_len * concept_num]
        res_embedding = masked_feat.mm(x_embedding)  # [mask_num, embedding_dim]
        mask_num = res_embedding.shape[0]

        concept_idx_mat = self.concept_num * torch.ones((batch_size, self.concept_num), device=xt.device).long()
        concept_idx_mat[qt_mask, :] = torch.arange(self.concept_num, device=xt.device)
        concept_embedding = self.emb_c(concept_idx_mat)  # [batch_size, concept_num, embedding_dim]

        index_tuple = (torch.arange(mask_num, device=xt.device), qt[qt_mask].long())
        concept_embedding[qt_mask] = concept_embedding[qt_mask].index_put(index_tuple, res_embedding)
        tmp_ht = torch.cat((ht, concept_embedding), dim=-1)  # [batch_size, concept_num, hidden_dim + embedding_dim]
        return tmp_ht

    # GNN aggregation step, as shown in 3.3.2 Equation 1 of the paper
    def _agg_neighbors(self, tmp_ht, qt):
        r"""
        Parameters:
            tmp_ht: temporal hidden representations of all concepts after the aggregate step
            qt: question indices for all students in a batch at the current timestamp
        Shape:
            tmp_ht: [batch_size, concept_num, hidden_dim + embedding_dim]
            qt: [batch_size]
            m_next: [batch_size, concept_num, hidden_dim]
        Return:
            m_next: hidden representations of all concepts aggregating neighboring representations at the next timestamp
            concept_embedding: input of VAE (optional)
            rec_embedding: reconstructed input of VAE (optional)
            z_prob: probability distribution of latent variable z in VAE (optional)
        """
        qt_mask = torch.ne(qt, -1)  # [batch_size], qt != -1
        masked_qt = qt[qt_mask]  # [mask_num, ]
        masked_tmp_ht = tmp_ht[qt_mask]  # [mask_num, concept_num, hidden_dim + embedding_dim]
        mask_num = masked_tmp_ht.shape[0]
        self_index_tuple = (torch.arange(mask_num, device=qt.device), masked_qt.long())
        self_ht = masked_tmp_ht[self_index_tuple]  # [mask_num, hidden_dim + embedding_dim]
        self_features = self.f_self(self_ht)  # [mask_num, hidden_dim]
        expanded_self_ht = self_ht.unsqueeze(dim=1).repeat(1, self.concept_num, 1)  #[mask_num, concept_num, hidden_dim + embedding_dim]
        neigh_ht = torch.cat((expanded_self_ht, masked_tmp_ht), dim=-1)  #[mask_num, concept_num, 2 * (hidden_dim + embedding_dim)]
        concept_embedding, rec_embedding, z_prob = None, None, None

        if self.graph_type in ['Dense', 'Transition', 'DKT', 'PAM']:
            adj = self.graph[masked_qt.long(), :].unsqueeze(dim=-1)  # [mask_num, concept_num, 1]
            reverse_adj = self.graph[:, masked_qt.long()].transpose(0, 1).unsqueeze(dim=-1)  # [mask_num, concept_num, 1]
            # self.f_neighbor_list[0](neigh_ht) shape: [mask_num, concept_num, hidden_dim]
            neigh_features = adj * self.f_neighbor_list[0](neigh_ht) + reverse_adj * self.f_neighbor_list[1](neigh_ht)
        else:  # ['MHA', 'VAE']
            concept_index = torch.arange(self.concept_num, device=qt.device)
            concept_embedding = self.emb_c(concept_index)  # [concept_num, embedding_dim]
            if self.graph_type == 'MHA':
                query = self.emb_c(masked_qt)
                key = concept_embedding
                att_mask = Variable(torch.ones(self.edge_type_num, mask_num, self.concept_num, device=qt.device))
                for k in range(self.edge_type_num):
                    index_tuple = (torch.arange(mask_num, device=qt.device), masked_qt.long())
                    att_mask[k] = att_mask[k].index_put(index_tuple, torch.zeros(mask_num, device=qt.device))
                graphs = self.graph_model(masked_qt, query, key, att_mask)
            else:  # self.graph_type == 'VAE'
                sp_send, sp_rec, sp_send_t, sp_rec_t = self._get_edges(masked_qt)
                graphs, rec_embedding, z_prob = self.graph_model(concept_embedding, sp_send, sp_rec, sp_send_t, sp_rec_t)
            neigh_features = 0
            for k in range(self.edge_type_num):
                adj = graphs[k][masked_qt, :].unsqueeze(dim=-1)  # [mask_num, concept_num, 1]
                if k == 0:
                    neigh_features = adj * self.f_neighbor_list[k](neigh_ht)
                else:
                    neigh_features = neigh_features + adj * self.f_neighbor_list[k](neigh_ht)
            if self.graph_type == 'MHA':
                neigh_features = 1. / self.edge_type_num * neigh_features
        # neigh_features: [mask_num, concept_num, hidden_dim]
        m_next = tmp_ht[:, :, :self.hidden_dim]
        m_next[qt_mask] = neigh_features
        m_next[qt_mask] = m_next[qt_mask].index_put(self_index_tuple, self_features)
        return m_next, concept_embedding, rec_embedding, z_prob

    # Update step, as shown in Section 3.3.2 of the paper
    def _update(self, tmp_ht, ht, qt):
        r"""
        Parameters:
            tmp_ht: temporal hidden representations of all concepts after the aggregate step
            ht: hidden representations of all concepts at the current timestamp
            qt: question indices for all students in a batch at the current timestamp
        Shape:
            tmp_ht: [batch_size, concept_num, hidden_dim + embedding_dim]
            ht: [batch_size, concept_num, hidden_dim]
            qt: [batch_size]
            h_next: [batch_size, concept_num, hidden_dim]
        Return:
            h_next: hidden representations of all concepts at the next timestamp
            concept_embedding: input of VAE (optional)
            rec_embedding: reconstructed input of VAE (optional)
            z_prob: probability distribution of latent variable z in VAE (optional)
        """
        qt_mask = torch.ne(qt, -1)  # [batch_size], qt != -1
        mask_num = qt_mask.nonzero().shape[0]
        # GNN Aggregation
        m_next, concept_embedding, rec_embedding, z_prob = self._agg_neighbors(tmp_ht, qt)  # [batch_size, concept_num, hidden_dim]
        # Erase & Add Gate
        m_next[qt_mask] = self.erase_add_gate(m_next[qt_mask])  # [mask_num, concept_num, hidden_dim]
        # GRU
        h_next = m_next
        res = self.gru(m_next[qt_mask].reshape(-1, self.hidden_dim), ht[qt_mask].reshape(-1, self.hidden_dim))  # [mask_num * concept_num, hidden_num]
        index_tuple = (torch.arange(mask_num, device=qt_mask.device), )
        h_next[qt_mask] = h_next[qt_mask].index_put(index_tuple, res.reshape(-1, self.concept_num, self.hidden_dim))
        return h_next, concept_embedding, rec_embedding, z_prob

    # Predict step, as shown in Section 3.3.3 of the paper
    def _predict(self, h_next, qt):
        r"""
        Parameters:
            h_next: hidden representations of all concepts at the next timestamp after the update step
            qt: question indices for all students in a batch at the current timestamp
        Shape:
            h_next: [batch_size, concept_num, hidden_dim]
            qt: [batch_size]
            y: [batch_size, concept_num]
        Return:
            y: predicted correct probability of all concepts at the next timestamp
        """
        qt_mask = torch.ne(qt, -1)  # [batch_size], qt != -1
        y = self.predict(h_next).squeeze(dim=-1)  # [batch_size, concept_num]
        y[qt_mask] = torch.sigmoid(y[qt_mask])  # [batch_size, concept_num]
        return y

    def _get_next_pred(self, yt, q_next):
        r"""
        Parameters:
            yt: predicted correct probability of all concepts at the next timestamp
            q_next: question index matrix at the next timestamp
            batch_size: the size of a student batch
        Shape:
            y: [batch_size, concept_num]
            questions: [batch_size, seq_len]
            pred: [batch_size, ]
        Return:
            pred: predicted correct probability of the question answered at the next timestamp
        """
        next_qt = q_next
        next_qt = torch.where(next_qt != -1, next_qt, self.concept_num * torch.ones_like(next_qt, device=yt.device))
        one_hot_qt = F.embedding(next_qt.long(), self.one_hot_q)  # [batch_size, concept_num]
        # dot product between yt and one_hot_qt
        pred = (yt * one_hot_qt).sum(dim=1)  # [batch_size, ]
        return pred

    # Get edges for edge inference in VAE
    def _get_edges(self, masked_qt):
        r"""
        Parameters:
            masked_qt: qt index with -1 padding values removed
        Shape:
            masked_qt: [mask_num, ]
            rel_send: [edge_num, concept_num]
            rel_rec: [edge_num, concept_num]
        Return:
            rel_send: from nodes in edges which send messages to other nodes
            rel_rec:  to nodes in edges which receive messages from other nodes
        """
        mask_num = masked_qt.shape[0]
        row_arr = masked_qt.cpu().numpy().reshape(-1, 1)  # [mask_num, 1]
        row_arr = np.repeat(row_arr, self.concept_num, axis=1)  # [mask_num, concept_num]
        col_arr = np.arange(self.concept_num).reshape(1, -1)  # [1, concept_num]
        col_arr = np.repeat(col_arr, mask_num, axis=0)  # [mask_num, concept_num]
        # add reversed edges
        new_row = np.vstack((row_arr, col_arr))  # [2 * mask_num, concept_num]
        new_col = np.vstack((col_arr, row_arr))  # [2 * mask_num, concept_num]
        row_arr = new_row.flatten()  # [2 * mask_num * concept_num, ]
        col_arr = new_col.flatten()  # [2 * mask_num * concept_num, ]
        data_arr = np.ones(2 * mask_num * self.concept_num)
        init_graph = sp.coo_matrix((data_arr, (row_arr, col_arr)), shape=(self.concept_num, self.concept_num))
        init_graph.setdiag(0)  # remove self-loop edges
        row_arr, col_arr, _ = sp.find(init_graph)
        row_tensor = torch.from_numpy(row_arr).long()
        col_tensor = torch.from_numpy(col_arr).long()
        one_hot_table = torch.eye(self.concept_num, self.concept_num)
        rel_send = F.embedding(row_tensor, one_hot_table)  # [edge_num, concept_num]
        rel_rec = F.embedding(col_tensor, one_hot_table)  # [edge_num, concept_num]
        sp_rec, sp_send = rel_rec.to_sparse(), rel_send.to_sparse()
        sp_rec_t, sp_send_t = rel_rec.T.to_sparse(), rel_send.T.to_sparse()
        sp_send = sp_send.to(device=masked_qt.device)
        sp_rec = sp_rec.to(device=masked_qt.device)
        sp_send_t = sp_send_t.to(device=masked_qt.device)
        sp_rec_t = sp_rec_t.to(device=masked_qt.device)
        return sp_send, sp_rec, sp_send_t, sp_rec_t

    def forward(self, features, questions):
        r"""
        Parameters:
            features: input one-hot matrix
            questions: question index matrix
        seq_len dimension needs padding, because different students may have learning sequences with different lengths.
        Shape:
            features: [batch_size, seq_len]
            questions: [batch_size, seq_len]
            pred_res: [batch_size, seq_len - 1]
        Return:
            pred_res: the correct probability of questions answered at the next timestamp
            concept_embedding: input of VAE (optional)
            rec_embedding: reconstructed input of VAE (optional)
            z_prob: probability distribution of latent variable z in VAE (optional)
        """
        batch_size, seq_len = features.shape
        ht = Variable(torch.zeros((batch_size, self.concept_num, self.hidden_dim), device=features.device))
        pred_list = []
        ec_list = []  # concept embedding list in VAE
        rec_list = []  # reconstructed embedding list in VAE
        z_prob_list = []  # probability distribution of latent variable z in VAE
        for i in range(seq_len):
            xt = features[:, i]  # [batch_size]
            qt = questions[:, i]  # [batch_size]
            qt_mask = torch.ne(qt, -1)  # [batch_size], next_qt != -1
            tmp_ht = self._aggregate(xt, qt, ht, batch_size)  # [batch_size, concept_num, hidden_dim + embedding_dim]
            h_next, concept_embedding, rec_embedding, z_prob = self._update(tmp_ht, ht, qt)  # [batch_size, concept_num, hidden_dim]
            ht[qt_mask] = h_next[qt_mask]  # update new ht
            yt = self._predict(h_next, qt)  # [batch_size, concept_num]
            if i < seq_len - 1:
                pred = self._get_next_pred(yt, questions[:, i + 1])
                pred_list.append(pred)
            ec_list.append(concept_embedding)
            rec_list.append(rec_embedding)
            z_prob_list.append(z_prob)
        pred_res = torch.stack(pred_list, dim=1)  # [batch_size, seq_len - 1]
        return pred_res, ec_list, rec_list, z_prob_list

In [5]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention module
    NOTE: Stole and modify from https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/SubLayers.py
    """

    def __init__(self, n_head, concept_num, input_dim, d_k, dropout=0.):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.concept_num = concept_num
        self.d_k = d_k
        self.w_qs = nn.Linear(input_dim, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(input_dim, n_head * d_k, bias=False)
        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5, attn_dropout=dropout)
        # inferred latent graph, used for saving and visualization
        self.graphs = nn.Parameter(torch.zeros(n_head, concept_num, concept_num))
        self.graphs.requires_grad = False

    def _get_graph(self, attn_score, qt):
        r"""
        Parameters:
            attn_score: attention score of all queries
            qt: masked question index
        Shape:
            attn_score: [n_head, mask_num, concept_num]
            qt: [mask_num]
        Return:
            graphs: n_head types of inferred graphs
        """
        graphs = Variable(torch.zeros(self.n_head, self.concept_num, self.concept_num, device=qt.device))
        for k in range(self.n_head):
            index_tuple = (qt.long(), )
            graphs[k] = graphs[k].index_put(index_tuple, attn_score[k])  # used for calculation
            #############################
            # here, we need to detach edges when storing it into self.graphs in case memory leak!
            self.graphs.data[k] = self.graphs.data[k].index_put(index_tuple, attn_score[k].detach())  # used for saving and visualization
            #############################
        return graphs

    def forward(self, qt, query, key, mask=None):
        r"""
        Parameters:
            qt: masked question index
            query: answered concept embedding for a student batch
            key: concept embedding matrix
            mask: mask matrix
        Shape:
            qt: [mask_num]
            query: [mask_num, embedding_dim]
            key: [concept_num, embedding_dim]
        Return:
            graphs: n_head types of inferred graphs
        """
        d_k, n_head = self.d_k, self.n_head
        len_q, len_k = query.size(0), key.size(0)

        # Pass through the pre-attention projection: lq x (n_head *dk)
        # Separate different heads: lq x n_head x dk
        q = self.w_qs(query).view(len_q, n_head, d_k)
        k = self.w_ks(key).view(len_k, n_head, d_k)

        # Transpose for attention dot product: n_head x lq x dk
        q, k = q.transpose(0, 1), k.transpose(0, 1)
        attn_score = self.attention(q, k, mask=mask)  # [n_head, mask_num, concept_num]
        graphs = self._get_graph(attn_score, qt)
        return graphs

In [6]:
class KTLoss(nn.Module):

    def __init__(self):
        super(KTLoss, self).__init__()

    def forward(self, pred_answers, real_answers):
        r"""
        Parameters:
            pred_answers: the correct probability of questions answered at the next timestamp
            real_answers: the real results(0 or 1) of questions answered at the next timestamp
        Shape:
            pred_answers: [batch_size, seq_len - 1]
            real_answers: [batch_size, seq_len]
        Return:
        """
        real_answers = real_answers[:, 1:]  # timestamp=1 ~ T
        # real_answers shape: [batch_size, seq_len - 1]
        # Here we can directly use nn.BCELoss, but this loss doesn't have ignore_index function
        answer_mask = torch.ne(real_answers, -1)
        pred_one, pred_zero = pred_answers, 1.0 - pred_answers  # [batch_size, seq_len - 1]

        # calculate auc and accuracy metrics
        try:
            y_true = real_answers[answer_mask].cpu().detach().numpy()
            y_pred = pred_one[answer_mask].cpu().detach().numpy()
            auc = roc_auc_score(y_true, y_pred)  # may raise ValueError
            output = torch.cat((pred_zero[answer_mask].reshape(-1, 1), pred_one[answer_mask].reshape(-1, 1)), dim=1)
            label = real_answers[answer_mask].reshape(-1, 1)
            acc = accuracy(output, label)
            acc = float(acc.cpu().detach().numpy())
        except ValueError as e:
            auc, acc = -1, -1

        # calculate NLL loss
        pred_one[answer_mask] = torch.log(pred_one[answer_mask])
        pred_zero[answer_mask] = torch.log(pred_zero[answer_mask])
        pred_answers = torch.cat((pred_zero.unsqueeze(dim=1), pred_one.unsqueeze(dim=1)), dim=1)
        # pred_answers shape: [batch_size, 2, seq_len - 1]
        nll_loss = nn.NLLLoss(ignore_index=-1)  # ignore masked values in real_answers
        loss = nll_loss(pred_answers, real_answers.long())
        return loss, auc, acc


class VAELoss(nn.Module):

    def __init__(self, concept_num, edge_type_num=2, prior=False, var=5e-5):
        super(VAELoss, self).__init__()
        self.concept_num = concept_num
        self.edge_type_num = edge_type_num
        self.prior = prior
        self.var = var

    def forward(self, ec_list, rec_list, z_prob_list, log_prior=None):
        time_stamp_num = len(ec_list)
        loss = 0
        for time_idx in range(time_stamp_num):
            output = rec_list[time_idx]
            target = ec_list[time_idx]
            prob = z_prob_list[time_idx]
            loss_nll = nll_gaussian(output, target, self.var)
            if self.prior:
                assert log_prior is not None
                loss_kl = kl_categorical(prob, log_prior, self.concept_num)
            else:
                loss_kl = kl_categorical_uniform(prob, self.concept_num, self.edge_type_num)
            if time_idx == 0:
                loss = loss_nll + loss_kl
            else:
                loss = loss + loss_nll + loss_kl
        return loss / time_stamp_num

In [7]:
import os
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

class KTDataset(Dataset):
    def __init__(self, features, questions, answers):
        super(KTDataset, self).__init__()
        self.features = features
        self.questions = questions
        self.answers = answers

    def __getitem__(self, index):
        return self.features[index], self.questions[index], self.answers[index]

    def __len__(self):
        return len(self.features)


def pad_collate(batch):
    (features, questions, answers) = zip(*batch)
    features = [torch.LongTensor(feat) for feat in features]
    questions = [torch.LongTensor(qt) for qt in questions]
    answers = [torch.LongTensor(ans) for ans in answers]
    feature_pad = pad_sequence(features, batch_first=True, padding_value=-1)
    question_pad = pad_sequence(questions, batch_first=True, padding_value=-1)
    answer_pad = pad_sequence(answers, batch_first=True, padding_value=-1)
    return feature_pad, question_pad, answer_pad


def load_dataset(file_path, batch_size, graph_type, dkt_graph_path=None, train_ratio=0.7, val_ratio=0.2, shuffle=True, model_type='GKT', use_binary=True, res_len=2, use_cuda=True):
    r"""
    Parameters:
        file_path: input file path of knowledge tracing data
        batch_size: the size of a student batch
        graph_type: the type of the concept graph
        shuffle: whether to shuffle the dataset or not
        use_cuda: whether to use GPU to accelerate training speed
    Return:
        concept_num: the number of all concepts(or questions)
        graph: the static graph is graph type is in ['Dense', 'Transition', 'DKT'], otherwise graph is None
        train_data_loader: data loader of the training dataset
        valid_data_loader: data loader of the validation dataset
        test_data_loader: data loader of the test dataset
    """
    df = pd.read_csv(file_path)
    if "skill_id" not in df.columns:
        raise KeyError(f"The column 'skill_id' was not found on {file_path}")
    if "correct" not in df.columns:
        raise KeyError(f"The column 'correct' was not found on {file_path}")
    if "user_id" not in df.columns:
        raise KeyError(f"The column 'user_id' was not found on {file_path}")

    # if not (df['correct'].isin([0, 1])).all():
    #     raise KeyError(f"The values of the column 'correct' must be 0 or 1.")

    # Step 1.1 - Remove questions without skill
    df.dropna(subset=['skill_id'], inplace=True)

    # Step 1.2 - Remove users with a single answer
    df = df.groupby('user_id').filter(lambda q: len(q) > 1).copy()

    # Step 2 - Enumerate skill id
    df['skill'], _ = pd.factorize(df['skill_id'], sort=True)  # we can also use problem_id to represent exercises

    # Step 3 - Cross skill id with answer to form a synthetic feature
    # use_binary: (0,1); !use_binary: (1,2,3,4,5,6,7,8,9,10,11,12). Either way, the correct result index is guaranteed to be 1
    if use_binary:
        df['skill_with_answer'] = df['skill'] * 2 + df['correct']
    else:
        df['skill_with_answer'] = df['skill'] * res_len + df['correct'] - 1


    # Step 4 - Convert to a sequence per user id and shift features 1 timestep
    feature_list = []
    question_list = []
    answer_list = []
    seq_len_list = []

    def get_data(series):
        feature_list.append(series['skill_with_answer'].tolist())
        question_list.append(series['skill'].tolist())
        answer_list.append(series['correct'].eq(1).astype('int').tolist())
        seq_len_list.append(series['correct'].shape[0])

    df.groupby('user_id').apply(get_data)
    max_seq_len = np.max(seq_len_list)
    print('max seq_len: ', max_seq_len)
    student_num = len(seq_len_list)
    print('student num: ', student_num)
    feature_dim = int(df['skill_with_answer'].max() + 1)
    print('feature_dim: ', feature_dim)
    question_dim = int(df['skill'].max() + 1)
    print('question_dim: ', question_dim)
    concept_num = question_dim

    # print('feature_dim:', feature_dim, 'res_len*question_dim:', res_len*question_dim)
    # assert feature_dim == res_len * question_dim

    kt_dataset = KTDataset(feature_list, question_list, answer_list)
    train_size = int(train_ratio * student_num)
    val_size = int(val_ratio * student_num)
    test_size = student_num - train_size - val_size
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(kt_dataset, [train_size, val_size, test_size])
    print('train_size: ', train_size, 'val_size: ', val_size, 'test_size: ', test_size)

    train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=pad_collate)
    valid_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=pad_collate)
    test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=pad_collate)

    graph = None
    if model_type == 'GKT':
        if graph_type == 'Dense':
            graph = build_dense_graph(concept_num)
        elif graph_type == 'Transition':
            graph = build_transition_graph(question_list, seq_len_list, train_dataset.indices, student_num, concept_num)
        elif graph_type == 'DKT':
            graph = build_dkt_graph(dkt_graph_path, concept_num)
        if use_cuda and graph_type in ['Dense', 'Transition', 'DKT']:
            graph = graph.cuda()
    return concept_num, graph, train_data_loader, valid_data_loader, test_data_loader


def build_transition_graph(question_list, seq_len_list, indices, student_num, concept_num):
    graph = np.zeros((concept_num, concept_num))
    student_dict = dict(zip(indices, np.arange(student_num)))
    for i in range(student_num):
        if i not in student_dict:
            continue
        questions = question_list[i]
        seq_len = seq_len_list[i]
        for j in range(seq_len - 1):
            pre = questions[j]
            next = questions[j + 1]
            graph[pre, next] += 1
    np.fill_diagonal(graph, 0)
    # row normalization
    rowsum = np.array(graph.sum(1))
    def inv(x):
        if x == 0:
            return x
        return 1. / x
    inv_func = np.vectorize(inv)
    r_inv = inv_func(rowsum).flatten()
    r_mat_inv = np.diag(r_inv)
    graph = r_mat_inv.dot(graph)
    # covert to tensor
    graph = torch.from_numpy(graph).float()
    return graph


def build_dkt_graph(file_path, concept_num):
    graph = np.loadtxt(file_path)
    assert graph.shape[0] == concept_num and graph.shape[1] == concept_num
    graph = torch.from_numpy(graph).float()
    return graph

In [8]:
args = {
    'seed':42,
    'data-dir': '/content/drive/MyDrive/GKT-data',
    'data-file': 'assistment_test15.csv',
    'save-dir': 'save',
    'graph-save-dir': '',
    'load-dir': '',
    'dkt-graph-dir': '',
    'dkt-graph': 'dkt-graph',
    'model': 'GKT',
    'hid-dim': 32,
    'emb-dim': 32,
    'attn-dim': 32,
    'vae-encoder-dim': 32,
    'vae-decoder-dim': 32,
    'edge-types': 2,
    'graph-type': 'MHA',
    'dropout': 0,
    'bias': True,
    'binary': True,
    'result-type': 12,
    'temp': 0.5,
    'hard': False,
    'no-factor': False,
    'prior': True,
    'var': 1,
    'epochs': 50,
    'batch-size': 128,
    'train-ratio': 0.6,
    'val-ratio': 0.2,
    'shuffle': True,
    'lr': 0.001,
    'lr-decay': 200,
    'gamma': 0.5,
    'test': False,
    'test-model-dir': ''
}

In [10]:
import random
args['cuda'] = torch.cuda.is_available()
args['factor'] = not args['no-factor']

random.seed(args['seed'])
np.random.seed(args['seed'])
torch.manual_seed(args['seed'])

if args['cuda']:
  torch.cuda.manual_seed(args['seed'])
  torch.cuda.manual_seed_all(args['seed'])
  torch.backends.cudnn.benchmark = False
  torch.backends.cudnn.deterministic = True

res_len = 2 if args['binary'] else args['result_type']

log = None
save_dir = args['save-dir']

exp_counter = 0
now = datetime.datetime.now()
# timestamp = now.isoformat()
timestamp = now.strftime('%Y-%m-%d %H-%M-%S')
    
model_file_name = 'GKT' + '-' + args['graph-type']
save_dir = '{}/exp{}/'.format(args['save-dir'], model_file_name + timestamp)
if not os.path.exists(save_dir):
  os.makedirs(save_dir)
meta_file = os.path.join(save_dir, 'metadata.pkl')
model_file = os.path.join(save_dir, model_file_name + '.pt')
optimizer_file = os.path.join(save_dir, model_file_name + '-Optimizer.pt')
scheduler_file = os.path.join(save_dir, model_file_name + '-Scheduler.pt')
log_file = os.path.join(save_dir, 'log.txt')
log = open(log_file, 'w')
pickle.dump({'args': args}, open(meta_file, "wb"))

dataset_path = os.path.join(args['data-dir'], args['data-file'])
dkt_graph_path = os.path.join(args['dkt-graph-dir'], args['dkt-graph'])
if not os.path.exists(dkt_graph_path):
    dkt_graph_path = None
concept_num, graph, train_loader, valid_loader, test_loader = load_dataset(dataset_path, args['batch-size'], args['graph-type'], dkt_graph_path=dkt_graph_path,
                                                                           train_ratio=args['train-ratio'], val_ratio=args['val-ratio'], shuffle=args['shuffle'],
                                                                           model_type=args['model'], use_cuda=args['cuda'])

if args['graph-type'] == 'MHA':
  graph_model = MultiHeadAttention(args['edge-types'], concept_num, args['emb-dim'], args['attn-dim'], dropout=args['dropout'])
elif args['graph-type'] == 'VAE':
  graph_model = VAE(args['emb_dim'], args['vae-encoder-dim'], args['edge-types'], args['vae-decoder-dim'], args['vae-decoder-dim'], concept_num,
                          edge_type_num=args['edge-types'], tau=args['temp'], factor=args['factor'], dropout=args['dropout'], bias=args['bias'])
  vae_loss = VAELoss(concept_num, edge_type_num=args['edge-types'], prior=args['prior'], var=args['var'])
  if args.cuda:
      vae_loss = vae_loss.cuda()
if args['cuda'] and args['graph-type'] in ['MHA', 'VAE']:
  graph_model = graph_model.cuda()
model = GKT(concept_num, args['hid-dim'], args['emb-dim'], args['edge-types'], args['graph-type'], graph=graph, graph_model=graph_model,
                dropout=args['dropout'], bias=args['bias'], has_cuda=args['cuda'])

kt_loss = KTLoss()

# build optimizer
optimizer = optim.Adam(model.parameters(), lr=args['lr'])
scheduler = lr_scheduler.StepLR(optimizer, step_size=args['lr-decay'], gamma=args['gamma'])

optimizer = optim.Adam(model.parameters(), lr=args['lr'])
scheduler = lr_scheduler.StepLR(optimizer, step_size=args['lr-decay'], gamma=args['gamma'])

if args['model'] == 'GKT' and args['prior']:
    prior = np.array([0.91, 0.03, 0.03, 0.03])  # TODO: hard coded for now
    print("Using prior")
    print(prior)
    log_prior = torch.FloatTensor(np.log(prior))
    log_prior = torch.unsqueeze(log_prior, 0)
    log_prior = torch.unsqueeze(log_prior, 0)
    log_prior = Variable(log_prior)
    if args['cuda']:
        log_prior = log_prior.cuda()

if args['cuda']:
    model = model.cuda()
    kt_loss = KTLoss()

NameError: ignored

In [43]:
def train(epoch, best_val_loss):
    t = time.time()
    loss_train = []
    kt_train = []
    vae_train = []
    auc_train = []
    acc_train = []
    if graph_model is not None:
        graph_model.train()
    model.train()
    for batch_idx, (features, questions, answers) in enumerate(train_loader):
        t1 = time.time()
        if args['cuda']:
            features, questions, answers = features.cuda(), questions.cuda(), answers.cuda()
        ec_list, rec_list, z_prob_list = None, None, None
        pred_res, ec_list, rec_list, z_prob_list = model(features, questions)
        loss_kt, auc, acc = kt_loss(pred_res, answers)
        kt_train.append(float(loss_kt.cpu().detach().numpy()))
        if auc != -1 and acc != -1:
            auc_train.append(auc)
            acc_train.append(acc)

        
        loss = loss_kt
        print('batch idx: ', batch_idx, 'loss kt: ', loss_kt.item(), 'auc: ', auc, 'acc: ', acc, end=' ')
        loss_train.append(float(loss.cpu().detach().numpy()))
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        del loss
        print('cost time: ', str(time.time() - t1))

    loss_val = []
    kt_val = []
    vae_val = []
    auc_val = []
    acc_val = []

    if graph_model is not None:
        graph_model.eval()
    model.eval()
    with torch.no_grad():
        for batch_idx, (features, questions, answers) in enumerate(valid_loader):
            if args['cuda']:
                features, questions, answers = features.cuda(), questions.cuda(), answers.cuda()
            ec_list, rec_list, z_prob_list = None, None, None
            pred_res, ec_list, rec_list, z_prob_list = model(features, questions)

            loss_kt, auc, acc = kt_loss(pred_res, answers)
            loss_kt = float(loss_kt.cpu().detach().numpy())
            kt_val.append(loss_kt)
            if auc != -1 and acc != -1:
                auc_val.append(auc)
                acc_val.append(acc)

            loss = loss_kt
            loss_val.append(loss)
            del loss
    
    print('Epoch: {:04d}'.format(epoch),
              'loss_train: {:.10f}'.format(np.mean(loss_train)),
              'auc_train: {:.10f}'.format(np.mean(auc_train)),
              'acc_train: {:.10f}'.format(np.mean(acc_train)),
              'loss_val: {:.10f}'.format(np.mean(loss_val)),
              'auc_val: {:.10f}'.format(np.mean(auc_val)),
              'acc_val: {:.10f}'.format(np.mean(acc_val)),
              'time: {:.4f}s'.format(time.time() - t))
    if args['save-dir'] and np.mean(loss_val) < best_val_loss:
        print('Best model so far, saving...')
        torch.save(model.state_dict(), model_file)
        torch.save(optimizer.state_dict(), optimizer_file)
        torch.save(scheduler.state_dict(), scheduler_file)
        print('Epoch: {:04d}'.format(epoch),
                  'loss_train: {:.10f}'.format(np.mean(loss_train)),
                  'auc_train: {:.10f}'.format(np.mean(auc_train)),
                  'acc_train: {:.10f}'.format(np.mean(acc_train)),
                  'loss_val: {:.10f}'.format(np.mean(loss_val)),
                  'auc_val: {:.10f}'.format(np.mean(auc_val)),
                  'acc_val: {:.10f}'.format(np.mean(acc_val)),
                  'time: {:.4f}s'.format(time.time() - t), file=log)
        log.flush()
    res = np.mean(loss_val)
    del loss_train
    del auc_train
    del acc_train
    del loss_val
    del auc_val
    del acc_val
    gc.collect()
    if args['cuda']:
        torch.cuda.empty_cache()
    return res

In [44]:
import numpy as np
import time
import random
import argparse
import pickle
import os
import gc
import datetime
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from sklearn.metrics import roc_auc_score
def test():
    loss_test = []
    kt_test = []
    vae_test = []
    auc_test = []
    acc_test = []

    if graph_model is not None:
        graph_model.eval()
    model.eval()
    model.load_state_dict(torch.load(model_file))
    with torch.no_grad():
        for batch_idx, (features, questions, answers) in enumerate(test_loader):
            if args.cuda:
                features, questions, answers = features.cuda(), questions.cuda(), answers.cuda()
            ec_list, rec_list, z_prob_list = None, None, None
            pred_res, ec_list, rec_list, z_prob_list = model(features, questions)
            
            loss_kt, auc, acc = kt_loss(pred_res, answers)
            loss_kt = float(loss_kt.cpu().detach().numpy())
            if auc != -1 and acc != -1:
                auc_test.append(auc)
                acc_test.append(acc)
            kt_test.append(loss_kt)
            loss = loss_kt
            loss_test.append(loss)
            del loss
    print('--------------------------------')
    print('--------Testing-----------------')
    print('--------------------------------')
    
    print('loss_test: {:.10f}'.format(np.mean(loss_test)),
              'auc_test: {:.10f}'.format(np.mean(auc_test)),
              'acc_test: {:.10f}'.format(np.mean(acc_test)))
    if args['save_dir']:
        print('--------------------------------', file=log)
        print('--------Testing-----------------', file=log)
        print('--------------------------------', file=log)
        
        print('loss_test: {:.10f}'.format(np.mean(loss_test)),
                  'auc_test: {:.10f}'.format(np.mean(auc_test)),
                  'acc_test: {:.10f}'.format(np.mean(acc_test)), file=log)
        log.flush()
    del loss_test
    del auc_test
    del acc_test
    gc.collect()
    if args['cuda']:
        torch.cuda.empty_cache()

if args['test'] is False:
    # Train model
    print('start training!')
    t_total = time.time()
    best_val_loss = np.inf
    best_epoch = 0
    for epoch in range(args['epochs']):
        val_loss = train(epoch, best_val_loss)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = epoch
    print("Optimization Finished!")
    print("Best Epoch: {:04d}".format(best_epoch))
    if args['save-dir']:
        print("Best Epoch: {:04d}".format(best_epoch), file=log)
        log.flush()

start training!
batch idx:  0 loss kt:  0.7155298590660095 auc:  0.5106289051888073 acc:  0.448 cost time:  2.471405267715454
Epoch: 0000 loss_train: 0.7155298591 auc_train: 0.5106289052 acc_train: 0.4480000000 loss_val: 0.7300035954 auc_val: 0.4746738747 acc_val: 0.4561855670 time: 3.0986s
Best model so far, saving...
batch idx:  0 loss kt:  0.7067127823829651 auc:  0.53248664312234 acc:  0.4736 cost time:  2.4328744411468506
Epoch: 0001 loss_train: 0.7067127824 auc_train: 0.5324866431 acc_train: 0.4736000000 loss_val: 0.7227399945 auc_val: 0.4818427547 acc_val: 0.4690721649 time: 3.0563s
Best model so far, saving...
batch idx:  0 loss kt:  0.7003569602966309 auc:  0.5517635606266413 acc:  0.4976 cost time:  2.4646854400634766
Epoch: 0002 loss_train: 0.7003569603 auc_train: 0.5517635606 acc_train: 0.4976000000 loss_val: 0.7214264870 auc_val: 0.4856034787 acc_val: 0.4793814433 time: 3.0793s
Best model so far, saving...
batch idx:  0 loss kt:  0.6946474313735962 auc:  0.5696708322013946

In [46]:
torch.save(model, '/content/model_50.pt')

In [11]:
new_model = torch.load('/content/model_50.pt')