# Setting

In [1]:
import os
from typing import List, Dict, Union

import polars as pl

from scripts.metrics import map_at_k

In [2]:
INPUT_DIR = "../../input/raw/"
OUTPUT_DIR = "./candidates/"

In [3]:
def explode_and_add_seq_no(df:pl.DataFrame) -> pl.DataFrame:
    df = df.explode(["prev_items"])
    df = df.with_columns(
        df.select(pl.col("session_id").cumcount().over("session_id").alias("seq_no").cast(pl.Int64))
    )
    return df

In [4]:
def generate_co_visit_matrix(df:pl.DataFrame) -> pl.DataFrame:
    # 共起ペアの作成
    df = df.join(df, on="session_id")

    # yad_noが同じものは除外する
    df = df.filter(pl.col("yad_no") != pl.col("yad_no_right"))

    # yad_noのペアごとに共起回数を計算
    df = df.group_by(["yad_no", "yad_no_right"]).count()

    # rankを計算
    df = df.with_columns(pl.col("count").rank(descending=True).over("yad_no").alias("trend_co_visit_weight_rank"))

    # 整形
    df = df.rename({"yad_no_right":"candidate_yad_no"})[["yad_no", "candidate_yad_no", "trend_co_visit_weight_rank"]]

    return df

# For local train/eval

In [5]:
train_log = pl.read_csv(os.path.join(INPUT_DIR, "train_log.csv"))

In [6]:
co_visit_matrix = generate_co_visit_matrix(train_log)

In [7]:
co_visit_matrix.write_parquet(os.path.join(OUTPUT_DIR, "co_visit_matrix_trend_for_train_or_eval.parquet"))

In [8]:
co_visit_matrix.head()

yad_no,candidate_yad_no,trend_co_visit_weight_rank
i64,i64,f64
10095,12425,92.0
6514,7890,1.0
10856,11146,18.5
13198,3653,13.5
7787,12750,8.5


# MAP@k=10

In [9]:
train_log = pl.read_csv(os.path.join(INPUT_DIR, "train_log.csv"))
train_label = pl.read_csv(os.path.join(INPUT_DIR, "train_label.csv")).rename({"yad_no":"label_yad_no"})

In [10]:
last_items = train_log.group_by("session_id").last()

In [11]:
co_visit_matrix = pl.read_parquet(os.path.join(OUTPUT_DIR, "co_visit_matrix_trend_for_train_or_eval.parquet"))

In [12]:
prediction = last_items \
    .join(co_visit_matrix, on="yad_no", how="left") \
    .join(train_label, on="session_id", how="left") \
    .sort(["session_id", "trend_co_visit_weight_rank"], descending=[False, False]) \
    .with_columns((pl.col("candidate_yad_no") == pl.col("label_yad_no")).cast(pl.Int8).alias("user_relevance")) \
    .fill_null(0)

In [13]:
user_relevances = prediction.group_by("session_id", maintain_order=True).all()["user_relevance"].to_list()

In [14]:
map_at_k(user_relevances, 10)

0.21876196463663303

# For test

In [15]:
test_log = pl.read_csv(os.path.join(INPUT_DIR, "test_log.csv"))

In [16]:
co_visit_matrix = generate_co_visit_matrix(test_log)

In [17]:
co_visit_matrix.write_parquet(os.path.join(OUTPUT_DIR, "co_visit_matrix_trend_for_test.parquet"))

In [18]:
co_visit_matrix.head()

yad_no,candidate_yad_no,trend_co_visit_weight_rank
i64,i64,f64
7044,10033,3.0
2288,8441,5.0
8958,12183,4.5
3370,6927,3.0
6770,12868,10.5
