# Practical Work

In [90]:
import pandas as pd

## Create the dataset files for RecBole

### Dataset structure
Recbole knowledge aware datasets require three files: .inter, .g, .link files. 

| File | Description |
|------|-------------|
|.inter|User-Item interaction|
|.kg| head, relation, tail|
|.link|item_id to entity_id|

## Create atomic files

 Add type info to kg generated by David

In [91]:
dataset = '05percent_subset'
path = 'data/lfm'
kg_path = '%s/intermediate_kg/%s' % (path, dataset)
prefixed_path = '%s/%s' % (path,dataset)

In [92]:
kg = pd.read_csv('%s_kg.txt' % kg_path, sep='\t', names=['head_id:token', 'relation_id:token', 'tail_id:token'])

In [93]:
kg.head()

Unnamed: 0,head_id:token,relation_id:token,tail_id:token
0,t31469731,in_album,b4202003
1,t11520441,in_album,b14810089
2,t33076177,in_album,b17095418
3,t42806818,in_album,b20270611
4,t25359723,in_album,b19286007


In [94]:
kg['relation_id:token'].unique()

array(['in_album', 'created_by', 'lives_in', 'listened_to',
       'has_micro_genre', 'has_genre', 'has_gender'], dtype=object)

In [95]:
kg_no_listen_events = kg[kg['relation_id:token'] != 'listened_to']
kg_no_listen_events.to_csv('data/rb_lfm/rb_lfm.kg', sep='\t', index=False)

In [96]:
users = pd.read_csv('%s_users.tsv' % prefixed_path, sep='\t', skiprows=[0], names=['user_id:token', 'country:token', 'age:token', 'gender:token', 'creation_time:token'])
users.to_csv('data/rb_lfm/rb_lfm.user', sep='\t', index=False)

In [97]:
items = pd.read_csv('%s_tracks.tsv' % prefixed_path, sep='\t', skiprows=[0], names=['item_id:token', 'artist:token', 'track:token'])
items.to_csv('data/rb_lfm/rb_lfm.item', sep='\t', index=False)
                    
track_ids = pd.DataFrame(items['item_id:token'])
track_ids['entity_id:token'] = 't' + track_ids['item_id:token'].astype(str)
track_ids.to_csv('data/rb_lfm/rb_lfm.link', sep='\t', index=False)

In [98]:
listening_events = pd.read_csv('%s_listening_events.tsv' % prefixed_path, sep='\t', skiprows=[0], names=['user_id:token', 'item_id:token', 'album_id:token', 'timestamp:token'])
listening_events.to_csv('data/rb_lfm/rb_lfm.inter', sep='\t', index=False)

In [99]:
# free memory
del kg
del kg_no_listen_events
del items
del track_ids
del users

## Demo run model

In [100]:
#from recbole.quick_start import run_recbole
#run_recbole(model='KGAT', dataset='rb_lfm', config_file_list=['lfm.yaml'])

# Custom pipeline

### Load config and create Dataset

In [1]:
from logging import getLogger
from recbole.config import Config
from recbole.data import create_dataset, data_preparation
from recbole.model.knowledge_aware_recommender import KGAT
from recbole.trainer import KGATTrainer
from recbole.utils import init_seed, init_logger

In [2]:

# configurations initialization
config = Config(model='KGAT', dataset='rb_lfm', config_file_list=['lfm.yaml'])

# init random seed
init_seed(config['seed'], config['reproducibility'])

# logger initialization
init_logger(config)
logger = getLogger()

# write config info into log
logger.info(config)

# dataset creating and filtering
dataset = create_dataset(config)
logger.info(dataset)

# dataset splitting
train_data, valid_data, test_data = data_preparation(config, dataset)


13 Feb 11:50    INFO  
General Hyper Parameters:
gpu_id = 0
use_gpu = True
seed = 2020
state = debug
reproducibility = True
data_path = ./data/rb_lfm
checkpoint_dir = saved
show_progress = True
save_dataset = False
dataset_save_path = None
save_dataloaders = False
dataloaders_save_path = None
log_wandb = False

