In [18]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
import pandas as pd

In [19]:
# Nb of questions
nb_question = 30
# Nb of concepts
nb_concept = 10
# Embedding dimension
embedding_dim = 16
# Nb of attention heads
nb_head = 8
# Context window
context_size = 20

# Attentive Knowledge Tracing

In [20]:
class Head(nn.Module):
    def __init__(self, D = embedding_dim, head_size = embedding_dim//8, T = context_size, monotonic = False):
        super().__init__()

        # embedding dim
        self.D = D 

        # monotonic attention head or simple attention head
        self.monotonic = monotonic

        # decay parameter 
        if monotonic:
            self.theta_raw = nn.Parameter(torch.tensor(-2.0, dtype=torch.float32))

        # Dk = Dq = Dv
        self.head_size = head_size

        self.query_key = nn.Linear(D, head_size, bias=False)
        self.value = nn.Linear(D, head_size, bias=False)

        self.register_buffer('tril', torch.tril(torch.ones(T, T)))

    def forward(self, q_in, k_in, v_in): 
        # x is of size B, T, D
        B, Tq, D = q_in.shape #Tq = Tk = T-1
        Tk = k_in.size(1)
        assert v_in.size(1) == Tk


        k = self.query_key(k_in) #  B, Tk, head_size
        q = self.query_key(q_in) #  B, Tq, head_size
        v = self.value(v_in) # B, Tk, head_size

        # (B, Tq, head_size) @ (B, head_size, Tq) = (B, Tq, Tq)
        weights = q @ k.transpose(-2, -1) * self.head_size**(-0.5)

        
        mask = self.tril[:Tq, :Tk]
        
        if self.monotonic:
            # COMPUTE d(t, tau)
            with torch.no_grad():

                scores_masked = weights.masked_fill(mask == 0, -1e32)
                gamma = F.softmax(scores_masked, dim = -1)
                gamma = gamma * mask.float()

                prefix = torch.cumsum(gamma, dim = -1)  # (B,Tq,Tk)

                temp = torch.arange(Tq, device=q_in.device)

                # prefix_at_t = prefix[b,t,t]
                prefix_at_t = prefix[:, temp, temp].unsqueeze(-1)  # (B,Tq,1)
                sum_tau1_to_t = prefix_at_t - prefix  # (B,Tq,Tq)
            
                abs_dtau = (temp.view(1, Tq, 1) - temp.view(1, 1, Tq)).abs().float()  # (1,Tq,Tq)
                d = torch.clamp(abs_dtau * sum_tau1_to_t, min = 0.) # (B, Tq, Tk), clamp to be non-negative

                d = d.sqrt().detach()

            decay_rate = F.softplus(self.theta_raw)
            # Clamp for numerical stability
            decay_rate = torch.clamp(decay_rate, max=10.0)

            # Compute final weights
            factor = torch.clamp(torch.clamp((-d*decay_rate).exp(), min = 1e-5), max = 1e5) # of size (B,Tq, Tq)
            weights = weights * factor # of size (B,Tq, Tq)

        # Apply mask after monotonic attention
        weights = weights.masked_fill(mask == 0, float('-inf'))
        weights = F.softmax(weights, dim = -1)

        # (B, T, T) @ (B, T, head_size) = (B, T, head_size)
        out = weights @ v

        return out # (B, T, head_size)

In [21]:
class Multi(nn.Module):
    def __init__(self, nb_head, monotonic, D = embedding_dim, head_size = embedding_dim//8, T = context_size, dropout = 0.2):
        super().__init__()
        self.nb_head = nb_head
        self.head_size = head_size
        
        # selfcount = False if monotonic True otherwise
        self.heads = nn.ModuleList([Head(D, head_size, T, monotonic) for _ in range(nb_head)])

        self.proj = nn.Linear(nb_head * head_size, D)
        self.drop = nn.Dropout(dropout)


    def forward(self, q_in, k_in, v_in):
        out = torch.cat([h(q_in, k_in, v_in) for h in self.heads], dim=-1)        # (B,T-1,nb_head*head_size)
        out = self.drop(self.proj(out))    # (B,T-1,D)

        return out
        
    

In [22]:
class EncoderBlock(nn.Module):
    def __init__(self, nb_head = 8, D = embedding_dim, head_size = embedding_dim//8, T = context_size, dropout = 0.2):
        super().__init__()

        self.sa_encoder = Multi(nb_head,False,D,head_size, T, dropout)

        self.norm1 = nn.LayerNorm(D)
        self.act1 = nn.ReLU()
        self.drop1 = nn.Dropout(dropout)

        self.ffn = nn.Sequential(
            nn.Linear(D, 4 * D),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(4 * D, D),
        )

        self.norm2 = nn.LayerNorm(D)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x):
        out = self.sa_encoder(x, x, x)
        out = self.norm1(x + self.drop1(out))          # (B,T-1,D)
        ffn_out = self.ffn(out)
        out = self.norm2(out + self.drop2(ffn_out))
        return out



In [23]:
class DecoderBlock(nn.Module):
    def __init__(self, nb_head = 8, D = embedding_dim, head_size = embedding_dim//8, T = context_size, dropout = 0.2):
        super().__init__()

        self.sa_decoder = Multi(nb_head,True,D,head_size, T, dropout)

        self.norm1 = nn.LayerNorm(D)
        self.act1 = nn.ReLU()
        self.drop1 = nn.Dropout(dropout)

        hidden = 4 * D
        self.ffn = nn.Sequential(
            nn.Linear(D, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, D),
        )

        self.norm2 = nn.LayerNorm(D)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, q_in, k_in, v_in):
        out = self.sa_decoder(q_in, k_in, v_in)
        out = self.norm1(q_in + self.drop1(out))         # (B,T-1,D)
        ffn_out = self.ffn(out)
        out = self.norm2(out + self.drop2(ffn_out))
        return out



In [24]:
class AKT(nn.Module):

    def __init__(self, C = nb_concept, Q = nb_question, D = embedding_dim):
        """
        AKT
        C : int = nb Concepts
        Q : int = nb Questions
        D : int = embedding Dimension
        """
        super().__init__()

        self.C = C
        self.Q = Q
        self.D = D

        # Rasch model-based embeddings
        # c_c
        self.c_embedding = nn.Embedding(C, D)

        # d_c
        self.d_embedding = nn.Embedding(C, D)

        # difficulty mu_q
        self.mu_embedding = nn.Embedding(Q, 1) 

        # correct or wrong answer g_r
        self.g_embedding = nn.Embedding(2, D)

        # f_(c, r)
        self.f_embedding = nn.Embedding(2*C, D)


        self.question_encoder = EncoderBlock(nb_head,D,D//8,context_size,0.2)
        self.knowledge_encoder = EncoderBlock(nb_head,D,D//8,context_size,0.2)
        self.knowledge_retriever = DecoderBlock(nb_head,D,D//8,context_size,0.2)

        self.prediction_layer = nn.Sequential(
            nn.Linear(2*D, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1)
        )
    

    def forward(self, idx):
        B, T, _ = idx.shape

        # of size B*T
        question = idx[:,:,0] 
        concept = idx[:,:,1]
        result = idx[:,:,2]
        
        x = self.c_embedding(concept) + self.mu_embedding(question)*self.d_embedding(concept) # size B, T, D
        y = self.c_embedding(concept) + self.g_embedding(result) + self.mu_embedding(question)*self.f_embedding(concept+result*self.C) # size B, T, D
        
        x_hat = self.question_encoder(x) # size B, T, D
        y_hat = self.knowledge_encoder(y) # size B, T, D

        x_q = x_hat[:, 1:, :] # queries B,T-1,D : question at time t
        x_k = x_hat[:, :-1, :] # keys B,T-1,D : questions up to t-1
        y_v = y_hat[:, :-1, :] # values B,T-1,D  : responses up to t-1

        h = self.knowledge_retriever(x_q,x_k,y_v) # size B, T, D

        out = torch.cat([h, x[:, 1:, :]], dim=-1) # B, T-1, 2D
        out = self.prediction_layer(out) # size B, T-1, 1
        out = torch.sigmoid(out) # size B, T-1, 1
        return out



In [25]:
model = AKT()

In [26]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)