In [1]:
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [2]:
!pip install polars

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
from collections import defaultdict, Counter
from typing import List, Dict

from tqdm import tqdm
import pandas as pd
import polars as pl

In [4]:
LOCALES = ["FR", "ES", "IT"]

In [5]:
train = pd.read_csv("/gdrive/MyDrive/amazon_kdd_2023/data/raw/sessions_train.csv")
train = train[train["locale"].isin(LOCALES)]
test = pd.concat([
    pd.read_csv("/gdrive/MyDrive/amazon_kdd_2023/data/raw/sessions_test_task2.csv"),
    pd.read_csv("/gdrive/MyDrive/amazon_kdd_2023/data/raw/sessions_test_task3.csv"),
])
test = test[test["locale"].isin(LOCALES)]

train = pl.from_pandas(train)
test = pl.from_pandas(test)

In [6]:
product = pl.read_parquet("/gdrive/MyDrive/amazon_kdd_2023/data/preprocessed/common/product_03.parquet")

In [7]:
# prev_itemsの加工
def str2list(s):
    s = s.replace("[", "").replace("]", "").replace("'", "").replace("\n", " ").replace("\r", " ")
    s = s.split() # 空白で区切ってリスト化
    return s

train = train.with_columns(pl.col("prev_items").apply(str2list).alias("prev_items"))
test = test.with_columns(pl.col("prev_items").apply(str2list).alias("prev_items"))

In [8]:
# testデータをtrainに追加

# 条件１：セッション数が3以上
test = test.with_columns(
    pl.col("prev_items").apply(len).alias("session_count")
)
test = test.filter(pl.col("session_count") >= 3)

# 条件２：最終アイテムが未インタラクト
prev_items_list = test["prev_items"].to_list()
next_item_list = []
prev_items_list_updated = []
for prev_items in prev_items_list:
    next_item_list.append(prev_items[-1])
    prev_items_list_updated.append(prev_items[:-1])
test = test.with_columns([
    pl.Series(name="next_item", values=next_item_list),
    pl.Series(name="prev_items", values=prev_items_list_updated),
])
test = test.filter(~pl.col("next_item").is_in(pl.col("prev_items")))
test = test[["prev_items", "next_item", "locale"]]

In [9]:
print("test追加前", len(train))
train = pl.concat([train, test])
print("test追加後", len(train))

test追加前 333533
test追加後 358312


In [10]:
# session_idの付与
train = train.with_columns(pl.Series(name="session_id", values=["train_" + str(i) for i in range(len(train))]))

In [11]:
train.head()

prev_items,next_item,locale,session_id
list[str],str,str,str
"[""B08MV5B53K"", ""B08MV4RCQR"", ""B08MV5B53K""]","""B012408XPC""","""ES""","""train_0"""
"[""B07JGW4QWX"", ""B085VCXHXL""]","""B07JFPYN5P""","""ES""","""train_1"""
"[""B08BFQ52PR"", ""B08LVSTZVF"", ""B08BFQ52PR""]","""B08NJP3KT6""","""ES""","""train_2"""
"[""B08PPBF9C6"", ""B08PPBF9C6"", … ""B08PPBF9C6""]","""B08PP6BLLK""","""ES""","""train_3"""
"[""B0B6W67XCR"", ""B0B712FY2M"", ""B0B6ZYJ3S2""]","""B09SL4MBM2""","""ES""","""train_4"""


In [13]:
train.write_parquet("/gdrive/MyDrive/amazon_kdd_2023/data/preprocessed/task2/train_task2_04.parquet")