Training Hyper Parameters:
epochs = 20
train_batch_size = 2048
learner = adam
learning_rate = 0.001
train_neg_sample_args = {'distribution': 'uniform', 'sample_num': 1, 'alpha': 1.0, 'dynamic': False, 'candidate_num': 0}
eval_step = 1
stopping_step = 10
clip_grad_norm = None
weight_decay = 0.0
loss_decimal_place = 4

Evaluation Hyper Parameters:
eval_args = {'split': {'RS': [8, 1, 1]}, 'group_by': 'user', 'order': 'RO', 'mode': 'full'}
repeatable = False
metrics = ['NDCG', 'Hit', 'Precision']
topk = [10]
valid_metric = NDCG@10
valid_metric_bigger = True
eval_batch_size = 2048
metric_decimal_place = 4

Dataset Hyper Parameters:
field_separator = 	
seq_separator =  
USER_ID_FIELD 

In [107]:
len(test_data)

13559

### Train model and evaluate

In [108]:
# model loading and initialization
model = KGAT(config, train_data._dataset).to(config['device'])
logger.info(model)

# trainer loading and initialization
trainer = KGATTrainer(config, model)

print('Starting to fit model')
# model training
best_valid_score, best_valid_result = trainer.fit(train_data, valid_data, saved=True, show_progress=True)

print('Evaluating model')
# model evaluation
test_result = trainer.evaluate(test_data, load_best_model=True)
print(test_result)

  d_inv = np.power(rowsum, -1).flatten()
12 Feb 20:13    INFO  KGAT(
  (user_embedding): Embedding(14068, 32)
  (entity_embedding): Embedding(134531, 32)
  (relation_embedding): Embedding(8, 32)
  (trans_w): Embedding(8, 1024)
  (aggregator_layers): ModuleList(
    (0): Aggregator(
      (message_dropout): Dropout(p=0.1, inplace=False)
      (W1): Linear(in_features=32, out_features=32, bias=True)
      (W2): Linear(in_features=32, out_features=32, bias=True)
      (activation): LeakyReLU(negative_slope=0.01)
    )
    (1): Aggregator(
      (message_dropout): Dropout(p=0.1, inplace=False)
      (W1): Linear(in_features=32, out_features=16, bias=True)
      (W2): Linear(in_features=32, out_features=16, bias=True)
      (activation): LeakyReLU(negative_slope=0.01)
    )
  )
  (tanh): Tanh()
  (mf_loss): BPRLoss()
  (reg_loss): EmbLoss()
)
Trainable parameters: 4766784


Starting to fit model


