In [2]:
import torch
from torch import nn
from torch.nn import Module, Embedding, Linear, Dropout, MaxPool1d, Sequential, ReLU
import copy
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F


device = "cpu" if not torch.cuda.is_available() else "cuda"

class transformer_FFN(Module):
    def __init__(self, emb_size, dropout) -> None:
        super().__init__()
        self.emb_size = emb_size
        self.dropout = dropout
        self.FFN = Sequential(
                Linear(self.emb_size, self.emb_size),
                ReLU(),
                Dropout(self.dropout),
                Linear(self.emb_size, self.emb_size),
                # Dropout(self.dropout),
            )
    def forward(self, in_fea):
        return self.FFN(in_fea)

def ut_mask(seq_len):
    """ Upper Triangular Mask
    """
    return torch.triu(torch.ones(seq_len,seq_len),diagonal=1).to(dtype=torch.bool).to(device)

def lt_mask(seq_len):
    """ Upper Triangular Mask
    """
    return torch.tril(torch.ones(seq_len,seq_len),diagonal=-1).to(dtype=torch.bool).to(device)

def pos_encode(seq_len):
    """ position Encoding
    """
    return torch.arange(seq_len).unsqueeze(0).to(device)

def get_clones(module, N):
    """ Cloning nn modules
    """
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [3]:
import torch
from torch import nn
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
import torch.nn.init as init
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch.nn.functional as F
from enum import IntEnum
from torch.nn.parameter import Parameter
import numpy as np
from torch.nn import Module, Embedding, LSTM, Linear, Dropout, LayerNorm, TransformerEncoder, TransformerEncoderLayer, \
        MultiLabelMarginLoss, MultiLabelSoftMarginLoss, CrossEntropyLoss, BCELoss, MultiheadAttention
from torch.nn.functional import one_hot, cross_entropy, multilabel_margin_loss, binary_cross_entropy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Dim(IntEnum):# 定义张量维度名称，增强代码可读性
    batch = 0
    seq = 1
    feature = 2
