In [69]:
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from rectools.models import RandomModel, PopularModel
from rectools.dataset import Interactions, Dataset
from rectools.model_selection import TimeRangeSplitter
from rectools import Columns
from rectools.metrics import (
    Precision,
    Accuracy,
    MAP,
    MRR, 
    NDCG,
    calc_metrics,
)
from rectools.models import ImplicitItemKNNWrapperModel

In [70]:
import requests

# url = 'https://github.com/irsafilo/KION_DATASET/raw/f69775be31fa5779907cf0a92ddedb70037fb5ae/data_original.zip'
# req = requests.get(url, stream=True)
# 
# with open('kion.zip', 'wb') as fd:
#     total_size_in_bytes = int(req.headers.get('Content-Length', 0))
#     progress_bar = tqdm(desc='kion dataset download', total=total_size_in_bytes, unit='iB', unit_scale=True)
#     for chunk in req.iter_content(chunk_size=2 ** 20):
#         progress_bar.update(len(chunk))
#         fd.write(chunk)
#         
# import zipfile as zf
# 
# files = zf.ZipFile('kion.zip','r')
# files.extractall()
# files.close()

In [99]:
interactions = pd.read_csv('data_original/interactions.csv')
interactions.rename(
    columns={
        'last_watch_dt': Columns.Datetime,
        'total_dur': Columns.Weight
    }, 
    inplace=True) 
interactions = Interactions(interactions)

selected_columns = ['item_id', 'title', 'release_year', 'genres', 'countries']
item_data = pd.read_csv('data_original/items.csv', usecols=selected_columns)

In [76]:
item_data

Unnamed: 0,item_id,title,release_year,genres,countries
0,10711,Поговори с ней,2002.0,"драмы, зарубежные, детективы, мелодрамы",Испания
1,2508,Голые перцы,2014.0,"зарубежные, приключения, комедии",США
2,10716,Тактическая сила,2011.0,"криминал, зарубежные, триллеры, боевики, комедии",Канада
3,7868,45 лет,2015.0,"драмы, зарубежные, мелодрамы",Великобритания
4,16268,Все решает мгновение,1978.0,"драмы, спорт, советские, мелодрамы",СССР
...,...,...,...,...,...
15958,6443,Полярный круг,2018.0,"драмы, триллеры, криминал","Финляндия, Германия"
15959,2367,Надежда,2020.0,"драмы, боевики",Россия
15960,10632,Сговор,2017.0,"драмы, триллеры, криминал",Россия
15961,4538,Среди камней,2019.0,"драмы, спорт, криминал",Россия


In [77]:
pd.concat([interactions.df.head(), interactions.df.tail()])

Unnamed: 0,user_id,item_id,datetime,weight,watched_pct
0,176549,9506,2021-05-11,4250.0,72.0
1,699317,1659,2021-05-29,8317.0,100.0
2,656683,7107,2021-05-09,10.0,0.0
3,864613,7638,2021-07-05,14483.0,100.0
4,964868,9506,2021-04-30,6725.0,100.0
5476246,648596,12225,2021-08-13,76.0,0.0
5476247,546862,9673,2021-04-13,2308.0,49.0
5476248,697262,15297,2021-08-20,18307.0,63.0
5476249,384202,16197,2021-04-19,6203.0,100.0
5476250,319709,4436,2021-08-15,3921.0,45.0


In [78]:
interactions.df.info(memory_usage='deep')

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5476251 entries, 0 to 5476250
Data columns (total 5 columns):
 #   Column       Dtype         
---  ------       -----         
 0   user_id      int64         
 1   item_id      int64         
 2   datetime     datetime64[ns]
 3   weight       float64       
 4   watched_pct  float64       
dtypes: datetime64[ns](1), float64(2), int64(2)
memory usage: 208.9 MB


In [79]:
models = {
    "random": RandomModel(random_state=42),
    "popular": PopularModel(),
    "most_raited": PopularModel(popularity="sum_weight")
}

In [84]:
metrics = {
    "precision@10": Precision(k=10),
    "accuracy@10": Accuracy(k=10),
    "map@10": MAP(k=10),
    "mrr@10": MRR(k=10),
    "ndcg@10": NDCG(k=10),
    "precision@5": Precision(k=5),
    "accuracy@5": Accuracy(k=5),
    "map@5": MAP(k=5),
    "mrr@5": MRR(k=5),
    "ndcg@5": NDCG(k=5),
    "precision@1": Precision(k=1),
    "accuracy@1": Accuracy(k=1),
    "map@1": MAP(k=1),
    "mrr@1": MRR(k=1),
    "ndcg@1": NDCG(k=1),
}

In [85]:
n_splits = 3

splitter = TimeRangeSplitter(
    test_size="14D",
    n_splits=n_splits,
    filter_already_seen=True,
    filter_cold_items=True,
    filter_cold_users=True,
)

Расчёт метрик:

In [86]:
def get_metrics(models, metrics, splitter, k):
    
    splitter.get_test_fold_borders(interactions)
    
    results = []
    
    fold_iterator = splitter.split(interactions, collect_fold_stats=True)
    
    for train_ids, test_ids, fold_info in tqdm((fold_iterator), total=splitter.n_splits):
        print(f"\n==================== Fold {fold_info['i_split']}")
        print(fold_info)
    
        df_train = interactions.df.iloc[train_ids]
        dataset = Dataset.construct(df_train)
    
        df_test = interactions.df.iloc[test_ids][Columns.UserItem]
        test_users = np.unique(df_test[Columns.User])
    
        catalog = df_train[Columns.Item].unique()
    
        for model_name, model in models.items():
            
            model.fit(dataset)
            recos = model.recommend(
                users=test_users,
                dataset=dataset,
                k=k,
                filter_viewed=True,
            )
            
            metric_values = calc_metrics(
                metrics,
                reco=recos,
                interactions=df_test,
                prev_interactions=df_train,
                catalog=catalog,
            )
            
            res = {"fold": fold_info["i_split"], "model": model_name}
            res.update(metric_values)
            results.append(res)
            
    return results

