In [1]:
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [2]:
!pip install polars

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
from collections import defaultdict, Counter
from typing import List, Dict

from tqdm import tqdm
import polars as pl

In [4]:
LOCALES = ["FR", "ES", "IT"]

In [5]:
train = pl.read_csv("/gdrive/MyDrive/amazon_kdd_2023/data/raw/sessions_train.csv")

In [6]:
product = pl.read_parquet("/gdrive/MyDrive/amazon_kdd_2023/data/preprocessed/common/product_03.parquet")

In [7]:
train["locale"].value_counts()

locale,counts
str,u32
"""FR""",117561
"""ES""",89047
"""UK""",1182181
"""JP""",979119
"""IT""",126925
"""DE""",1111416


In [8]:
# prev_itemsの加工
def str2list(s):
    s = s.replace("[", "").replace("]", "").replace("'", "").replace("\n", " ").replace("\r", " ")
    s = s.split() # 空白で区切ってリスト化
    return s

train = train.with_columns(pl.col("prev_items").apply(str2list).alias("prev_items"))

In [9]:
# session_idの付与
train = train.with_columns(pl.Series(name="session_id", values=["train_" + str(i) for i in range(len(train))]))
train_original = train

In [10]:
# next_itemをprev_itemsにappendする
prev_items_list = train["prev_items"].to_list()
next_item_list = train["next_item"].to_list()
prev_items_list_updated = []
for prev_items, next_item in zip(prev_items_list, next_item_list):
    prev_items.append(next_item)
    prev_items_list_updated.append(prev_items)
train = train.with_columns(
    pl.Series(name="prev_items", values=prev_items_list_updated)
)

In [11]:
# explode
def explode(df:pl.DataFrame) -> pl.DataFrame:
    df = df.explode(["prev_items"])
    df = df.with_columns(
        df.select(pl.col("session_id").cumcount().over("session_id").alias("sequence_num"))
    )
    return df

train = explode(train)

In [12]:
# productのavailability付与
train = train.join(product.unique(subset="id")[["id", "available_locales"]], left_on="prev_items", right_on="id", how="left")

In [13]:
train.head()

prev_items,next_item,locale,session_id,sequence_num,available_locales
str,str,str,str,u32,list[str]
"""B09W9FND7K""","""B09M7GY217""","""DE""","""train_0""",0,"[""DE"", ""UK"", ""FR""]"
"""B09JSPLN1M""","""B09M7GY217""","""DE""","""train_0""",1,"[""DE"", ""UK"", ""FR""]"
"""B09M7GY217""","""B09M7GY217""","""DE""","""train_0""",2,"[""DE""]"
"""B076THCGSG""","""B001B4THSA""","""DE""","""train_1""",0,"[""DE""]"
"""B007MO8IME""","""B001B4THSA""","""DE""","""train_1""",1,"[""DE""]"


In [14]:
train_original.head()

prev_items,next_item,locale,session_id
list[str],str,str,str
"[""B09W9FND7K"", ""B09JSPLN1M""]","""B09M7GY217""","""DE""","""train_0"""
"[""B076THCGSG"", ""B007MO8IME"", … ""B001B4TKA0""]","""B001B4THSA""","""DE""","""train_1"""
"[""B0B1LGXWDS"", ""B00AZYORS2"", … ""B00AZYORS2""]","""B0767DTG2Q""","""DE""","""train_2"""
"[""B09XMTWDVT"", ""B0B4MZZ8MB"", … ""B0B71CHT1L""]","""B0B4R9NN4B""","""DE""","""train_3"""
"[""B09Y5CSL3T"", ""B09Y5DPTXN"", ""B09FKD61R8""]","""B0BGVBKWGZ""","""DE""","""train_4"""


In [15]:
session_id_lst = train.groupby("session_id", maintain_order=True).first()["session_id"].to_list()
available_locales_lst = train.groupby("session_id", maintain_order=True).all()["available_locales"].to_list()

In [16]:
dfs = []
for locale in tqdm(LOCALES):
    session_ids = []
    # セッション中のすべてのアイテム（next_item含む）について、そのlocaleで取り扱いのあるセッションを抽出
    for session_id, available_locales in zip(session_id_lst, available_locales_lst):
        if all([locale in available_locale for available_locale in available_locales]):
            session_ids.append(session_id)
    # データ抽出。他のマイナーなロケールのものは除く。
    df = train_original.filter(
        pl.col("session_id").is_in(session_ids) &
        pl.col("locale").is_in([locale, "UK", "JP", "DE"])
    )
    df = df.with_columns(pl.lit(locale).alias("locale"))
    dfs.append(df)
df = pl.concat(dfs)

100%|██████████| 3/3 [00:08<00:00,  2.97s/it]


In [17]:
df = df.unique(subset="session_id")

In [18]:
df["locale"].value_counts()

locale,counts
str,u32
"""IT""",136324
"""ES""",99896
"""FR""",148715


In [19]:
df.head()

prev_items,next_item,locale,session_id
list[str],str,str,str
"[""B00L6H2L5S"", ""B088XS76JB"", ""B07Q6VWDS6""]","""B00D3I07Z4""","""IT""","""train_3565199"""
"[""B071WKH2P7"", ""B07NJLM811""]","""B0733777R3""","""FR""","""train_3454996"""
"[""B08LK2HGWW"", ""B085M6RNRD"", … ""B00HZV9WTM""]","""B07NWT6YLD""","""IT""","""train_3489019"""
"[""B09MZBQLTW"", ""B00RTCUIX6""]","""B00J5ETUOY""","""IT""","""train_3574643"""
"[""B07TTMZR27"", ""B09WMZY3SB""]","""B07TSLQDSZ""","""FR""","""train_433450"""


In [20]:
df.write_parquet("/gdrive/MyDrive/amazon_kdd_2023/data/preprocessed/task2/train_task2_03.parquet")

In [24]:
len(df.filter(pl.col("locale").is_in(LOCALES))) / len(train_original.filter(pl.col("locale").is_in(LOCALES)))

1.1541136859021446