class CSKT(nn.Module):
    def __init__(self, n_question, n_pid, 
            d_model, n_blocks, dropout, d_ff=256, 
            loss1=0.5, loss2=0.5, loss3=0.5, start=50, num_layers=2, nheads=4, seq_len=512, r=1, gamma=1, 
            kq_same=1, final_fc_dim=512, final_fc_dim2=256, num_attn_heads=8, separate_qa=False, l2=1e-5, emb_type="qid", emb_path="", pretrain_dim=768):
        super().__init__()
        self.model_name = "cskt"
        self.n_question = n_question
        self.dropout = dropout
        self.kq_same = kq_same
        self.n_pid = n_pid
        self.l2 = l2
        self.model_type = self.model_name
        self.separate_qa = separate_qa
        self.emb_type = emb_type
        embed_l = d_model
        


        self.r = r
        self.gamma = gamma

        if self.n_pid > 0:
            if emb_type.find("scalar") != -1: 
                self.difficult_param = nn.Embedding(self.n_pid+1, 1) 
            else:
                self.difficult_param = nn.Embedding(self.n_pid+1, embed_l)  
            self.q_embed_diff = nn.Embedding(self.n_question+1, embed_l) 
            self.qa_embed_diff = nn.Embedding(2 * self.n_question + 1, embed_l) 
        if emb_type.startswith("qid"):
            self.q_embed = nn.Embedding(self.n_question, embed_l)
            if self.separate_qa: 
                    self.qa_embed = nn.Embedding(2*self.n_question+1, embed_l)
            else: 
                self.qa_embed = nn.Embedding(2, embed_l)
        self.model = Architecture(n_question=n_question, n_blocks=n_blocks, n_heads=num_attn_heads, dropout=dropout,
                                    d_model=d_model, d_feature=d_model / num_attn_heads, d_ff=d_ff,  kq_same=self.kq_same, model_type=self.model_type, seq_len=seq_len, 
                                    r = r, gamma=gamma)
    
        self.out = nn.Sequential(
            nn.Linear(d_model + embed_l,
                      final_fc_dim), nn.ReLU(), nn.Dropout(self.dropout),
            nn.Linear(final_fc_dim, final_fc_dim2), nn.ReLU(
            ), nn.Dropout(self.dropout),
            nn.Linear(final_fc_dim2, 1)
        )

        self.reset()

    def reset(self):
        for p in self.parameters():
            if p.size(0) == self.n_pid+1 and self.n_pid > 0:
                torch.nn.init.constant_(p, 0.)

    def base_emb(self, q_data, target):
        q_embed_data = self.q_embed(q_data)  # BS, seqlen,  d_model# c_ct
        if self.separate_qa:
            qa_data = q_data + self.n_question * target
            qa_embed_data = self.qa_embed(qa_data)
        else:
            # BS, seqlen, d_model # c_ct+ g_rt =e_(ct,rt)
            qa_embed_data = self.qa_embed(target)+q_embed_data
        return q_embed_data, qa_embed_data

    def get_attn_pad_mask(self, sm):
        batch_size, l = sm.size()
        pad_attn_mask = sm.data.eq(0).unsqueeze(1)
        pad_attn_mask = pad_attn_mask.expand(batch_size, l, l)
        return pad_attn_mask.repeat(self.nhead, 1, 1)
    def forward(self, dcur, qtest=False, train=False):
        q, c, r = dcur["qseqs"].long(), dcur["cseqs"].long(), dcur["rseqs"].long()
        qshft, cshft, rshft = dcur["shft_qseqs"].long(), dcur["shft_cseqs"].long(), dcur["shft_rseqs"].long()
        pid_data = torch.cat((q[:,0:1], qshft), dim=1).to(device)
        q_data = torch.cat((c[:,0:1], cshft), dim=1).to(device)
        target = torch.cat((r[:,0:1], rshft), dim=1).to(device)

        emb_type = self.emb_type

        if emb_type.startswith("qid"):
            q_embed_data, qa_embed_data = self.base_emb(q_data, target)
        if self.n_pid > 0 and emb_type.find("norasch") == -1: # have problem id
        
            if emb_type.find("aktrasch") == -1:
                q_embed_diff_data = self.q_embed_diff(q_data)  # 
                pid_embed_data = self.difficult_param(pid_data)  # 
                q_embed_data = q_embed_data + pid_embed_data * \
                    q_embed_diff_data 

            else:
                q_embed_diff_data = self.q_embed_diff(q_data)  # 
                pid_embed_data = self.difficult_param(pid_data)  # 
                q_embed_data = q_embed_data + pid_embed_data * \
                    q_embed_diff_data 

                qa_embed_diff_data = self.qa_embed_diff(
                    target)  # 
                qa_embed_data = qa_embed_data + pid_embed_data * \
                        (qa_embed_diff_data+q_embed_diff_data)  
        y2, y3 = 0, 0
        if emb_type in ["qid", "qidaktrasch", "qid_scalar", "qid_norasch"]:
            d_output = self.model(q_embed_data, qa_embed_data)

            concat_q = torch.cat([d_output, q_embed_data], dim=-1)
            output = self.out(concat_q).squeeze(-1)
            m = nn.Sigmoid()
            preds = m(output)

        if train:
            return preds, y2, y3
        else:
            if qtest:
                return preds, concat_q
            else:
                return preds

class Architecture(nn.Module):
    def __init__(self, n_question,  n_blocks, d_model, d_feature,
                 d_ff, n_heads, dropout, kq_same, model_type, seq_len, r, gamma):
        super().__init__()
        """
            n_block : number of stacked blocks in the attention
            d_model : dimension of attention input/output
            d_feature : dimension of input in each of the multi-head attention part.
            n_head : number of heads. n_heads*d_feature = d_model
        """
        self.d_model = d_model
        self.model_type = model_type

        if model_type in {'cskt'}:
            self.blocks_2 = nn.ModuleList([
                TransformerLayer(d_model=d_model, d_feature=d_model // n_heads,
                                 d_ff=d_ff, dropout=dropout, n_heads=n_heads, kq_same=kq_same, seq_len = seq_len, r=r, gamma=gamma)
                for _ in range(n_blocks)
            ])

    def forward(self, q_embed_data, qa_embed_data):
        seqlen, batch_size = q_embed_data.size(1), q_embed_data.size(0)


        qa_pos_embed = qa_embed_data
        q_pos_embed = q_embed_data

        y = qa_pos_embed
        seqlen, batch_size = y.size(1), y.size(0)
        x = q_pos_embed

        
        for block in self.blocks_2:
            x = block(mask=0, query=x, key=x, values=y, apply_pos=True) #
            
        return x

