In [1]:
%load_ext autoreload
%autoreload 2
import sys
from pathlib import Path
import pandas as pd
import cudf
import cupy as cp
import numpy as np
from tqdm import tqdm

sys.path.append('../')
from src.retriever import PopularItem, FavoriteItem, CoOccurrenceItem, ConcatRetriever
from src.utils import get_data_period, period_extraction

In [2]:
TRAIN_PATH = '../data/processed/train.csv'
COOCCURRENCE_DIR = Path('../data/retriever/co-occurrence/')
EXP_NO = 'exp006'
OUTPUT_DIR = Path(f'../data/{EXP_NO}')

date_th_list = ['2017-04-16', '2017-04-23', '2017-04-30']
train_flag_list = [True, True, False]
train_period = 14
eval_period = 7
top_n = 100

In [3]:
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
df = cudf.read_csv(TRAIN_PATH, parse_dates=['time_stamp'])
for date_th, train_flag in zip(date_th_list, train_flag_list):
    print('='*50)
    print(f'# date_th={date_th}, train_flag={train_flag}')
    
    train_start, train_end, eval_start, eval_end = get_data_period(date_th, train_period, eval_period)
    
    config_common = {
        'train_start_date': train_start,
        'train_end_date': train_end,
        'eval_start_date': eval_start,
        'eval_end_date': eval_end,
        
    }
    
    config_popular = {**config_common, 'top_n': 100}
    
    config_favorite = {**config_common, 'top_n': 100}
    
    config_cooccurrence = {
        **config_common,
        'top_n': 200,
        'output_path': COOCCURRENCE_DIR/ f'co-occurrence_{date_th}_t{train_period}_e{eval_period}_n{top_n}.pickle',
    }
    
    config_concat = {
        **config_common,
        'top_n': 100,
        'retrievers': [
            FavoriteItem(**config_favorite),
            CoOccurrenceItem(**config_cooccurrence),
            PopularItem(**config_popular),
        ]
    }
    
    retriever = ConcatRetriever(**config_concat)
    retriever.fit(df)
    
    if train_flag:
        users = period_extraction(df, eval_start, eval_end)['user_id'].unique().to_arrow().tolist()
        retriever.search(users)
        scores = retriever.evaluate(df, verbose=True)
        pairs = retriever.get_pairs(df)
    else:
        users = pd.read_csv('../data/raw/test.tsv', delimiter='\t')['user_id'].tolist()
        retriever.search(users)
        pairs = retriever.get_pairs(df, target=False)
        
    filename = OUTPUT_DIR/f'pairs_{date_th}_t{train_period}_e{eval_period}_n{top_n}.csv'
    pairs.to_csv(filename, index=False)

# date_th=2017-04-16, train_flag=True
[ConcatRetriever] n=12,011, n_items=100.0 max_ndcg=0.5012, recall=0.3774, precision=0.0357
# date_th=2017-04-23, train_flag=True
[ConcatRetriever] n=11,501, n_items=100.0 max_ndcg=0.4905, recall=0.3706, precision=0.0352
# date_th=2017-04-30, train_flag=False
