# Item Response Ranking for NCDM

This notebook will show you how to train and use the IRR-NCDM.
Refer to [IRR doc](../../docs/IRR.md) for more details.
First, we will show how to get the data (here we use a0910 as the dataset).
Then we will show how to train a IRR-NCDM and perform the parameters persistence.
At last, we will show how to load the parameters from the file and evaluate on the test dataset.

In [1]:
import logging
from longling.lib.structure import AttrDict
from longling import set_logging_info
from EduCDM.IRR import pair_etl as etl, point_etl as vt_etl, extract_item

set_logging_info()

params = AttrDict(
    batch_size=256,
    n_neg=10,
    n_imp=10,
    logger=logging.getLogger(),
    hyper_params={"user_num": 4164, "knowledge_num": 123}
)
item_knowledge = extract_item("../../data/a0910/item.csv", params["hyper_params"]["knowledge_num"], params)
train_data, train_df = etl("../../data/a0910/train.csv", item_knowledge, params)
valid_data, _ = vt_etl("../../data/a0910/valid.csv", item_knowledge, params)
test_data, _ = vt_etl("../../data/a0910/test.csv", item_knowledge, params)

train_data, valid_data, test_data

reading records from ../../data/a0910/item.csv: 100%|██████████| 19529/19529 [00:00<00:00, 56803.95it/s]
rating2triplet: 100%|██████████| 17051/17051 [00:15<00:00, 1073.69it/s]


(<longling.lib.iterator.AsyncLoopIter at 0x19579f67af0>,
 <torch.utils.data.dataloader.DataLoader at 0x195779a4e50>,
 <torch.utils.data.dataloader.DataLoader at 0x1957b015a00>)

In [2]:
train_df

Unnamed: 0,user_id,item_id,score
0,1615,12977,1.0
1,782,13124,0.0
2,1084,16475,0.0
3,593,8690,0.0
4,127,14225,1.0
...,...,...,...
186044,2280,6019,0.0
186045,121,2,1.0
186046,601,5425,1.0
186047,573,2412,0.0


In [3]:
from EduCDM.IRR import NCDM

cdm = NCDM(
    4163 + 1,
    17746 + 1,
    123,
)
cdm.train(
    train_data,
    valid_data,
    epoch=2,
)
cdm.save("IRR-NCDM.params")

Epoch 0: 727it [03:23,  3.57it/s]
evaluating: 100%|██████████| 101/101 [00:00<00:00, 115.12it/s]
formatting item df: 100%|██████████| 10415/10415 [00:00<00:00, 12423.18it/s]
ranking metrics: 10415it [00:14, 735.24it/s]
Epoch 1: 100%|██████████| 727/727 [02:49<00:00,  4.29it/s]
evaluating: 100%|██████████| 101/101 [00:01<00:00, 91.61it/s]
formatting item df: 100%|██████████| 10415/10415 [00:00<00:00, 11477.08it/s]
ranking metrics: 10415it [00:14, 707.66it/s]
INFO:root:save parameters to IRR-NCDM.params


[Epoch 0] Loss: 2.558156, PointLoss: 0.647514, PairLoss: 4.468798
[Epoch 0]
      ndcg@k  precision@k  recall@k      f1@k     len@k  support@k
1   1.000000     0.704849  0.493072  0.547682  1.000000      10415
3   0.895209     0.681741  0.743652  0.691947  1.906961      10415
5   0.894566     0.676585  0.796747  0.713769  2.229573      10415
10  0.893579     0.674508  0.816654  0.720577  2.423428      10415
auc: 0.866700	map: 0.855576	mrr: 0.924757	coverage_error: 3.349563	ranking_loss: 0.564976	len: 2.458569	support: 10415
[Epoch 1] Loss: 2.555617, PointLoss: 0.644294, PairLoss: 4.466940
[Epoch 1]
      ndcg@k  precision@k  recall@k      f1@k     len@k  support@k
1   1.000000     0.704849  0.493072  0.547682  1.000000      10415
3   0.895237     0.681741  0.743652  0.691947  1.906961      10415
5   0.894598     0.676585  0.796747  0.713769  2.229573      10415
10  0.893612     0.674508  0.816654  0.720577  2.423428      10415
auc: 0.866700	map: 0.855650	mrr: 0.924757	coverage_error: 3

In [4]:
cdm.load("IRR-NCDM.params")
print(cdm.eval(test_data))

INFO:root:load parameters from IRR-NCDM.params
evaluating: 100%|██████████| 218/218 [00:01<00:00, 169.59it/s]
formatting item df: 100%|██████████| 13682/13682 [00:01<00:00, 12955.64it/s]
ranking metrics: 13682it [00:23, 571.49it/s]


      ndcg@k  precision@k  recall@k      f1@k     len@k  support@k
1   1.000000     0.703552  0.386927  0.454626  1.000000      13682
3   0.871698     0.683635  0.676587  0.646876  2.268528      13682
5   0.871883     0.674940  0.777347  0.696432  2.981582      13682
10  0.869815     0.669786  0.847081  0.725684  3.723652      13682
auc: 0.804861	map: 0.803387	mrr: 0.895832	coverage_error: 5.059019	ranking_loss: 0.636236	len: 4.075428	support: 13682