class TransformerLayer(nn.Module):
    def __init__(self, d_model, d_feature,
                 d_ff, n_heads, dropout,  kq_same, seq_len, r, gamma):
        super().__init__()
        """
            This is a Basic Block of Transformer paper. It containts one Multi-head attention object. Followed by layer norm and postion wise feedforward net and dropout layer.
        """
        kq_same = kq_same == 1

        self.masked_attn_head = MultiHeadAttention(
            d_model, d_feature, n_heads, dropout, kq_same=kq_same, seq_len=seq_len, r=r, gamma=gamma)


        self.layer_norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        self.linear1 = nn.Linear(d_model, d_ff)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, mask, query, key, values, apply_pos=True):
        seqlen, batch_size = query.size(1), query.size(0)
        nopeek_mask = np.triu(
            np.ones((1, 1, seqlen, seqlen)), k=mask).astype('uint8')
        src_mask = (torch.from_numpy(nopeek_mask) == 0).to(device)
        if mask == 0: 
            query2 = self.masked_attn_head(
                query, key, values, mask=src_mask, zero_pad=True) 
        else:
            query2 = self.masked_attn_head(
                query, key, values, mask=src_mask, zero_pad=False)

        query = query + self.dropout1((query2)) #
        query = self.layer_norm1(query) 
        if apply_pos:
            query2 = self.linear2(self.dropout( 
                self.activation(self.linear1(query))))
            query = query + self.dropout2((query2)) # 
            query = self.layer_norm2(query) 
        return query
import torch
import torch.nn as nn

class RelativePositionBias(nn.Module):
    def __init__(self, max_len, num_heads): 
        super().__init__()
        self.max_len = max_len
        self.num_heads = num_heads

        self.rel_pos_bias = nn.Parameter(torch.zeros(num_heads, 2 * max_len - 1))
        nn.init.normal_(self.rel_pos_bias, std=0.02) 

    def forward(self, seq_len):
        device = self.rel_pos_bias.device
        range_vec = torch.arange(seq_len, device=device) 
        distance_mat = range_vec[None, :] - range_vec[:, None] 
        distance_mat_clipped = distance_mat + self.max_len - 1 
        assert distance_mat_clipped.min() >= 0 and distance_mat_clipped.max() < 2 * self.max_len - 1

        values = self.rel_pos_bias[:, distance_mat_clipped] 
        return values  

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_feature, n_heads, dropout, kq_same, r, gamma, seq_len=512,bias=True):
        super().__init__()
        """
        It has projection layer for getting keys, queries and values. Followed by attention and a connected layer.
        """
        self.d_model = d_model
        self.d_k = d_feature
        self.h = n_heads
        self.kq_same = kq_same

        self.v_linear = nn.Linear(d_model, d_model, bias=bias)
        self.k_linear = nn.Linear(d_model, d_model, bias=bias)
        if kq_same is False:
            self.q_linear = nn.Linear(d_model, d_model, bias=bias)
        self.num_heads = n_heads
        self.dropout = nn.Dropout(dropout)
        self.proj_bias = bias
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)
        self.seq_len = seq_len
        self.r = r
        self.gamma = gamma
        self.rel_pos_bias_module = RelativePositionBias(self.seq_len, self.num_heads)
        self.kernel_bias = ParallelKerpleLog(self.h)

    def _reset_parameters(self):
        xavier_uniform_(self.k_linear.weight)
        xavier_uniform_(self.v_linear.weight)
        if self.kq_same is False:
            xavier_uniform_(self.q_linear.weight)

        if self.proj_bias:
            constant_(self.k_linear.bias, 0.)
            constant_(self.v_linear.bias, 0.)
            if self.kq_same is False:
                constant_(self.q_linear.bias, 0.)
            constant_(self.out_proj.bias, 0.)

    def forward(self, q, k, v, mask, zero_pad):

        bs = q.size(0)
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        if self.kq_same is False:
            q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        else:
            q = self.k_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        rel_pos_bias_module = RelativePositionBias(self.seq_len, self.num_heads)
        rel_pos_emb_tensor = rel_pos_bias_module(self.seq_len) 
        scores = attention(q, k, v, self.d_k,
                   mask, self.dropout, zero_pad, self.r, self.gamma, self.kernel_bias, self.rel_pos_bias_module, alpha=0.5)
    
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out_proj(concat)

        return output

