In [None]:
import warnings
import pandas as pd
import plotly.express as px

warnings.filterwarnings("ignore")

In [None]:
AW_MAIN_CLASSES = [
    "Furniture",
    "Lighting",
    "Home Textiles",
    "Tableware",
    "Decoration",
    "Flowers & Plants"
]
TEST_ROWS = 500
AW_DATA_FILEPATH = "../resources/item data 2026_AW.csv"
SS_DATA_FILEPATH = "../resources/item data 2026_SS.xlsx"

In [None]:
def train_test_split(df: pd.DataFrame):
    # fill na
    for col in ["main", "sub", "detail", "level4"]:
        df[col] = df[col].fillna("Unspecified")
    
    ratios = df["main"].value_counts(normalize=True).to_dict()

    df = df.sample(len(df)) # shuffle data
    test_df = pd.DataFrame()

    for main_class, ratio in ratios.items():
        new_df = df[df["main"] == main_class].sample(int(TEST_ROWS*ratio))
        test_df = pd.concat([test_df, new_df])

    if len(test_df) < TEST_ROWS:
        diff = TEST_ROWS - len(test_df)
        test_df = pd.concat([
            test_df,
            df[~(df["item_id"].isin(test_df["item_id"]))].sample(diff)
        ])

    train_df = df[~(df["item_id"].isin(test_df["item_id"]))]

    train_df["dataset"] = "train"
    test_df["dataset"] = "test"

    return test_df.reset_index(drop=True), train_df.reset_index(drop=True)

In [None]:
aw_df = pd.read_csv(AW_DATA_FILEPATH, sep=",")
aw_df = aw_df[aw_df["main"].isin(AW_MAIN_CLASSES)]

ss_df = pd.read_excel(SS_DATA_FILEPATH)

In [None]:
aw_df_test, aw_df_train = train_test_split(aw_df)
ss_df_test, ss_df_train = train_test_split(ss_df)

In [None]:
test_df = pd.concat([aw_df_test, ss_df_test]).reset_index(drop=True)
train_df = pd.concat([aw_df_train, ss_df_train]).reset_index(drop=True)

In [None]:
train_proportions = train_df[["season", "main"]].value_counts(normalize=True).reset_index().sort_values("proportion")
train_proportions["dataset"] = "train"

test_proportions = test_df[["season", "main"]].value_counts(normalize=True).reset_index().sort_values("proportion")
test_proportions["dataset"] = "test"

data = pd.concat([train_proportions, test_proportions])

fig = px.bar(
    data,
    orientation="h",
    x="proportion",
    y="main",
    color="season",
    facet_col="dataset",
    width=1000,
    height=700,
    title=f"Main class values proprotions",
    subtitle=f"train rows={len(train_df)}, test rows={len(test_df)}"
)
fig.show(renderer="notebook")

In [None]:
# train_df.to_csv("../resources/train_df.csv", index=False)
# test_df.to_csv("../resources/test_df.csv", index=False)