#### Importing required libraries

In [1]:
import polars as pl
import numpy as np
from sklearn.model_selection import train_test_split

#### Function definition

In [2]:
def prepare_dataset_with_content_based_cold_start(
    ratings: pl.DataFrame,
    items: pl.DataFrame,
    cold_start_joke_count: int = 10,
    test_size: float = 0.10,
    random_state: int = 42
):
    """
    Prepares a train-test split with a content-based cold-start scenario for jokes.

    Steps:
    1. Select a subset of existing jokes as cold-start jokes.
    2. Separate all ratings of these jokes into the test set.
    3. Perform a stratified random split on the remaining ratings to form the main train and test sets.
    4. Combine the cold-start ratings with the main test set.

    Args:
        ratings (pl.DataFrame): User-joke interaction data (userId, jokeId, rating).
        items (pl.DataFrame): Joke metadata (jokeId, jokeText, etc.).
        cold_start_joke_count (int): Number of jokes to treat as cold-start.
        test_size (float): Proportion of random test data (excluding cold-start jokes).
        random_state (int): Random seed for reproducibility.

    Returns:
        train_data (pl.DataFrame): Training set.
        test_data (pl.DataFrame): Test set (includes cold-start jokes).
        items (pl.DataFrame): Items dataset (unchanged), can be used with cold-start jokes.
    """

    # Step 1: Select Cold-Start Jokes
    # Shuffle joke IDs and pick the first 'cold_start_joke_count' as cold-start jokes
    joke_ids = items["jokeId"].to_list()
    np.random.seed(random_state)
    np.random.shuffle(joke_ids)
    cold_start_jokes = set(joke_ids[:cold_start_joke_count])

    # Extract cold-start ratings
    cold_start_ratings = ratings.filter(pl.col("jokeId").is_in(cold_start_jokes))

    # The remaining ratings for splitting
    remaining_ratings = ratings.filter(~pl.col("jokeId").is_in(cold_start_jokes))

    # Step 2: Exclude users with only one interaction in remaining_ratings
    user_interaction_counts = remaining_ratings.groupby("userId").count()
    valid_users = user_interaction_counts.filter(pl.col("count") > 1)["userId"]
    remaining_ratings_filtered = remaining_ratings.filter(pl.col("userId").is_in(valid_users))

    # Step 3: Perform a stratified random split on remaining_ratings_filtered
    # Convert to pandas for sklearn train_test_split
    remaining_pd = remaining_ratings_filtered.to_pandas()
    train_pd, test_pd_general = train_test_split(
        remaining_pd,
        test_size=test_size,
        random_state=random_state,
        stratify=remaining_pd["userId"]
    )

    train_data = pl.from_pandas(train_pd)
    test_data_general = pl.from_pandas(test_pd_general)

    # Step 4: Combine cold-start ratings with the general test set
    # cold_start_ratings represent all interactions of the cold-start jokes
    # Add them to the test data to simulate cold-start items
    test_data = pl.concat([test_data_general, cold_start_ratings])

    return train_data, test_data, items

#### Reading the data

In [3]:
ratings = pl.read_csv('../data/raw/jester_ratings.csv')  
ratings = ratings.drop_nulls()
items = pl.read_csv('../data/raw/jester_items.csv')    

#### Splitting the data

In [4]:
train_data, test_data, shuffled_items = prepare_dataset_with_content_based_cold_start(ratings, items, cold_start_joke_count=7)

#### Basic stats of obtained samples

In [5]:
train_data.describe()

describe,userId,jokeId,rating
str,f64,f64,f64
"""count""",323433.0,323433.0,323433.0
"""null_count""",0.0,0.0,0.0
"""mean""",6684.506238,71.466202,1.414847
"""std""",4521.82508,46.444096,5.622701
"""min""",1.0,5.0,-10.0
"""max""",14780.0,150.0,10.0
"""median""",6139.0,70.0,2.094
"""25%""",2501.0,22.0,-2.812
"""75%""",10757.0,113.0,5.938


In [6]:
test_data.describe()

describe,userId,jokeId,rating
str,f64,f64,f64
"""count""",61983.0,61983.0,61983.0
"""null_count""",0.0,0.0,0.0
"""mean""",6781.68622,62.076715,1.62009
"""std""",4500.700437,43.642258,5.581423
"""min""",1.0,5.0,-10.0
"""max""",14780.0,150.0,10.0
"""median""",6308.0,56.0,2.312
"""25%""",2594.0,19.0,-2.406
"""75%""",10821.0,103.0,6.094


In [7]:
len(shuffled_items)

150

#### Saving splitted data

In [8]:
train_data.write_parquet("../data/processed/train_data.parquet")
test_data.write_parquet("../data/processed/test_data.parquet")
shuffled_items.write_parquet("../data/processed/shuffled_jokes.parquet")