import torch
import torch.nn.functional as F

def disentangled_attention_scores(q, k, rel_pos_bias_module, mask=None):
    scores = torch.matmul(q, k.transpose(-2, -1))  
    seq_len = q.size(-2)
    rel_pos_emb = rel_pos_bias_module(seq_len)  # (num_heads, seq_len, seq_len)

    assert rel_pos_emb.size(0) == scores.size(1), f"rel_pos_emb heads {rel_pos_emb.size(0)} != scores heads {scores.size(1)}"
    scores = scores + rel_pos_emb.unsqueeze(0)  # (1, head, seq_len, seq_len)
    scores = scores / (q.size(-1) ** 0.5)

    if mask is not None:
        # 确保mask形状匹配 (batch, 1, seq_len, seq_len)
        if mask.dim() == 3:
            mask = mask.unsqueeze(1)
        # 防止全部-∞导致softmax输出nan，先确保mask不全是0
        mask_sum = mask.sum(dim=-1, keepdim=True)
        mask = torch.where(mask_sum == 0, torch.ones_like(mask), mask)
        scores = scores.masked_fill(mask == 0, float("-inf"))

    return scores

def retnet_retention(q, k, v, mask, decay=0.9):
    batch, head, seq_len, d_k = q.size()
    output = torch.zeros_like(v) # output：最终输出序列
    y = torch.zeros(batch, head, d_k).to(q.device) # y：用来保存前面累积的信息（初始化为0）

    # 保证 decay 在 [0, 1] 范围内
    decay = max(min(decay, 1.0), 0.0)

    for t in range(seq_len):
        x_t = v[:, :, t, :]
        y = decay * y + x_t
        output[:, :, t, :] = y

    # 如果mask非空，置0无效位置
    if mask is not None:
        # mask: (batch, 1, seq_len, seq_len)，对t对应维度置0
        mask_t = mask[:, :, t, :].squeeze(1) if mask.dim() == 4 else None
        if mask_t is not None:
            output = output * mask_t.unsqueeze(-1).float()

    return output

def disentangled_retnet_attention(q, k, v, mask, dropout, rel_pos_bias_module, alpha=0.5):
    """结合传统注意力和 RetNet 机制"""
    scores = disentangled_attention_scores(q, k, rel_pos_bias_module, mask)
    attn = F.softmax(scores, dim=-1)
    attn = dropout(attn)
    attn_output = torch.matmul(attn, v)

    ret_output = retnet_retention(q, k, v, mask)

    output = attn_output + alpha * ret_output
    return output


def attention(q, k, v, d_k, mask, dropout, zero_pad, r, gamma, kernel_bias, rel_pos_bias_module, alpha=0.5):
    output = disentangled_retnet_attention(q, k, v, mask, dropout, rel_pos_bias_module, alpha)

    if zero_pad:
        bs, head, seqlen = output.size(0), output.size(1), output.size(2)
        pad_zero = torch.zeros(bs, head, 1, output.size(-1)).to(q.device)
        output = torch.cat([pad_zero, output[:, :, 1:, :]], dim=2)

    return output

# def attention(q, k, v, d_k, mask, dropout, zero_pad, r, gamma, kernel_bias, rel_pos_bias_module, alpha=0.5):
#     """只有retnet"""
#     return retnet_retention(q, k, v, mask)

# def attention(q, k, v, d_k, mask, dropout, zero_pad, r, gamma, kernel_bias, rel_pos_bias_module, alpha=0.5):
#     """disentangled_attention_only+相对位置偏置"""
#     scores = disentangled_attention_scores(q, k, rel_pos_bias_module, mask)
#     attn = F.softmax(scores, dim=-1)
#     attn = dropout(attn)
#     return torch.matmul(attn, v)