In [None]:
import polars as pl
from sklearn.model_selection import GroupKFold, StratifiedKFold, KFold
import numpy as np

In [None]:
# df = pl.read_csv("/home/user/work/output/exp052/run0/train.csv")
# df.group_by("fold").agg(pl.mean("QuestionId")).sort("fold")

In [None]:
# df.filter(pl.col("fold") == 0).select("QuestionId_Answer").sort("QuestionId_Answer")

In [None]:
def preprocess_table(df: pl.DataFrame, common_cols: list[str]) -> pl.DataFrame:
    long_df = (
        df.select(pl.col(common_cols + [f"Answer{alpha}Text" for alpha in ["A", "B", "C", "D"]]))
        .unpivot(
            index=common_cols,
            variable_name="AnswerType",
            value_name="AnswerText",
        )
        .with_columns(
            pl.col("AnswerType").str.extract(r"Answer([A-D])Text$").alias("AnswerAlphabet"),
        )
        .with_columns(
            pl.concat_str([pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_").alias("QuestionId_Answer"),
        )
        .sort("QuestionId_Answer")
    )
    # 問題-正解-不正解のペアを作る
    correct_df = (
        long_df.filter(pl.col("CorrectAnswer") == pl.col("AnswerAlphabet"))
        .select(["QuestionId", "AnswerAlphabet", "AnswerText"])
        .rename({"AnswerAlphabet": "CorrectAnswerAlphabet", "AnswerText": "CorrectAnswerText"})
    )
    long_df = (
        long_df.join(correct_df, on=["QuestionId"], how="left")
        .rename({"AnswerAlphabet": "InCorrectAnswerAlphabet", "AnswerText": "InCorrectAnswerText"})
        .filter(pl.col("InCorrectAnswerAlphabet") != pl.col("CorrectAnswerAlphabet"))
        .drop(["AnswerType", "CorrectAnswer"])
    )
    long_df = long_df.with_columns(
        pl.concat_str(
            [
                pl.lit("\n## Construct\n"),
                pl.col("ConstructName"),
                pl.lit("\n## Subject\n"),
                pl.col("SubjectName"),
                pl.lit("\n## Question\n"),
                pl.col("QuestionText"),
                pl.lit("\n## CorrectAnswer\n"),
                pl.col("CorrectAnswerText"),
                pl.lit("\n## InCorrectAnswer\n"),
                pl.col("InCorrectAnswerText"),
            ],
            separator="",
        ).alias("AllText")
    )
    return long_df


def preprocess_misconception(df: pl.DataFrame, common_cols: list[str]) -> pl.DataFrame:
    misconception = (
        df.select(pl.col(common_cols + [f"Misconception{alpha}Id" for alpha in ["A", "B", "C", "D"]]))
        .unpivot(
            index=common_cols,
            variable_name="MisconceptionType",
            value_name="MisconceptionId",
        )
        .with_columns(
            pl.col("MisconceptionType").str.extract(r"Misconception([A-D])Id$").alias("AnswerAlphabet"),
        )
        .with_columns(
            pl.concat_str([pl.col("QuestionId"), pl.col("AnswerAlphabet")], separator="_").alias("QuestionId_Answer"),
        )
        .sort("QuestionId_Answer")
        .select(pl.col(["QuestionId_Answer", "MisconceptionId"]))
        .with_columns(pl.col("MisconceptionId").cast(pl.Int64))
    )
    return misconception

def get_fold(_train: pl.DataFrame, cv: list[tuple[np.ndarray, np.ndarray]]) -> pl.DataFrame:
    """
    trainにfoldのcolumnを付与する
    """
    train = _train.clone()
    train = train.with_columns(pl.lit(-1).alias("fold"))
    for fold, (train_idx, valid_idx) in enumerate(cv):
        train = train.with_columns(
            pl.when(pl.col("index").is_in(valid_idx)).then(fold).otherwise(pl.col("fold")).alias("fold")
        )
    return train


def get_groupkfold(train: pl.DataFrame, group_col: str, n_splits: int, seed: int) -> pl.DataFrame:
    group_ids = train[group_col].unique()
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
    cv = []
    for train_idx, valid_idx in kf.split(X=group_ids):
        new_train_idx = train.filter(train[group_col].is_in(group_ids[train_idx])).select(pl.col("index")).to_numpy().flatten()
        new_valid_idx = train.filter(train[group_col].is_in(group_ids[valid_idx])).select(pl.col("index")).to_numpy().flatten()
        cv.append((new_train_idx, new_valid_idx))
    return get_fold(train, cv)


def get_stratifiedkfold(train: pl.DataFrame, target_col: str, n_splits: int, seed: int) -> pl.DataFrame:
    kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    cv = list(kf.split(X=train, y=train[target_col].to_numpy()))
    return get_fold(train, cv)

In [None]:
def add_fold(df: pl.DataFrame, split_rate:float=0.9, n_splits=5, seed:int=0):
    tmp = df.with_row_index()    
    df1 = tmp.sort("index")
    # df1 = tmp.sample(fraction=split_rate, shuffle=True, seed=seed)
    # df2 = tmp.filter(~pl.col("index").is_in(df1["index"]))
    df1 = get_groupkfold(df1, group_col="MisconceptionId", n_splits=n_splits, seed=seed)
    # if len(df2) > 0:
    #     df2 = get_stratifiedkfold(df2, target_col="MisconceptionId", n_splits=n_splits, seed=seed)
    #     all_df = pl.concat([df1, df2])
    # else:
    all_df = df1
    train = all_df.filter(pl.col("fold") != 0)
    valid = all_df.filter(pl.col("fold") == 0)

    train_misconception_ids = train["MisconceptionId"].to_list()
    valid_misconception_ids = valid["MisconceptionId"].to_list()
    unseen_misconceotion_ids = list(set(valid_misconception_ids) - set(train_misconception_ids))
    unseen_valid_size = valid.filter(pl.col("MisconceptionId").is_in(unseen_misconceotion_ids)).shape[0]
    unseen_rate = unseen_valid_size / valid.shape[0]
    return all_df.drop("index").sort("QuestionId_Answer")

In [None]:
train = pl.read_csv("/home/user/work/input/eedi-mining-misconceptions-in-mathematics/train.csv")
misconception_mapping = pl.read_csv("/home/user/work/input/eedi-mining-misconceptions-in-mathematics/misconception_mapping.csv")
common_cols = ["QuestionId", "ConstructName", "SubjectName", "QuestionText", "CorrectAnswer"]
df = preprocess_table(train, common_cols)
pp_misconception = preprocess_misconception(train, common_cols)
df = df.join(pp_misconception, on="QuestionId_Answer", how="inner")
df = df.filter(pl.col("MisconceptionId").is_not_null())

In [None]:
df = add_fold(df)

In [None]:
df.group_by("fold").agg(pl.mean("MisconceptionId")).sort("fold")

In [None]:
df.group_by("fold").agg(pl.mean("QuestionId")).sort("fold")