In [125]:
def visual_analys(model, interactions, user_ids, item_data):
    
    dataset = Dataset.construct(interactions.df)

    recos = model.recommend(
    users=user_ids,
    dataset=dataset,
    k=10,
    filter_viewed=True,
    )
    
    recos = pd.merge(recos, item_data, on='item_id', how='left')
    
    print(recos.to_string(index=False, max_colwidth=40))

In [88]:
results = get_metrics(models, metrics, splitter, k=10)
results

  0%|          | 0/3 [00:00<?, ?it/s]


{'i_split': 0, 'start': Timestamp('2021-07-12 00:00:00', freq='14D'), 'end': Timestamp('2021-07-26 00:00:00', freq='14D'), 'train': 3239125, 'train_users': 646423, 'train_items': 14730, 'test': 398993, 'test_users': 122488, 'test_items': 7394}

{'i_split': 1, 'start': Timestamp('2021-07-26 00:00:00', freq='14D'), 'end': Timestamp('2021-08-09 00:00:00', freq='14D'), 'train': 3892558, 'train_users': 742256, 'train_items': 15085, 'test': 458757, 'test_users': 135624, 'test_items': 7711}

{'i_split': 2, 'start': Timestamp('2021-08-09 00:00:00', freq='14D'), 'end': Timestamp('2021-08-23 00:00:00', freq='14D'), 'train': 4649162, 'train_users': 850489, 'train_items': 15415, 'test': 521381, 'test_users': 151629, 'test_items': 7705}


[{'fold': 0,
  'model': 'random',
  'precision@10': 0.0002122656913330285,
  'accuracy@10': 0.9991002607268723,
  'precision@5': 0.00019757037424074193,
  'accuracy@5': 0.9994395499590455,
  'precision@1': 0.00022042975638429888,
  'accuracy@1': 0.9997110004108969,
  'ndcg@10': 0.0002118380318504049,
  'ndcg@5': 0.0002033616926661086,
  'ndcg@1': 0.00022042975638429888,
  'mrr@10': 0.0006196298594329424,
  'mrr@5': 0.00047324363747197,
  'mrr@1': 0.00022042975638429888,
  'map@10': 0.00020768850263758139,
  'map@5': 0.00015775375404269358,
  'map@1': 8.110722007185333e-05},
 {'fold': 0,
  'model': 'popular',
  'precision@10': 0.051507086408464506,
  'accuracy@10': 0.999169907462401,
  'precision@5': 0.07838808699627721,
  'accuracy@5': 0.9994926324550552,
  'precision@1': 0.10307948533733917,
  'accuracy@1': 0.999724966338335,
  'ndcg@10': 0.06400819383624766,
  'ndcg@5': 0.085028346947154,
  'ndcg@1': 0.10307948533733917,
  'mrr@10': 0.19365305820693288,
  'mrr@5': 0.18407082054296472

In [127]:
user_ids = [666262, 672861, 955527]
for model_name, model in models.items():
    print(f'\nmodel "{model_name}":')
    visual_analys(model, interactions, user_ids, item_data)


model "random":
 user_id  item_id  score  rank                                    title  release_year                                   genres                          countries
  666262     7419     10     1                              Ода радости        2019.0                                  комедии                                США
  666262     9109      9     2                          Последняя битва        2017.0                           драмы, военные                     Великобритания
  666262    13917      8     3                              Преисподняя        2016.0      драмы, детективы, триллеры, вестерн                         Нидерланды
  666262    13332      7     4                                Лихорадка        2003.0                                    ужасы                                США
  666262     1331      6     5                                 Вечность        2016.0                                    драмы                   Франция, Бельгия
  666262   

In [128]:
pivot_results = pd.DataFrame(results).drop(columns="fold").groupby(["model"], sort=False).agg(["mean", "std"])
pivot_results

Unnamed: 0_level_0,precision@10,precision@10,accuracy@10,accuracy@10,precision@5,precision@5,accuracy@5,accuracy@5,precision@1,precision@1,...,mrr@5,mrr@5,mrr@1,mrr@1,map@10,map@10,map@5,map@5,map@1,map@1
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,...,mean,std,mean,std,mean,std,mean,std,mean,std
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
random,0.000222,8e-06,0.999114,1.4e-05,0.000225,2.4e-05,0.999446,7e-06,0.000247,4.2e-05,...,0.000527,6.2e-05,0.000247,4.2e-05,0.000199,1.4e-05,0.000154,1.2e-05,7e-05,1.4e-05
popular,0.045964,0.005104,0.999175,6e-06,0.070339,0.007241,0.999492,2e-06,0.097211,0.007269,...,0.168272,0.015108,0.097211,0.007269,0.098549,0.01375,0.091013,0.012241,0.047186,0.004912
most_raited,0.041627,0.004448,0.999169,7e-06,0.056373,0.01066,0.999483,4e-06,0.098136,0.007737,...,0.152551,0.01826,0.098136,0.007737,0.090972,0.013581,0.08166,0.013846,0.04792,0.005112


In [37]:
pivot_results.to_csv("../artifacts/first_reco_result.csv")