# Item Response Ranking for MIRT

This notebook will show you how to train and use the IRR-MIRT.
Refer to [IRR doc](../../docs/IRR.md) for more details.
First, we will show how to get the data (here we use a0910 as the dataset).
Then we will show how to train a IRR-MIRT and perform the parameters persistence.
At last, we will show how to load the parameters from the file and evaluate on the test dataset.

In [1]:
import logging
from longling.lib.structure import AttrDict
from longling import set_logging_info
from EduCDM.IRR import pair_etl as etl, point_etl as vt_etl, extract_item

set_logging_info()

params = AttrDict(
    batch_size=256,
    n_neg=10,
    n_imp=10,
    logger=logging.getLogger(),
    hyper_params={"user_num": 4164}
)
item_knowledge = extract_item("../../data/a0910/item.csv", 123, params)
train_data, train_df = etl("../../data/a0910/train.csv", item_knowledge, params)
valid_data, _ = vt_etl("../../data/a0910/valid.csv", item_knowledge, params)
test_data, _ = vt_etl("../../data/a0910/test.csv", item_knowledge, params)

train_data, valid_data, test_data

reading records from ../../data/a0910/item.csv: 100%|██████████| 19529/19529 [00:00<00:00, 55001.51it/s]
rating2triplet: 100%|██████████| 17051/17051 [00:16<00:00, 1015.42it/s]


(<longling.lib.iterator.AsyncLoopIter at 0x2c650edb910>,
 <torch.utils.data.dataloader.DataLoader at 0x2c641c7cd00>,
 <torch.utils.data.dataloader.DataLoader at 0x2c650edb7c0>)

In [2]:
train_df

Unnamed: 0,user_id,item_id,score
0,1615,12977,1.0
1,782,13124,0.0
2,1084,16475,0.0
3,593,8690,0.0
4,127,14225,1.0
...,...,...,...
186044,2280,6019,0.0
186045,121,2,1.0
186046,601,5425,1.0
186047,573,2412,0.0


In [3]:
from EduCDM.IRR import MIRT

cdm = MIRT(
    4163 + 1,
    17746 + 1,
    123
)
cdm.train(
    train_data,
    valid_data,
    epoch=2,
)
cdm.save("IRR-MIRT.params")

Epoch 0: 727it [01:44,  6.99it/s]
evaluating: 100%|██████████| 101/101 [00:00<00:00, 147.99it/s]
formatting item df: 100%|██████████| 10415/10415 [00:00<00:00, 11414.30it/s]
ranking metrics: 10415it [00:15, 673.20it/s]
Epoch 1: 100%|██████████| 727/727 [01:34<00:00,  7.67it/s]
evaluating: 100%|██████████| 101/101 [00:00<00:00, 162.69it/s]
formatting item df: 100%|██████████| 10415/10415 [00:00<00:00, 13177.35it/s]
ranking metrics: 10415it [00:14, 737.45it/s]
INFO:root:save parameters to IRR-MIRT.params


[Epoch 0] Loss: 2.564640, PointLoss: 0.664851, PairLoss: 4.464429
[Epoch 0]
      ndcg@k  precision@k  recall@k      f1@k     len@k  support@k
1   1.000000     0.672492  0.473254  0.524685  1.000000      10415
3   0.888882     0.675116  0.737681  0.686057  1.906961      10415
5   0.892384     0.674415  0.793843  0.711479  2.229573      10415
10  0.892711     0.673980  0.816039  0.720015  2.423428      10415
auc: 0.836789	map: 0.911223	mrr: 0.902067	coverage_error: 3.008395	ranking_loss: 0.285400	len: 2.458569	support: 10415
[Epoch 1] Loss: 2.538817, PointLoss: 0.651702, PairLoss: 4.425933
[Epoch 1]
      ndcg@k  precision@k  recall@k      f1@k     len@k  support@k
1   1.000000     0.670475  0.472193  0.523379  1.000000      10415
3   0.889095     0.675148  0.737989  0.686143  1.906961      10415
5   0.892053     0.674396  0.794098  0.711508  2.229573      10415
10  0.892339     0.674066  0.816177  0.720113  2.423428      10415
auc: 0.836738	map: 0.910856	mrr: 0.901024	coverage_error: 3

In [4]:
cdm.load("IRR-MIRT.params")
print(cdm.eval(test_data))

INFO:root:load parameters from IRR-MIRT.params
evaluating: 100%|██████████| 218/218 [00:00<00:00, 258.54it/s]
formatting item df: 100%|██████████| 13682/13682 [00:01<00:00, 13198.37it/s]
ranking metrics: 13682it [00:22, 610.71it/s]


      ndcg@k  precision@k  recall@k      f1@k     len@k  support@k
1   1.000000     0.669200  0.371232  0.435331  1.000000      13682
3   0.862998     0.667434  0.663840  0.633772  2.268528      13682
5   0.869094     0.667705  0.770522  0.690039  2.981582      13682
10  0.869793     0.667432  0.844655  0.723425  3.723652      13682
auc: 0.770833	map: 0.870554	mrr: 0.873113	coverage_error: 4.645888	ranking_loss: 0.315248	len: 4.075428	support: 13682