[1;35mTrain     0[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain     0[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:14    INFO  epoch 0 training [time: 16.92s, train_loss1: 113.4480, train_loss2: 68.2419]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:14    INFO  epoch 0 evaluating [time: 21.34s, valid_score: 0.018900]
12 Feb 20:14    INFO  valid result: 
ndcg@10 : 0.0189    hit@10 : 0.0833    precision@10 : 0.0099
12 Feb 20:14    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain     1[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain     1[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:14    INFO  epoch 1 training [time: 15.15s, train_loss1: 89.3312, train_loss2: 29.7134]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:15    INFO  epoch 1 evaluating [time: 21.87s, valid_score: 0.032700]
12 Feb 20:15    INFO  valid result: 
ndcg@10 : 0.0327    hit@10 : 0.1271    precision@10 : 0.0166
12 Feb 20:15    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain     2[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain     2[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:15    INFO  epoch 2 training [time: 15.31s, train_loss1: 82.6103, train_loss2: 18.7961]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:15    INFO  epoch 2 evaluating [time: 21.56s, valid_score: 0.035600]
12 Feb 20:15    INFO  valid result: 
ndcg@10 : 0.0356    hit@10 : 0.1388    precision@10 : 0.0183
12 Feb 20:15    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain     3[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain     3[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:15    INFO  epoch 3 training [time: 15.26s, train_loss1: 79.7589, train_loss2: 13.9513]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:16    INFO  epoch 3 evaluating [time: 22.01s, valid_score: 0.039100]
12 Feb 20:16    INFO  valid result: 
ndcg@10 : 0.0391    hit@10 : 0.1446    precision@10 : 0.0195
12 Feb 20:16    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain     4[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain     4[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:16    INFO  epoch 4 training [time: 15.18s, train_loss1: 77.6395, train_loss2: 10.9852]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:16    INFO  epoch 4 evaluating [time: 22.06s, valid_score: 0.042400]
12 Feb 20:16    INFO  valid result: 
ndcg@10 : 0.0424    hit@10 : 0.1576    precision@10 : 0.0213
12 Feb 20:16    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain     5[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain     5[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:17    INFO  epoch 5 training [time: 15.36s, train_loss1: 75.6580, train_loss2: 9.3092]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:17    INFO  epoch 5 evaluating [time: 21.75s, valid_score: 0.046100]
12 Feb 20:17    INFO  valid result: 
ndcg@10 : 0.0461    hit@10 : 0.1653    precision@10 : 0.0234
12 Feb 20:17    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain     6[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain     6[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:17    INFO  epoch 6 training [time: 15.39s, train_loss1: 73.6984, train_loss2: 7.9702]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:18    INFO  epoch 6 evaluating [time: 22.05s, valid_score: 0.049700]
12 Feb 20:18    INFO  valid result: 
ndcg@10 : 0.0497    hit@10 : 0.1824    precision@10 : 0.0252
12 Feb 20:18    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain     7[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain     7[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:18    INFO  epoch 7 training [time: 15.36s, train_loss1: 71.6748, train_loss2: 7.1267]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:18    INFO  epoch 7 evaluating [time: 22.12s, valid_score: 0.051500]
12 Feb 20:18    INFO  valid result: 
ndcg@10 : 0.0515    hit@10 : 0.1874    precision@10 : 0.026
12 Feb 20:18    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain     8[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain     8[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:19    INFO  epoch 8 training [time: 15.46s, train_loss1: 69.6969, train_loss2: 6.1352]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:19    INFO  epoch 8 evaluating [time: 22.29s, valid_score: 0.054800]
12 Feb 20:19    INFO  valid result: 
ndcg@10 : 0.0548    hit@10 : 0.1966    precision@10 : 0.0271
12 Feb 20:19    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain     9[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain     9[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:19    INFO  epoch 9 training [time: 15.48s, train_loss1: 67.7431, train_loss2: 5.5648]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:20    INFO  epoch 9 evaluating [time: 22.25s, valid_score: 0.056800]
12 Feb 20:20    INFO  valid result: 
ndcg@10 : 0.0568    hit@10 : 0.199    precision@10 : 0.0277
12 Feb 20:20    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain    10[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain    10[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:20    INFO  epoch 10 training [time: 15.39s, train_loss1: 65.5563, train_loss2: 4.8028]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:20    INFO  epoch 10 evaluating [time: 22.29s, valid_score: 0.057900]
12 Feb 20:20    INFO  valid result: 
ndcg@10 : 0.0579    hit@10 : 0.203    precision@10 : 0.0285
12 Feb 20:20    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain    11[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain    11[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:20    INFO  epoch 11 training [time: 15.30s, train_loss1: 63.0078, train_loss2: 4.5266]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:21    INFO  epoch 11 evaluating [time: 22.18s, valid_score: 0.060800]
12 Feb 20:21    INFO  valid result: 
ndcg@10 : 0.0608    hit@10 : 0.2102    precision@10 : 0.0295
12 Feb 20:21    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain    12[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain    12[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:21    INFO  epoch 12 training [time: 15.49s, train_loss1: 60.4409, train_loss2: 4.1284]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:21    INFO  epoch 12 evaluating [time: 22.16s, valid_score: 0.062900]
12 Feb 20:21    INFO  valid result: 
ndcg@10 : 0.0629    hit@10 : 0.2173    precision@10 : 0.0303
12 Feb 20:21    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain    13[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain    13[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:22    INFO  epoch 13 training [time: 15.48s, train_loss1: 57.7186, train_loss2: 3.9206]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:22    INFO  epoch 13 evaluating [time: 22.42s, valid_score: 0.063900]
12 Feb 20:22    INFO  valid result: 
ndcg@10 : 0.0639    hit@10 : 0.2205    precision@10 : 0.0306
12 Feb 20:22    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain    14[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain    14[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:22    INFO  epoch 14 training [time: 15.40s, train_loss1: 54.8881, train_loss2: 3.6452]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:23    INFO  epoch 14 evaluating [time: 22.04s, valid_score: 0.064900]
12 Feb 20:23    INFO  valid result: 
ndcg@10 : 0.0649    hit@10 : 0.2256    precision@10 : 0.0315
12 Feb 20:23    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain    15[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain    15[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:23    INFO  epoch 15 training [time: 15.52s, train_loss1: 52.0540, train_loss2: 3.2839]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:23    INFO  epoch 15 evaluating [time: 22.13s, valid_score: 0.065300]
12 Feb 20:23    INFO  valid result: 
ndcg@10 : 0.0653    hit@10 : 0.2278    precision@10 : 0.0317
12 Feb 20:23    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain    16[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain    16[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:24    INFO  epoch 16 training [time: 15.30s, train_loss1: 49.2599, train_loss2: 2.9989]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:24    INFO  epoch 16 evaluating [time: 21.97s, valid_score: 0.066400]
12 Feb 20:24    INFO  valid result: 
ndcg@10 : 0.0664    hit@10 : 0.2325    precision@10 : 0.0324
12 Feb 20:24    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain    17[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain    17[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:24    INFO  epoch 17 training [time: 15.26s, train_loss1: 46.8925, train_loss2: 3.0904]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:25    INFO  epoch 17 evaluating [time: 21.99s, valid_score: 0.066800]
12 Feb 20:25    INFO  valid result: 
ndcg@10 : 0.0668    hit@10 : 0.2341    precision@10 : 0.0329
12 Feb 20:25    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain    18[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain    18[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:25    INFO  epoch 18 training [time: 15.39s, train_loss1: 44.2992, train_loss2: 2.6371]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:25    INFO  epoch 18 evaluating [time: 22.04s, valid_score: 0.068500]
12 Feb 20:25    INFO  valid result: 
ndcg@10 : 0.0685    hit@10 : 0.2373    precision@10 : 0.0334
12 Feb 20:25    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


[1;35mTrain    19[0m:   0%|          | 0/327 [00:00<?, ?it/s]

[1;35mTrain    19[0m:   0%|          | 0/147 [00:00<?, ?it/s]

12 Feb 20:26    INFO  epoch 19 training [time: 15.42s, train_loss1: 41.9708, train_loss2: 2.6128]


[1;35mEvaluate   [0m:   0%|                                                        | 0/13151 [00:00<?, ?it/s…

12 Feb 20:26    INFO  epoch 19 evaluating [time: 22.42s, valid_score: 0.069400]
12 Feb 20:26    INFO  valid result: 
ndcg@10 : 0.0694    hit@10 : 0.2412    precision@10 : 0.034
12 Feb 20:26    INFO  Saving current: saved/KGAT-Feb-12-2023_20-13-49.pth


Evaluating model


12 Feb 20:26    INFO  Loading model structure and parameters from saved/KGAT-Feb-12-2023_20-13-49.pth


OrderedDict([('ndcg@10', 0.0676), ('hit@10', 0.2316), ('precision@10', 0.0328)])


In [109]:
test_result = trainer.evaluate(test_data, load_best_model=True)
print(test_result)

12 Feb 20:26    INFO  Loading model structure and parameters from saved/KGAT-Feb-12-2023_20-13-49.pth


OrderedDict([('ndcg@10', 0.0676), ('hit@10', 0.2316), ('precision@10', 0.0328)])


# Load best model

In [3]:
import os
import torch

In [4]:
latest_model = 'saved/%s' % os.listdir('saved')[-1]

In [5]:
latest_model

'saved/KGAT-Feb-12-2023_15-01-25.pth'

In [6]:
state_dict = torch.load(latest_model, map_location='cuda')
model = KGAT(config, train_data._dataset).to(config['device'])

model.load_state_dict(state_dict['state_dict'])

  d_inv = np.power(rowsum, -1).flatten()
  indices = torch.LongTensor([final_adj_matrix.row, final_adj_matrix.col])


<All keys matched successfully>

## Some model investigation

The model contains entity and user embeddings

In [7]:
model.entity_embedding.weight.shape

torch.Size([134531, 32])

Ego Embeddings return current user and entity embeddings and concatenates them

In [8]:
ego_embeddings = model._get_ego_embeddings()
ego_embeddings.shape

torch.Size([148599, 32])

In [9]:
from recbole.data.interaction import Interaction
input_interactions = Interaction({
    'user_id': torch.tensor([1]),
    'item_id_list': torch.tensor([]),
    'item_length': torch.tensor([])
})
predictions = model.full_sort_predict(input_interactions)

In [10]:
predictions.shape

torch.Size([50019])

In [11]:
dataset

[1;35mrb_lfm[0m
[1;34mThe number of users[0m: 14068
[1;34mAverage actions of users[0m: 58.66169048126822
[1;34mThe number of items[0m: 50019
[1;34mAverage actions of items[0m: 16.49794074133312
[1;34mThe number of inters[0m: 825194
[1;34mThe sparsity of the dataset[0m: 99.88272952155349%
[1;34mRemain Fields[0m: ['entity_id', 'user_id', 'item_id', 'timestamp', 'head_id', 'relation_id', 'tail_id', 'neg_item_id', 'neg_tail_id']
The number of entities: 134531
The number of relations: 8
The number of triples: 300264
The number of items that have been linked to KG: 50018

## The attention matrix

The attention matrix is a nxn matrix where $n = n_u + n_e$ with $n_u$ being the number of users and $n_e$ the number of entities

The rows of the matrix indicate the head, the columns the tail and the values are the $\pi(h,r,t)$ values stated in the paper as

$$
\pi(h,r,t) = 
$$

So the attention matrix is more or less a graph with all relations and the corresponding attention value for a connection between head and tail.

In [12]:
attention_matrix = model.A_in

In [13]:
attention_matrix

tensor(indices=tensor([[     1,      1,      1,  ...,  78612,  78613,  78614],
                       [ 14069,  14084,  14093,  ..., 148597, 148597, 148597]]),
       values=tensor([0.0175, 0.0702, 0.0175,  ..., 1.0000, 1.0000, 1.0000]),
       device='cuda:0', size=(148599, 148599), nnz=976186,
       layout=torch.sparse_coo)

In [14]:
attention_matrix.coalesce()

tensor(indices=tensor([[     1,      1,      1,  ...,  78612,  78613,  78614],
                       [ 14069,  14084,  14093,  ..., 148597, 148597, 148597]]),
       values=tensor([0.0175, 0.0702, 0.0175,  ..., 1.0000, 1.0000, 1.0000]),
       device='cuda:0', size=(148599, 148599), nnz=976186,
       layout=torch.sparse_coo)

In [15]:
134531 + 14068

148599

In [16]:
type(dataset.ckg_graph(form="dgl", value_field="relation_id"))

dgl.heterograph.DGLHeteroGraph

In [17]:
user, items = model()
user

tensor([[-0.0116, -0.0090, -0.0043,  ...,  0.0281, -0.0961,  0.0949],
        [-0.3419,  0.1299, -0.2020,  ..., -0.0029,  0.4548, -0.0607],
        [-0.1656,  0.0864, -0.0920,  ..., -0.0000,  0.3450, -0.0588],
        ...,
        [-0.0775,  0.1889, -0.1472,  ..., -0.0017,  0.1380,  0.1519],
        [-0.0378,  0.0831, -0.0664,  ..., -0.0021,  0.2756,  0.3086],
        [-0.0027,  0.0061, -0.0840,  ..., -0.0014,  0.5064, -0.0398]],
       device='cuda:0', grad_fn=<SplitWithSizesBackward0>)

In [18]:
user.shape

torch.Size([14068, 80])

In [19]:
model._get_ego_embeddings().shape

torch.Size([148599, 32])

In [20]:
user[0][:16]

tensor([-0.0116, -0.0090, -0.0043, -0.0053, -0.0094,  0.0055,  0.0058, -0.0204,
         0.0009, -0.0062, -0.0135, -0.0057, -0.0126, -0.0060, -0.0088, -0.0030],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [21]:
user[0][16:]

tensor([-1.6511e-02,  2.6773e-03, -4.9035e-03,  7.4086e-03,  4.5108e-03,
         1.4419e-02, -1.5370e-02,  7.9050e-03, -5.4199e-03, -1.1961e-02,
        -1.3822e-02, -1.3760e-02,  3.0691e-04,  3.7714e-03,  2.2786e-02,
         2.4545e-03, -9.0743e-02, -2.0827e-01, -1.3551e-01, -1.1944e-01,
        -7.0720e-02, -9.5239e-02, -1.4786e-01, -1.7271e-01, -1.0872e-01,
        -1.5972e-01, -0.0000e+00,  3.1510e-01, -1.8873e-01, -1.2029e-01,
        -1.1343e-01, -1.5693e-01,  3.4265e-01, -2.2036e-01, -1.0355e-01,
         1.0774e-01,  1.2383e-01, -9.9345e-02, -1.8687e-01, -1.0239e-01,
        -2.7038e-01,  9.7624e-02, -2.0195e-01, -2.0739e-01, -1.0198e-01,
         4.0795e-01, -9.7126e-02, -1.1919e-01, -4.6320e-01, -0.0000e+00,
        -1.6688e-02, -1.4783e-02, -7.0402e-03,  2.5105e-01,  1.7425e-02,
        -2.3146e-01, -4.8848e-01, -1.2411e-01, -6.8266e-03, -5.9031e-01,
        -2.1549e-01,  2.8075e-02, -9.6091e-02,  9.4950e-02], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [22]:
items.shape

torch.Size([134531, 80])

In [23]:
attention_matrix = model.A_in.cpu().coalesce()

In [24]:
size = attention_matrix.size()[0]

# Try to convert sparse attention matrix to scipy sparse coo matrix

The attention matrix should reflect a directional weighted graph with attention scores for each connection in the graph. Following those connections and picking the shortest path between two nodes should result in an explainable path

In [25]:
from scipy.sparse import coo_matrix
from scipy.sparse.csgraph import shortest_path
import numpy as np

In [60]:
matrix = coo_matrix((1/attention_matrix.values().numpy(), (attention_matrix.indices()[0].numpy(), attention_matrix.indices()[1].numpy())), shape=(size,size))

In [27]:
dataset

[1;35mrb_lfm[0m
[1;34mThe number of users[0m: 14068
[1;34mAverage actions of users[0m: 58.66169048126822
[1;34mThe number of items[0m: 50019
[1;34mAverage actions of items[0m: 16.49794074133312
[1;34mThe number of inters[0m: 825194
[1;34mThe sparsity of the dataset[0m: 99.88272952155349%
[1;34mRemain Fields[0m: ['entity_id', 'user_id', 'item_id', 'timestamp', 'head_id', 'relation_id', 'tail_id', 'neg_item_id', 'neg_tail_id']
The number of entities: 134531
The number of relations: 8
The number of triples: 300264
The number of items that have been linked to KG: 50018

In [61]:
dataset.token2id(dataset.uid_field, ['2'])

array([3243])

In [62]:
dataset.id2token(dataset.uid_field, [1])

array(['42688'], dtype='<U6')

In [63]:
dataset.fields

<bound method Dataset.fields of [1;35mrb_lfm[0m
[1;34mThe number of users[0m: 14068
[1;34mAverage actions of users[0m: 58.66169048126822
[1;34mThe number of items[0m: 50019
[1;34mAverage actions of items[0m: 16.49794074133312
[1;34mThe number of inters[0m: 825194
[1;34mThe sparsity of the dataset[0m: 99.88272952155349%
[1;34mRemain Fields[0m: ['entity_id', 'user_id', 'item_id', 'timestamp', 'head_id', 'relation_id', 'tail_id', 'neg_item_id', 'neg_tail_id']
The number of entities: 134531
The number of relations: 8
The number of triples: 300264
The number of items that have been linked to KG: 50018>

In [64]:
i = 0
model.eval()
for batch_idx, batched_data in enumerate(valid_data):
    interaction, history_index, positive_u, positive_i = batched_data
    print(interaction['user_id'])
    predictions = model.full_sort_predict(interaction).detach().cpu().numpy()
    print(predictions)
    print(len(predictions))
    break
    i+=1

tensor([1])
[-0.48484862  2.0330496   2.0203824  ... -1.3216093  -0.17598814
  1.5596061 ]
50019


In [65]:
input_interactions = Interaction({
    'user_id': torch.tensor([3243])
})
predictions = model.full_sort_predict(input_interactions).detach().cpu().numpy()


In [66]:
ind = np.argpartition(predictions, -10)[-10:]
ind = ind[np.argsort(predictions[ind])]
ind

array([ 2014,   414, 35031,  2093, 40256,  1514, 29794,  1211,   370,
         192])

In [67]:
ind = ind + dataset.user_num
ind

array([16082, 14482, 49099, 16161, 54324, 15582, 43862, 15279, 14438,
       14260])

In [68]:
last_ind = np.argpartition(predictions, 10)[:10]
last_ind = last_ind[np.argsort(predictions[last_ind])]
last_ind

array([37184, 49145, 38688, 35974, 24847, 45954, 47553, 29316, 19554,
       47471])

In [69]:
last_ind = last_ind + dataset.user_num
last_ind

array([51252, 63213, 52756, 50042, 38915, 60022, 61621, 43384, 33622,
       61539])

In [70]:
dataset.id2token(dataset.uid_field, [1])

array(['42688'], dtype='<U6')

In [71]:
shortest_distances, predecessors = shortest_path(matrix, directed=True,  return_predecessors = True, indices=[3243])

In [72]:
shortest_distances

array([[         inf,  88.56428528, 103.16666508, ...,          inf,
                 inf,          inf]])

In [73]:
predecessors

array([[-9999, 32409, 18374, ..., -9999, -9999, -9999]], dtype=int32)

Let's check the distances for the top 10 recommendations

In [74]:
for idx in ind:
    display(shortest_path_user[0][0][idx])

0.1179059831192717

0.12111592653673142

0.11804763542022556

0.11513590009417385

0.1179059831192717

0.12097819827613421

0.12101520586293191

0.1179059831192717

0.12177651727688499

0.12074998475145549

intuitively, the last 10 recommendations should be not reachable

In [75]:
for idx in last_ind:
    display(shortest_path_user[0][0][idx])

inf

inf

inf

inf

inf

inf

inf

inf

inf

inf

In [76]:
def get_path(Pr, j):
    path = [j]
    k = j
    while Pr[k] != -9999:
        path.append(Pr[k])
        k = Pr[k]
    return path[::-1]

In [83]:
get_path(predecessors[0], ind[9])

[3243, 16226, 2213, 17619, 637, 14260]

## Put it all together and evaluate on test set