In [1]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [2]:
!pip install polars

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
import concurrent.futures
import math

from tqdm import tqdm
import numpy as np
import polars as pl
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from scipy.sparse import csr_matrix
from scipy.spatial.distance import cosine

In [4]:
TOP_N = 100
VER = "15"
DIR = "/gdrive/MyDrive/amazon_kdd_2023/"

# MRR@100

In [None]:
train = pl.read_parquet("/gdrive/MyDrive/amazon_kdd_2023/data/preprocessed/task1/train_task1.parquet")

In [None]:
candidates = pl.concat([
    pl.read_parquet(DIR + f"data/interim/candidates/task1/imf_{VER}_DE_for_train_or_eval.parquet").filter(pl.col("session_id").str.starts_with("train")).sort(["session_id", "imf_score"], descending=[False, True]).groupby("session_id", maintain_order=True).head(TOP_N),
    pl.read_parquet(DIR + f"data/interim/candidates/task1/imf_{VER}_UK_for_train_or_eval.parquet").filter(pl.col("session_id").str.starts_with("train")).sort(["session_id", "imf_score"], descending=[False, True]).groupby("session_id", maintain_order=True).head(TOP_N),
    pl.read_parquet(DIR + f"data/interim/candidates/task1/imf_{VER}_JP_for_train_or_eval.parquet").filter(pl.col("session_id").str.starts_with("train")).sort(["session_id", "imf_score"], descending=[False, True]).groupby("session_id", maintain_order=True).head(TOP_N),
])

In [None]:
candidates = candidates.filter(pl.col("imf_score") != 0)
candidates.write_parquet(DIR + f"data/interim/candidates/task1/imf_{VER}_for_train_or_eval.parquet")

In [None]:
candidates.head()

session_id,candidate_item,imf_score,imf_rank
str,str,f32,u32
"""train_0""","""B07KDC7PJH""",0.739599,1
"""train_0""","""B09P83WGQG""",0.67242,2
"""train_0""","""B084G8H15D""",0.616889,3
"""train_0""","""B089W9VDMN""",0.602682,4
"""train_0""","""B08L66WJWR""",0.589559,5


In [None]:
label_lists = []
n_rows = 400_000
for df in tqdm(train.iter_slices(n_rows=n_rows), total=math.ceil(train.height/n_rows)): # specify "total" parameter to display tqdm progress bar 
    # process data
    df = df.join(candidates, on="session_id", how="left")
    df = df.with_columns((pl.col("candidate_item") == pl.col("next_item")).cast(pl.Int8).alias("label"))
    label_lists.extend(df.groupby("session_id", maintain_order=True).all()["label"].to_list())

100%|██████████| 9/9 [04:05<00:00, 27.30s/it]


In [None]:
# MRRの計算
rr = 0
for labels in label_lists:
    labels = labels[:100]
    for i, label in enumerate(labels):
        if label == 1:
            rr += 1 / (i+1)
            break
mrr = rr / len(label_lists)
print("MRR:", round(mrr, 5))

MRR: 0.20113


# for inference

In [5]:
candidates = pl.concat([
    pl.read_parquet(DIR + f"data/interim/candidates/task1/imf_{VER}_DE_for_inference.parquet").filter(pl.col("session_id").str.starts_with("test")).sort(["session_id", "imf_score"], descending=[False, True]).groupby("session_id", maintain_order=True).head(TOP_N),
    pl.read_parquet(DIR + f"data/interim/candidates/task1/imf_{VER}_UK_for_inference.parquet").filter(pl.col("session_id").str.starts_with("test")).sort(["session_id", "imf_score"], descending=[False, True]).groupby("session_id", maintain_order=True).head(TOP_N),
    pl.read_parquet(DIR + f"data/interim/candidates/task1/imf_{VER}_JP_for_inference.parquet").filter(pl.col("session_id").str.starts_with("test")).sort(["session_id", "imf_score"], descending=[False, True]).groupby("session_id", maintain_order=True).head(TOP_N),
])

In [6]:
candidates = candidates.filter(pl.col("imf_score") != 0)
candidates.write_parquet(DIR + f"data/interim/candidates/task1/imf_{VER}_for_inference.parquet")

In [7]:
candidates.head()

session_id,candidate_item,imf_score,imf_rank
str,str,f32,u32
"""test_phase2_0""","""B0B3HMH1JP""",0.57608,1
"""test_phase2_0""","""B0B87N98MM""",0.45538,2
"""test_phase2_0""","""B09BR9YCJ9""",0.454744,3
"""test_phase2_0""","""B00FFTDZTY""",0.452741,4
"""test_phase2_0""","""B07SDFLVKD""",0.444103,5
