This may only be useful as an illustration for how to apply index features and emebddings into graph convolutions. Look at the code marked with `******` to see how this works. 

In [7]:
from collections import defaultdict 

import torch
import dgl
import numpy as np 
import scipy.sparse as sp

In [8]:
n_users = 1000
n_items = 1000
n_kcs = 100

kc_min = 1
kc_max = 10

n_prob_dists = 5
n_ics = n_items + n_kcs 

q_dict = dict()
probs = np.random.dirichlet(np.ones(n_kcs), size=n_prob_dists)
for item_id in range(n_items): 
    kc_samples = np.random.choice(
        a=np.arange(n_kcs), 
        size=np.random.randint(kc_min, kc_max, 1), 
        p=probs[np.random.randint(0, n_prob_dists)],
        replace=False,
    )
    q_dict[item_id] = kc_samples

In [10]:
import dgl
from dgl.nn import GraphConv
import torch.nn as nn



def q_dict_to_sp_extended_csr(q_dict) -> sp.spmatrix: 
    # Assumes standardised IDs
    item_ids = []
    kc_ids = []
    for item_id, item_kcs in q_dict.items(): 
        for kc_id in item_kcs: 
            item_ids.append(item_id)
            kc_ids.append(kc_id)
    item_ids = np.asarray(item_ids)
    kc_ids = np.asarray(kc_ids)
    data = np.ones_like(item_ids).astype(float)
    
    # Convenience vars 
    n_items = len(q_dict)
    n_kcs = len(set(kc_ids))
    n_ics = n_items + n_kcs

    # Make the n_items x n_kcs matrix
    q_matrix = sp.csr_matrix((data, (item_ids, kc_ids)), shape=(n_items, n_kcs))
    
    # make the n_ics x n_ics matrix; offset KC IDs by n_items
    q_matrix_extended = sp.csr_matrix((data, (item_ids, kc_ids + n_items)), shape=(n_ics, n_ics))
    q_matrix_extended += q_matrix_extended.T + sp.eye(n_ics)

    return q_matrix_extended



class GCQEmbedding(nn.Module):
    def __init__(self, q_dict, n_feats=10, n_out=3, n_levels=2):
        super().__init__()
        
        assert isinstance(q_dict, dict)
        
        # Get the sparse extended matrix, and construct the DGL graph object
        self.q_matrix_extended = q_dict_to_sp_extended_csr(q_dict=q_dict)
        self.dgl_graph = dgl.from_scipy(self.q_matrix_extended, eweight_name='w')
        
        # Instantiate the node embeddings, graph convolutions, and prediction modules
        # NIKHIL: ********* \/
        self.node_embeddings = nn.Embedding(self.dgl_graph.num_nodes(), n_feats)
        self.gconvs = nn.ModuleList()
        assert n_levels > 0
        for _ in range(n_levels):
            self.gconvs.append(GraphConv(n_feats, n_feats))
        self.pred = nn.Linear(n_feats, n_out)
        
    def graph_embeddings(self, embeddings=None):
        # Calculate the graph embeddings
        zz = self.node_embeddings(torch.arange(self.dgl_graph.num_nodes()))
        for gconv in self.gconvs: 
            zz = gconv(self.dgl_graph, zz.relu())
        return self.pred(zz.relu())

    def forward(self, item_ids):
        # Calculate the predicted embeddings
        embeddings = self.graph_embeddings()
        # return embeddings[item_ids]
        pred = nn.functional.embedding(weight=embeddings, input=item_ids)
        return pred
    
class GCN_MovieLens(nn.Module): 
    def __init__(self, n_users, n_items, q_dict=None, n_feats=20, n_levels=2, k=3): 
        super().__init__()
        self.user_embeddings = nn.Embedding(n_users, n_feats) 
        self.item_embeddings = GCQEmbedding(
            q_dict=q_dict, 
            n_feats=n_feats,
            n_levels=n_levels,
            n_out=n_feats
        )
        
    def forward(self, user_ids, item_ids): 
        user_embeddings = self.user_embeddings(user_ids)
        item_embeddings = self.item_embeddings(item_ids)
        return torch.cat([user_embeddings, item_embeddings], dim=1)




class GCN_IRT(nn.Module): 
    def __init__(self, n_users, n_items, q_dict=None, n_feats=20, n_levels=2, k=3): 
        super().__init__()
        self.user_embeddings = nn.Embedding(n_users, 1) 
        self.item_embeddings = GCQEmbedding(
            q_dict=q_dict, 
            n_feats=n_feats,
            n_levels=n_levels,
            n_out=k
        )
        
        self.k = k
        
        # Note: Can easily toggle on/off GCN components by toggling between embedding 
        #       types with the same forward signature
        # if q_dict is None:
        #     self.item_embeddings = nn.Embedding(n_items, k)
        
    def forward(self, user_ids, item_ids): 
        # User
        ability = self.user_embeddings(user_ids)
        
        # Item
        item_params = self.item_embeddings(item_ids)
        assert item_params.shape[1] == self.k
        diff = item_params[:, [0]]
        disc = item_params[:, [1]].exp()           # Helps identifiability
        guess = item_params[:, [2]].sigmoid() / 4  # Helps identifiability
        
        # 2PL IRT prob
        logits = (ability - diff) * disc
        prob = logits.sigmoid()
        
        # 3PL IRT prob
        return guess + (1 - guess) * prob


model = GCN_IRT(
    n_users=n_users, 
    n_items=n_items, 
    q_dict=q_dict, 
)

n_ints = 100

user_ids = torch.randint(0, n_users, (n_ints,))
item_ids = torch.randint(0, n_items, (n_ints,))

pp = model(user_ids=user_ids, item_ids=item_ids).detach()
pp

tensor([[0.3515],
        [0.2164],
        [0.5256],
        [0.4713],
        [0.4606],
        [0.8242],
        [0.3626],
        [0.5044],
        [0.6287],
        [0.3689],
        [0.4438],
        [0.7753],
        [0.4904],
        [0.7105],
        [0.6538],
        [0.5350],
        [0.7593],
        [0.4709],
        [0.9232],
        [0.6271],
        [0.2440],
        [0.6542],
        [0.8611],
        [0.1874],
        [0.7029],
        [0.6656],
        [0.3428],
        [0.7177],
        [0.2185],
        [0.3276],
        [0.3766],
        [0.4166],
        [0.8782],
        [0.6927],
        [0.8452],
        [0.6744],
        [0.6576],
        [0.2369],
        [0.2931],
        [0.6390],
        [0.2596],
        [0.5329],
        [0.7306],
        [0.6138],
        [0.5707],
        [0.4800],
        [0.8383],
        [0.8356],
        [0.9345],
        [0.2001],
        [0.9314],
        [0.9457],
        [0.5737],
        [0.2621],
        [0.7222],
        [0