# INSTALL REQUIREMENTS

In [1]:
!pip install -r /kaggle/input/requirements1.txt

Collecting appnope==0.1.4 (from -r /kaggle/input/requirements1.txt (line 1))
  Downloading appnope-0.1.4-py2.py3-none-any.whl.metadata (908 bytes)
Collecting black==24.10.0 (from -r /kaggle/input/requirements1.txt (line 3))
  Downloading black-24.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl.metadata (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.2/79.2 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Collecting cramjam==2.9.0 (from -r /kaggle/input/requirements1.txt (line 8))
  Downloading cramjam-2.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)
Collecting debugpy==1.6.7 (from -r /kaggle/input/requirements1.txt (line 10))
  Downloading debugpy-1.6.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.1 kB)
Collecting decorator==5.1.1 (from -r /kaggle/input/requirements1.txt (line 11))
  Downloading decorator-5.1.1-py3-none-any.whl.metadata (4.0 kB)
Collecting executing

# DATA PROCESSING UTILS

In [2]:
import numpy as np


class StatelessRandomGenerator:
    def __init__(self, seed=42):
        self.seed = seed

    def set_seed(self, new_seed):
        self.seed = new_seed

    def random(self, size=None):
        rng = np.random.default_rng(self.seed)
        return rng.random(size)

    def integers(self, low, high=None, size=None):
        rng = np.random.default_rng(self.seed)
        return rng.integers(low, high, size)

    def choice(self, a, size=None, replace=True, p=None):
        rng = np.random.default_rng(self.seed)
        return rng.choice(a, size, replace, p)


global_rng = StatelessRandomGenerator(42)


def set_global_seed(new_seed):
    global_rng.set_seed(new_seed)

In [3]:
import torch


def wmape_metric(pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
    return torch.sum(torch.abs(pred - true), dim=0) / torch.sum(true, dim=0)

# DATA PROCESSING

In [4]:
from datetime import datetime
import json
from pathlib import Path
import polars as pl
#from data_processing.utils.stateless_rng import global_rng

def filter_purchases_purchases_per_month_pl(
    df_pl: pl.DataFrame, train_end: datetime.date, group_by_channel_id: bool = False
):
    """Filters extreme customers and groups purchases by date and optionally by sales channel.

    This function:
    1. Groups transactions by customer, date, and optionally sales channel
    2. Identifies extreme customers based on the 99th percentile of total items purchased
    3. Removes these customers from the dataset

    Args:
        df_pl (pl.DataFrame): Input transaction dataframe containing:
            - customer_id: Customer identifier
            - date: Transaction date
            - article_id: Product identifier
            - price: Transaction price
            - sales_channel_id: Sales channel identifier
        train_end (datetime.date): End date for training period.
        group_by_channel_id (bool, optional): Whether to group transactions by sales channel. Defaults to False.

    Returns:
        tuple[pl.DataFrame, pl.DataFrame]: Tuple containing:
            - grouped_df: Grouped transaction data with columns:
                - customer_id, date, [sales_channel_id], article_ids, total_price, prices, num_items
            - extreme_customers: DataFrame of customers identified as outliers based on purchase behavior

    Notes:
        Extreme customers are identified using the 99th percentile of total items purchased
        during the training period.
    """
    # Used for multi variate time series
    if group_by_channel_id:
        grouped_df = (
            df_pl.lazy()
            .group_by(["customer_id", "date", "sales_channel_id"])
            .agg(
                [
                    pl.col("article_id").explode().alias("article_ids"),
                    pl.col("price").sum().round(2).alias("total_price"),
                    pl.col("price").explode().alias("prices"),
                ]
            )
            .with_columns(pl.col("article_ids").list.len().alias("num_items"))
        )
    else:
        grouped_df = (
            df_pl.lazy()
            .group_by(["customer_id", "date"])
            .agg(
                [
                    pl.col("article_id").explode().alias("article_ids"),
                    pl.col("price").sum().round(2).alias("total_price"),
                    pl.col("sales_channel_id").explode().alias("sales_channel_ids"),
                    pl.col("price").explode().alias("prices"),
                ]
            )
            .with_columns(pl.col("article_ids").list.len().alias("num_items"))
        )

    # Only remove customers with extreme purchases in train period
    customers_summary = (
        df_pl.lazy()
        .filter(pl.col("date") < train_end)
        .group_by("customer_id")
        .agg(
            [
                pl.col("date").n_unique().alias("total_purchases"),
                pl.col("price").sum().round(2).alias("total_spent"),
                pl.col("article_id").flatten().alias("flattened_ids")
            ]
        )
        .with_columns(pl.col("flattened_ids").list.len().alias("total_items"))
    )

    quantile = 0.99
    total_purchases_99, total_spending_99, total_items_99 = (
        customers_summary.select(
            [
                pl.col("total_purchases").quantile(quantile),
                pl.col("total_spent").quantile(quantile),
                pl.col("total_items").quantile(quantile),
            ]
        )
        .collect()
        .to_numpy()
        .flatten()
    )

    # Currently only remove customers with very large number of total items purchased
    extreme_customers = customers_summary.filter(
        (pl.col("total_items") >= total_items_99)
        # | (pl.col("total_purchases") >= total_purchases_99)
        # | (pl.col("total_spent") >= total_spending_99)
    )

    extreme_customers = extreme_customers.select("customer_id").unique()
    extreme_customers = extreme_customers.collect()

    print(
        f"""
        Cutoff Values for {quantile*100}th Percentiles:
        -----------------------------------
        Total items bought:    {total_items_99:.0f} items

        -----------------------------------
        Removed Customers:     {len(extreme_customers):,}
        """
    )

    return grouped_df.collect(), extreme_customers

def train_test_split(
    train_df: pl.DataFrame,
    test_df: pl.DataFrame,
    subset: int = None,
    train_subsample_percentage: float = None,
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
    """Splits data into train, validation, and test sets with optional subsampling.

    The function performs the following operations:
    1. Optional subsampling of both train and test data
    2. Optional percentage-based subsampling of training data
    3. Creates a validation set from 10% of the training data

    Args:
        train_df (pl.DataFrame): Training dataset.
        test_df (pl.DataFrame): Test dataset.
        subset (int, optional): If provided, limits both train and test sets to first n rows. 
            Defaults to None.
        train_subsample_percentage (float, optional): If provided, randomly samples this percentage 
            of training data. Defaults to None.

    Returns:
        tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: Tuple containing:
            - train_df: Final training dataset (90% of training data after subsampling)
            - val_df: Validation dataset (10% of training data)
            - test_df: Test dataset (potentially subsampled)

    Notes:
        If both subset and train_subsample_percentage are provided, subset is applied first.
        The validation set is always 10% of the remaining training data after any subsampling.
    """

    if subset is not None:
        train_df = train_df[:subset]
        test_df = test_df[:subset]
    elif train_subsample_percentage is not None:
        sampled_indices = global_rng.choice(
            len(train_df),
            size=int(train_subsample_percentage * len(train_df)),
            replace=False,
        )
        train_df = train_df[sampled_indices]

    # Train-val-split
    # Calculate 10% of the length of the array
    sampled_indices = global_rng.choice(
        len(train_df), size=int(0.1 * len(train_df)), replace=False
    )
    val_df = train_df[sampled_indices]
    train_df = train_df.filter(~pl.arange(0, pl.count()).is_in(sampled_indices))

    return train_df, val_df, test_df

def map_article_ids(df: pl.DataFrame, data_path: Path) -> pl.DataFrame:
    """Maps article IDs to new running IDs using a mapping dictionary from JSON.

    Args:
        df (pl.DataFrame): DataFrame with 'article_id' column to be mapped.
        data_path (Path): Path to directory with 'running_id_dict.json' containing ID mappings.

    Returns:
        pl.DataFrame: DataFrame with mapped article IDs, sorted by new IDs. Non-mapped articles are removed.
    """
    with open(data_path / "running_id_dict.json", "r") as f:
        data = json.load(f)
    article_id_dict = data["combined"]

    mapping_df = pl.DataFrame(
        {
            "old_id": list(article_id_dict.keys()),
            "new_id": list(article_id_dict.values()),
        },
        schema_overrides={"old_id": pl.Int32, "new_id": pl.Int32},
    )

    # Join and select
    df = df.join(
        mapping_df, left_on="article_id", right_on="old_id", how="inner"
    ).select(
        pl.col("new_id").alias("article_id"),
        pl.all().exclude(["article_id", "old_id", "new_id"]),
    )
    df = df.sort("article_id")

    return df

In [5]:
#from pathlib import Path
#from data_processing.customer_df.customer_df import get_customer_df_benchmarks
#from data_processing.transaction_df.transaction_df import get_tx_article_dfs
import polars as pl


def expand_list_columns(
    df: pl.DataFrame, date_col: str = "days_before_lst", num_col: str = "num_items_lst"
) -> pl.DataFrame:
    """
    Expand a Polars DataFrame by repeating each element in a list column according to
    the counts specified in another list column.

    Args:
        df: Input Polars DataFrame with list columns
        date_col: Name of the column containing the lists to be expanded
        num_col: Name of the column containing lists of counts

    Returns:
        A new Polars DataFrame where the list elements in date_col have been expanded
    """
    expanded = df.with_columns(
        pl.struct([date_col, num_col])
        .map_elements(
            lambda x: [
                date
                for date, count in zip(x[date_col], x[num_col])
                for _ in range(count)
            ]
        )
        .alias(date_col)
    )

    return expanded


def add_benchmark_tx_features(df: pl.DataFrame) -> pl.DataFrame:
    """Creates benchmark transaction features from aggregated customer transaction data.

    Args:
        df: A Polars DataFrame containing aggregated transaction data with list columns
            including total_price_lst, num_items_lst, days_before_lst, price_lst,
            and CLV_label.

    Returns:
        pl.DataFrame: A DataFrame with derived features including:
            - total_spent: Sum of all transaction amounts
            - total_purchases: Count of transactions
            - total_items: Sum of items purchased
            - days_since_last_purchase: Days since most recent transaction
            - days_since_first_purchase: Days since first transaction
            - avg_spent_per_transaction: Mean transaction amount
            - avg_items_per_transaction: Mean items per transaction
            - avg_days_between: Mean days between transactions
            - regression_label: CLV label for regression
            - classification_label: Binary CLV label (>0)

    Note:
        The avg_days_between calculation may return None for customers with single
        transactions, which is handled by tree-based algorithms.
    """
    return df.select(
        "customer_id",
        pl.col("total_price_lst").list.sum().alias("total_spent"),
        pl.col("total_price_lst").list.len().alias("total_purchases"),
        pl.col("num_items_lst").list.sum().alias("total_items"),
        pl.col("days_before_lst").list.get(-1).alias("days_since_last_purchase"),
        pl.col("days_before_lst").list.get(0).alias("days_since_first_purchase"),
        pl.col("price_lst").list.mean().alias("avg_spent_per_transaction"),
        (
            pl.col("num_items_lst")
            .list.mean()
            .cast(pl.Float32)
            .alias("avg_items_per_transaction")
        ),
        # Code below returns None values for customers with single Tx
        # Tree algos should be able to handle this
        (
            pl.col("days_before_lst")
            .list.diff(null_behavior="drop")
            .list.mean()
            .mul(-1)
            .cast(pl.Float32)
            .alias("avg_days_between")
        ),
        pl.col("CLV_label").alias("regression_label"),
        pl.col("CLV_label").gt(0).cast(pl.Int32).alias("classification_label"),
    )


def process_dataframe(df: pl.DataFrame, max_length: int = 20) -> pl.DataFrame:
    """Processes a polars DataFrame by expanding list columns and selecting specific columns with transformations.

    This function performs several operations on the input DataFrame:
    1. Expands list columns using the expand_list_columns function
    2. Selects and renames specific columns
    3. Truncates list columns to a maximum length

    Args:
        df: A polars DataFrame containing customer transaction data
        max_length: Maximum number of elements to keep in list columns (default: 20)

    Returns:
        A processed polars DataFrame with the following columns:
            - customer_id: Customer identifier
            - days_before_lst: Truncated list of days before some reference date
            - articles_ids_lst: Truncated list of article identifiers
            - regression_label: CLV label for regression tasks
            - classification_label: Binary classification label derived from CLV
    """
    df = expand_list_columns(df, date_col="days_before_lst", num_col="num_items_lst")
    return df.select(
        "customer_id",
        "days_before_lst",
        "articles_ids_lst",
        pl.col("CLV_label").alias("regression_label"),
        pl.col("CLV_label").gt(0).cast(pl.Int32).alias("classification_label"),
    ).with_columns(
        pl.col("days_before_lst").list.tail(max_length),
        pl.col("articles_ids_lst").list.tail(max_length),
    )


def get_benchmark_dfs(
    data_path: Path, config: dict
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
    """Creates benchmark train, validation, and test datasets with transaction and customer features.

    Args:
        data_path: Path object pointing to the data directory
        config: Dictionary containing configuration parameters for data processing

    Returns:
        tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: A tuple containing:
            - train_df: Training dataset with benchmark features
            - val_df: Validation dataset with benchmark features
            - test_df: Test dataset with benchmark features

        Each DataFrame contains transaction-derived features joined with customer features.
    """
    train_article, val_article, test_article = get_tx_article_dfs(
        data_path=data_path,
        config=config,
        cols_to_aggregate=[
            "date",
            "days_before",
            "article_ids",
            "sales_channel_ids",
            "total_price",
            "prices",
            "num_items",
        ],
        keep_customer_id=True,
    )

    customer_df = get_customer_df_benchmarks(data_path=data_path, config=config)

    train_df = process_dataframe(
        df=train_article, max_length=config["max_length"]
    ).join(customer_df, on="customer_id", how="left")
    val_df = process_dataframe(df=val_article, max_length=config["max_length"]).join(
        customer_df, on="customer_id", how="left"
    )
    test_df = process_dataframe(df=test_article, max_length=config["max_length"]).join(
        customer_df, on="customer_id", how="left"
    )

    return train_df, val_df, test_df

In [6]:
import polars as pl
#from pathlib import Path


def get_customer_df_benchmarks(data_path: Path, config: dict):
    """Processes customer data with age grouping and zip code mapping.

    Args:
        data_path (Path): Path to directory containing 'customers.csv' and 'zip_code_count.csv'.
        config (dict): Configuration with 'min_zip_code_count'. Updated with 'num_age_groups' and 'num_zip_codes'.

    Returns:
        pl.DataFrame: Processed DataFrame with customer_id, age_group (0-6), and mapped zip_code_id.
    """
    file_path = data_path / "customers.csv"
    df = pl.scan_csv(file_path).select(
        (
            "customer_id",
            pl.col("age").fill_null(strategy="mean"),
            "postal_code",
        )
    )

    # df = df.with_columns(
    #     [
    #         pl.when(pl.col("age").is_null())
    #         .then(0)
    #         .when(pl.col("age") < 25)
    #         .then(1)
    #         .when(pl.col("age").is_between(25, 34))
    #         .then(2)
    #         .when(pl.col("age").is_between(35, 44))
    #         .then(3)
    #         .when(pl.col("age").is_between(45, 54))
    #         .then(4)
    #         .when(pl.col("age").is_between(55, 64))
    #         .then(5)
    #         .otherwise(6)
    #         .alias("age_group")
    #     ]
    # )
    # config["num_age_groups"] = 7

    return df.collect()

In [7]:
#from datetime import datetime
#from pathlib import Path
#import polars as pl

#from data_processing.utils.utils_transaction_df import (
 #   filter_purchases_purchases_per_month_pl,
  #  map_article_ids,
   # train_test_split,
#)


def generate_clv_data_pl(
    df: pl.DataFrame,
    agg_df: pl.DataFrame,
    label_threshold: datetime.date,
    pred_end: datetime.date,
    clv_periods: list,
    log_clv: bool = False,
):
    """Generates Customer Lifetime Value (CLV) data from transaction dataframe.

    Args:
        df (pl.DataFrame): Input transaction dataframe containing customer purchases.
        agg_df (pl.DataFrame): Aggregated dataframe containing customer data.
        label_threshold (datetime.date): Start date for CLV calculation period.
        pred_end (datetime.date): End date for CLV calculation period.
        clv_periods (list): List of periods for CLV calculation (currently supports single period only).
        log_clv (bool, optional): Whether to apply log1p transformation to CLV values. Defaults to False.

    Returns:
        pl.DataFrame: Aggregated dataframe with added CLV calculations.

    Raises:
        ValueError: If more than one CLV period is provided.
    """
    if len(clv_periods) > 1:
        raise ValueError("CLV periods should be a single number for now.")

    # Filter transactions between label_threshold and end_date for each period
    filtered_df = df.filter(
        (pl.col("date") >= label_threshold) & (pl.col("date") <= pred_end)
    )

    # Sum total_price for the filtered transactions by customer_id. This is the CLV
    summed_period_df = filtered_df.group_by("customer_id").agg(
        pl.sum("total_price").round(2).alias(f"CLV_label")
    )
    if log_clv:
        summed_period_df = summed_period_df.with_columns(
            pl.col(f"CLV_label").log1p().round(2).alias(f"CLV_label")
        )

    agg_df = agg_df.join(summed_period_df, on="customer_id", how="left")

    agg_df = agg_df.fill_null(0)
    return agg_df


def group_and_convert_df_pl(
    df: pl.DataFrame,
    label_start_date: datetime.date,
    pred_end: datetime.date,
    clv_periods: list,
    cols_to_aggregate: list = [
        "date",
        "days_before",
        "num_items",
        "article_ids",
        "sales_channel_ids",
        "total_price",
        "prices",
    ],
    keep_customer_id: bool = True,
    log_clv: bool = False,
) -> pl.DataFrame:
    """Groups and converts transaction data into aggregated customer-level features.

    Args:
        df (pl.DataFrame): Input transaction dataframe.
        label_start_date (datetime.date): Start date for clv label period.
        pred_end (datetime.date): End date for prediction period.
        clv_periods (list): List of periods for CLV calculation.
        cols_to_aggregate (list, optional): Columns to include in aggregation. Defaults to standard transaction columns.
        keep_customer_id (bool, optional): Whether to retain customer_id in output. Defaults to True.
        log_clv (bool, optional): Whether to apply log1p transformation to CLV values. Defaults to False.

    Returns:
        pl.DataFrame: Aggregated customer-level dataframe.

    Raises:
        ValueError: If required columns (days_before, article_ids, num_items) are missing from cols_to_aggregate.
    """

    if any(
        col not in cols_to_aggregate
        for col in ["days_before", "article_ids", "num_items"]
    ):
        raise ValueError(
            "The columns days_before, article_ids, and num_items are required "
            "for the aggregation"
        )

    mapping = {
        "date": "date_lst",
        "days_before": "days_before_lst",
        "article_ids": "articles_ids_lst",
        "sales_channel_ids": "sales_channel_id_lst",
        "total_price": "total_price_lst",
        "prices": "price_lst",
        "num_items": "num_items_lst",
    }

    agg_df = (
        df.filter(pl.col("date") < label_start_date)
        .with_columns(
            (label_start_date - pl.col("date"))
            .dt.total_days()
            .cast(pl.Int32)
            .alias("days_before"),
            (
                pl.col("sales_channel_ids")
                .cast(pl.List(pl.Int32))
                .alias("sales_channel_ids")
            ),
            pl.col("article_ids").cast(pl.List(pl.Int32)).alias("article_ids"),
        )
        .sort("customer_id", "date")
        .group_by("customer_id")
        .agg(
            pl.col("date").explode().alias("date_lst"),
            pl.col("days_before").explode().alias("days_before_lst"),
            pl.col("article_ids").explode().alias("articles_ids_lst"),
            pl.concat_list(pl.col("sales_channel_ids")).alias("sales_channel_id_lst"),
            pl.col("total_price").explode().alias("total_price_lst"),
            pl.col("prices").explode().alias("price_lst"),
            pl.col("num_items").explode().alias("num_items_lst"),
        )
    )

    if clv_periods is not None:
        agg_df = generate_clv_data_pl(
            df=df,
            agg_df=agg_df,
            label_threshold=label_start_date,
            pred_end=pred_end,
            clv_periods=clv_periods,
            log_clv=log_clv,
        )

    # Drop columns which are not to be aggregated
    cols_to_drop = [v for k, v in mapping.items() if k not in cols_to_aggregate]
    if not keep_customer_id:
        cols_to_drop.append("customer_id")
    agg_df = agg_df.drop(*cols_to_drop)

    return agg_df


def split_df_and_group_pl(
    df: pl.DataFrame,
    clv_periods: list,
    config: dict,
    cols_to_aggregate: list = [
        "date",
        "days_before",
        "article_ids",
        "sales_channel_ids",
        "total_price",
        "prices",
        "num_items",
    ],
    keep_customer_id: bool = True,
    log_clv: bool = False,
) -> tuple[pl.DataFrame, pl.DataFrame]:
    """Splits transaction data into training and test sets and performs aggregation.

    Args:
        df (pl.DataFrame): Input transaction dataframe.
        clv_periods (list): List of periods for CLV calculation.
        config (dict): Configuration dictionary containing:
        cols_to_aggregate (list, optional): Columns to include in aggregation. Defaults to standard transaction columns.
        keep_customer_id (bool, optional): Whether to retain customer_id in output. Defaults to True.
        log_clv (bool, optional): Whether to apply log1p transformation to CLV values. Defaults to False.

    Returns:
        tuple[pl.DataFrame, pl.DataFrame]: Tuple containing:
            - train_df: Aggregated training dataset
            - test_df: Aggregated test dataset
    """

    train_begin = datetime.strptime(config.get("train_begin"), "%Y-%m-%d")
    train_label_start = datetime.strptime(config.get("train_label_begin"), "%Y-%m-%d")
    train_end = datetime.strptime(config.get("train_end"), "%Y-%m-%d")
    test_begin = datetime.strptime(config.get("test_begin"), "%Y-%m-%d")
    test_label_start = datetime.strptime(config.get("test_label_begin"), "%Y-%m-%d")
    test_end = datetime.strptime(config.get("test_end"), "%Y-%m-%d")

    # Creating the training DataFrame by filtering dates up to `train_end`
    train_df = df.filter(
        (pl.col("date") <= train_end) & (pl.col("date") >= train_begin)
    )

    train_df = group_and_convert_df_pl(
        df=train_df,
        label_start_date=train_label_start,
        pred_end=train_end,
        clv_periods=clv_periods,
        cols_to_aggregate=cols_to_aggregate,
        keep_customer_id=keep_customer_id,
        log_clv=log_clv,
    )

    # Creating the test DataFrame by filtering dates after `test_begin`
    test_df = df.filter((pl.col("date") >= test_begin) & (pl.col("date") <= test_end))

    test_df = group_and_convert_df_pl(
        df=test_df,
        label_start_date=test_label_start,
        pred_end=test_end,
        clv_periods=clv_periods,
        cols_to_aggregate=cols_to_aggregate,
        keep_customer_id=keep_customer_id,
        log_clv=log_clv,
    )

    return train_df, test_df


def load_data_rem_outlier_pl(
    data_path: Path, train_end: datetime.date, group_by_channel_id: bool = False
):
    """Loads transaction data, applies price scaling, and removes outliers.

    Args:
        data_path (Path): Path to directory containing transaction data parquet file.
        train_end (datetime.date): End date for training period.
        group_by_channel_id (bool, optional): Whether to group data by sales channel ID. Defaults to False.

    Returns:
        tuple[pl.DataFrame, pl.DataFrame]: Tuple containing:
            - grouped_df: Processed transaction dataframe
            - extreme_customers: Dataframe of customers identified as outliers
    """
    file_path = data_path / "transactions_polars.parquet"
    df_pl = pl.read_parquet(file_path)

    df_pl = df_pl.with_columns(
        pl.col("t_dat").alias("date").cast(pl.Date), pl.col("article_id").cast(pl.Int32)
    )

    df_pl = df_pl.with_columns(
        pl.col("price").mul(590).cast(pl.Float32).round(2).alias("price")
    )

    # Map article ids to running ids so that they match with feature matrix
    df_pl = map_article_ids(df=df_pl, data_path=data_path)

    grouped_df, extreme_customers = filter_purchases_purchases_per_month_pl(
        df_pl, train_end=train_end, group_by_channel_id=group_by_channel_id
    )

    return grouped_df, extreme_customers


def get_customer_train_test_articles_pl(
    data_path: Path,
    config: dict,
    clv_periods: list = None,
    cols_to_aggregate: list = [
        "date",
        "days_before",
        "article_ids",
        "sales_channel_ids",
        "total_price",
        "prices",
        "num_items",
    ],
    keep_customer_id: bool = True,
):
    """Processes customer transaction data into train and test sets with article information.

    Args:
        data_path (Path): Path to directory containing transaction data.
        config (dict): Configuration dictionary for data processing parameters.
        clv_periods (list, optional): List of periods for CLV calculation. Defaults to None.
        cols_to_aggregate (list, optional): Columns to include in aggregation. Defaults to standard transaction columns.
        keep_customer_id (bool, optional): Whether to retain customer_id in output. Defaults to True.

    Returns:
        tuple[pl.DataFrame, pl.DataFrame]: Tuple containing:
            - train_df: Processed training dataset with article information
            - test_df: Processed test dataset with article information
    """
    train_end = datetime.strptime(config.get("train_end"), "%Y-%m-%d")
    grouped_df, extreme_customers = load_data_rem_outlier_pl(
        data_path=data_path, train_end=train_end
    )

    train_df, test_df = split_df_and_group_pl(
        df=grouped_df,
        clv_periods=clv_periods,
        config=config,
        cols_to_aggregate=cols_to_aggregate,
        keep_customer_id=True,
        log_clv=config.get("log_clv", False),
    )

    train_df = train_df.join(extreme_customers, on="customer_id", how="anti")
    test_df = test_df.join(extreme_customers, on="customer_id", how="anti")

    if not keep_customer_id:
        train_df = train_df.drop("customer_id")
        test_df = test_df.drop("customer_id")

    return train_df, test_df


def get_tx_article_dfs(
    data_path: Path,
    config: dict,
    cols_to_aggregate: list = [
        "date",
        "days_before",
        "article_ids",
        "sales_channel_ids",
        "total_price",
        "prices",
        "num_items",
    ],
    keep_customer_id: bool = True,
):
    """Creates train, validation, and test datasets with optional subsampling.

    Args:
        data_path (Path): Path to directory containing transaction data files.
        config (dict): Configuration dictionary containing:
        cols_to_aggregate (list, optional): Transaction columns to include in output.
        keep_customer_id (bool, optional): Whether to retain customer_id column.

    Returns:
        tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: Tuple containing:
            - train_df: Final training dataset (subset of original training data)
            - val_df: Validation dataset (10% of original training data)
            - test_df: Test dataset (optionally subsampled)
    """
    """
    Columns of dfs:
        - customer_id
        - date_lst (list[date]): Dates of each transaction
        - days_before_lst (list[int]): Number of days between start of prediction and date of transction
        - articles_ids_lst (list[int]): Flattened list of all items a customer purchased 
        - sales_channel_id_lst (list[list[int]]): Sales channel of a transaction (repeated for each item within a transaction)
        - total_price_lst (list[float]): Value of each transaction
        - price_lst (list[float]): Flattened list of prices of all items customer purchased
        - num_items_lst (list[int]): Number of items in each transaction
        - CLV_label (float): Sales in prediction period (label to be used)
    """
    train_df, test_df = get_customer_train_test_articles_pl(
        data_path=data_path,
        config=config,
        clv_periods=config.get("clv_periods", [6]),
        cols_to_aggregate=cols_to_aggregate,
        keep_customer_id=keep_customer_id,
    )
    train_df, val_df, test_df = train_test_split(
        train_df=train_df,
        test_df=test_df,
        subset=config.get("subset"),
        train_subsample_percentage=config.get("train_subsample_percentage"),
    )
    return train_df, val_df, test_df

In [8]:
#from pathlib import Path
#from data_processing.get_data import get_benchmark_dfs
#import polars as pl

In [9]:
config = {
    "train_begin": "2018-09-20",
    "train_label_begin": "2019-09-20",
    "train_end": "2020-03-17",
    "test_begin": "2019-03-19",
    "test_label_begin": "2020-03-18",
    "test_end": "2020-09-13",
    "min_zip_code_count": 3,
    "date_aggregation": "daily",
    "group_by_channel_id": False,
    "log_clv": False,
    "clv_periods": [6],
    "subset": None,
    "train_subsample_percentage": None,
    "max_length":20, # DEFINE HOW MANY ITEMS ARE TO BE CONSIDERED IN TRANSFORMER SEQUENCE
}
# data_path = Path("/kaggle/input/hm-dataset/data/data")
data_path = Path("/kaggle/input/data/data/")

print(10 * "#", " Loading data ", 10 * "#")
train_df, val_df, test_df = get_benchmark_dfs(data_path, config)

##########  Loading data  ##########

        Cutoff Values for 99.0th Percentiles:
        -----------------------------------
        Total items bought:    152 items

        -----------------------------------
        Removed Customers:     11,908
        


  train_df = train_df.filter(~pl.arange(0, pl.count()).is_in(sampled_indices))


In [10]:
test_df

customer_id,days_before_lst,articles_ids_lst,regression_label,classification_label,age,postal_code
str,list[i64],list[i32],f32,i32,i64,str
"""76a9e3e7690518…","[68, 68, … 29]","[9423, 9423, … 80293]",0.0,0,33,"""2c29ae653a9282…"
"""fd397bca991fb3…",[183],[15965],67.970001,1,27,"""6955a641e0ba93…"
"""09a10086a549fa…","[155, 110, … 110]","[37963, 33830, … 71079]",0.0,0,24,"""2c29ae653a9282…"
"""2ee66d660b16e0…","[289, 289, … 30]","[2271, 24744, … 83441]",0.0,0,44,"""0fec0086f1e39e…"
"""a6d717dde77801…","[171, 171, … 110]","[3455, 77717, … 85852]",58.470001,1,47,"""bc38656f2a2b19…"
…,…,…,…,…,…,…
"""a3b8ed7d0788ed…",[97],[68207],0.0,0,36,"""83d355719a39cc…"
"""c30cc75277dac2…","[320, 309, … 185]","[55901, 8, … 75628]",0.0,0,26,"""120c8f1333d624…"
"""31473511ded7a7…","[168, 168, … 66]","[44984, 44987, … 88701]",197.100006,1,50,"""2faae2f724d7a0…"
"""66fa41f4d5f338…","[228, 228, … 139]","[15674, 15689, … 2247]",148.949997,1,45,"""c8217776877a67…"


# BST TRAINING AND TESTING (FINAL VERSION)

In [17]:
import math
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Optional


# Custom Transformer classes

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-8) -> None:
        super().__init__()
        self.eps = eps
        self.gain = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return norm * self.gain

class CustomMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.dropout = dropout
      
    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        B, T, _ = query.size()
        qkv = self.in_proj(query)
        q, k, v = qkv.chunk(3, dim=-1)

        # Reshape for multi-head attention
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)
        attn_output = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=False,
        )

        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)
        output = self.out_proj(attn_output)
        return output

    def merge_masks(self,
                    attn_mask: Optional[torch.Tensor],
                    key_padding_mask: Optional[torch.Tensor],
                    query: torch.Tensor) -> Optional[torch.Tensor]:
        merged_mask = None
        batch_size, seq_len, _ = query.shape

        def convert_to_float_mask(mask):
            if mask.dtype == torch.bool:
                return mask.float().masked_fill(mask, float("-inf"))
            return mask

        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
                -1, self.num_heads, -1, -1)
            merged_mask = convert_to_float_mask(key_padding_mask)

        if attn_mask is not None:
            if attn_mask.dim() == 2:
                correct_2d_size = (seq_len, seq_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, "
                                       f"but should be {correct_2d_size}.")
                attn_mask = attn_mask.unsqueeze(0).expand(batch_size, self.num_heads, -1, -1)
            elif attn_mask.dim() == 3:
                correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, "
                                       f"but should be {correct_3d_size}.")
                attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)
            else:
                raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
            attn_mask = convert_to_float_mask(attn_mask)
            if merged_mask is None:
                merged_mask = attn_mask
            else:
                merged_mask = merged_mask + attn_mask
        return merged_mask

class TransformerEncoderLayer(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        embed_dim = config["embedding_dim"]
        num_heads = config.get("heads", 8)
        dropout = config["transformer_dropout"]
        dim_feedforward = config["dim_feedforward"]
        self.norm_first = config.get("norm_first", False)

        self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(embed_dim, dim_feedforward)
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, embed_dim)
        self.dropout2 = nn.Dropout(dropout)
        self.norm1 = RMSNorm(embed_dim)
        self.norm2 = RMSNorm(embed_dim)
        self.activation = nn.GELU()

    def _sa_block(self, src, attn_mask=None, key_padding_mask=None):
        src2 = self.self_attn(src, src, src,
                              key_padding_mask=key_padding_mask,
                              attn_mask=attn_mask)
        return self.dropout1(src2)

    def _ff_block(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        return self.dropout2(src2)

    def forward(self,
                src: torch.Tensor,
                src_key_padding_mask: torch.Tensor = None,
                src_mask: torch.Tensor = None):
        if self.norm_first:
            src = src + self._sa_block(self.norm1(src),
                                        attn_mask=src_mask,
                                        key_padding_mask=src_key_padding_mask)
            src = src + self._ff_block(self.norm2(src))
        else:
            src = self.norm1(src + self._sa_block(src,
                                                  attn_mask=src_mask,
                                                  key_padding_mask=src_key_padding_mask))
            src = self.norm2(src + self._ff_block(src))
        return src

class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = nn.ModuleList([
            TransformerEncoderLayer(config)
            for _ in range(config["num_transformer_layers"])
        ])

    def forward(self,
                src,
                src_key_padding_mask=None,
                src_mask=None):
        for layer in self.encoder:
            src = layer(src,
                        src_key_padding_mask=src_key_padding_mask,
                        src_mask=src_mask)
        return src


# Preparing vocabularies and the dataset (without customer_id)


def prepare_vocabularies(train_df, val_df, test_df):
    """
    1) Ensure each df is a pandas DataFrame.
    2) Builds dictionary for postal codes.
    3) Finds max article ID, max day, and max age.
    """
    def to_pandas_if_polars(df):
        return df.to_pandas() if not hasattr(df, "iloc") else df

    train_pd = to_pandas_if_polars(train_df)
    val_pd   = to_pandas_if_polars(val_df)
    test_pd  = to_pandas_if_polars(test_df)

    combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)

    unique_postals = combined['postal_code'].unique()
    postal2idx = {p: i for i, p in enumerate(unique_postals)}
    num_postal = len(postal2idx)

    all_articles = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['articles_ids_lst']:
            all_articles.extend(lst)
    max_article_id = max(all_articles)
    num_articles = max_article_id + 1

    all_days = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['days_before_lst']:
            all_days.extend(lst)
    max_day = max(all_days)

    max_age = combined['age'].max()
    num_age = max_age + 1

    return postal2idx, num_postal, num_articles, max_day, num_age

class CustomerDataset(Dataset):
    """
    Expects columns:
      - postal_code (str)
      - days_before_lst (list[int])
      - articles_ids_lst (list[int])
      - regression_label (float)
      - classification_label (int) (0 means churn, 1 means not churn)
      - age (int)
    Note: customer_id is no longer used.
    """
    def __init__(self, df, postal2idx: Dict[str, int]):
        if not hasattr(df, "iloc"):
            df = df.to_pandas()
        self.data = df
        self.postal2idx = postal2idx

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        postal_id = self.postal2idx[row['postal_code']]
        age = int(row['age'])
        articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)
        days = torch.tensor(row['days_before_lst'], dtype=torch.long)
        regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)
        classification_label = torch.tensor(int(row['classification_label']), dtype=torch.long)
        return (articles, days, age, postal_id, regression_label, classification_label)

# Custom collate function for variable-length sequences
def fixed_length_collate_fn(batch: list[tuple[torch.Tensor, torch.Tensor, int, int, torch.Tensor, torch.Tensor]], 
                           sequence_length: int = 8, padding_value: int = 0) -> tuple[torch.Tensor, ...]:
    """
    Efficiently pads sequences using PyTorch's pad_sequence and then truncates them to a fixed length.
    
    Args:
        batch: List of tuples where each tuple contains 
              (articles, days, age, postal_id, regression_label, classification_label)
        sequence_length: Desired length of the sequences
        padding_value: Value to use for padding sequences
    
    Returns:
        Tuple of tensors: 
          (article_seqs_tensor, day_seqs_tensor, ages_tensor, postal_ids_tensor, reg_labels_tensor, class_labels_tensor)
    """
    # Unpack the batch
    article_seqs, day_seqs, ages, postal_ids, reg_labels, class_labels = zip(*batch)
    
    # Use pad_sequence for efficient padding (batch_first gives shape: [B, L, ...])
    article_seqs_tensor = pad_sequence(article_seqs, batch_first=True, padding_value=padding_value)
    day_seqs_tensor = pad_sequence(day_seqs, batch_first=True, padding_value=padding_value)
    
    # Truncate padded tensors to the desired sequence length
    article_seqs_tensor = article_seqs_tensor[:, :sequence_length]
    day_seqs_tensor = day_seqs_tensor[:, :sequence_length]
    
    # Convert scalar values to tensors
    ages_tensor = torch.tensor(ages, dtype=torch.long)
    postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)
    reg_labels_tensor = torch.stack(reg_labels, dim=0)
    class_labels_tensor = torch.stack(class_labels, dim=0)
    
    return (
        article_seqs_tensor,
        day_seqs_tensor,
        ages_tensor,
        postal_ids_tensor,
        reg_labels_tensor,
        class_labels_tensor
    )

# BST model WITHOUT customer_id input

class BST(pl.LightningModule):
    def __init__(
        self,
        num_articles,
        max_day,
        num_age,
        num_postal,
        sequence_length,
        train_df,
        val_df,
        test_df,
        postal2idx,
        # embedding dims
        article_emb_dim=16,
        day_emb_dim=8,
        age_emb_dim=4,
        postal_emb_dim=4,
        # transformer config
        transformer_nhead=2,
        transformer_ff_dim=64,
        num_transformer_layers=1,
        # multi-task config
        predict_churn=False,  # Flag to enable/disable churn prediction
        # training
        learning_rate=0.0005
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['train_df','val_df','test_df','postal2idx'])
        self.learning_rate = learning_rate
        self.predict_churn = predict_churn

        # DataFrames and mapping
        self.train_df = train_df
        self.val_df   = val_df
        self.test_df  = test_df
        self.postal2idx = postal2idx

        # Embeddings (customer embedding removed)
        self.embeddings_age = nn.Embedding(num_age, age_emb_dim)
        self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)
        self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)
        self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)

        # Sequence features: concatenation of article and day embeddings
        self.seq_feature_dim = article_emb_dim + day_emb_dim

        # Custom Transformer setup
        config = {
            "embedding_dim": self.seq_feature_dim,
            "heads": transformer_nhead,
            "transformer_dropout": 0.2,
            "dim_feedforward": transformer_ff_dim,
            "norm_first": False,
            "num_transformer_layers": num_transformer_layers,
        }
        self.transformer = TransformerEncoder(config)
        self.transformer_output_dim = sequence_length * self.seq_feature_dim

        # User features: only age and postal embeddings are used
        user_feature_dim = age_emb_dim + postal_emb_dim

        combined_dim = self.transformer_output_dim + user_feature_dim

        # Separate heads for regression and (optional) classification
        self.regressor_head = nn.Sequential(
            nn.Linear(combined_dim, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1)
        )

        if self.predict_churn:
            self.classifier_head = nn.Sequential(
                nn.Linear(combined_dim, 128),
                nn.LeakyReLU(),
                nn.Linear(128, 1)  # single logit for binary classification
            )
            self.classification_criterion = nn.BCEWithLogitsLoss()

        self.regression_criterion = nn.MSELoss()

    def encode_input(self, batch):
        # Expected tuple: (articles, days, age, postal_id, regression_label, classification_label)
        articles, days, age, postal_id, regression_label, classification_label = batch

        article_embeds = self.embeddings_article(articles)  # (B, L, article_emb_dim)
        day_embeds = self.embeddings_day(days)              # (B, L, day_emb_dim)
        # Concatenate to form the sequence features; day_embeds serve as the time/position signal.
        transformer_input = torch.cat([article_embeds, day_embeds], dim=-1)  # (B, L, seq_feature_dim)

        transformer_output = self.transformer(transformer_input)  # (B, L, seq_feature_dim)
        transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)

        age_embed = self.embeddings_age(age)
        postal_embed = self.embeddings_postal(postal_id)
        user_features = torch.cat([age_embed, postal_embed], dim=1)

        combined_features = torch.cat([transformer_output_flat, user_features], dim=1)
        return combined_features, regression_label, classification_label

    def forward(self, batch):
        features, reg_label, class_label = self.encode_input(batch)
        reg_output = self.regressor_head(features).squeeze(dim=-1)
        if self.predict_churn:
            class_output = self.classifier_head(features).squeeze(dim=-1)
        else:
            class_output = None
        return reg_output, class_output, reg_label, class_label

    def training_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)
        reg_loss = self.regression_criterion(reg_output, reg_label)
        loss = reg_loss
        self.log("train_reg_loss", reg_loss)
        if self.predict_churn:
            class_loss = self.classification_criterion(class_output, class_label.float())
            self.log("train_class_loss", class_loss)
            loss = reg_loss + class_loss
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)
        reg_loss = self.regression_criterion(reg_output, reg_label)
        self.log("val_reg_loss", reg_loss)
        loss = reg_loss
        if self.predict_churn:
            class_loss = self.classification_criterion(class_output, class_label.float())
            self.log("val_class_loss", class_loss)
            loss = reg_loss + class_loss
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)
        reg_loss = self.regression_criterion(reg_output, reg_label)
        self.log("test_reg_loss", reg_loss)
        loss = reg_loss
        if self.predict_churn:
            class_loss = self.classification_criterion(class_output, class_label.float())
            self.log("test_class_loss", class_loss)
            loss = reg_loss + class_loss
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = CustomerDataset(self.train_df, self.postal2idx)
            self.val_dataset   = CustomerDataset(self.val_df, self.postal2idx)
        if stage == "test" or stage is None:
            self.test_dataset  = CustomerDataset(self.test_df, self.postal2idx)

    def train_dataloader(self):
     return DataLoader(
        self.train_dataset,
        batch_size=128,
        shuffle=True,
        num_workers=4,
        collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)
     )

    def val_dataloader(self):
     return DataLoader(
        self.val_dataset,
        batch_size=128,
        shuffle=False,
        num_workers=4,
        collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)
     )

    def test_dataloader(self):
     return DataLoader(
        self.test_dataset,
        batch_size=128,
        shuffle=False,
        num_workers=4,
        collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)
     )



# TRAIN AND TEST


#
# 1. for regression only
# model = BST(
#     num_articles=num_articles,
#     max_day=max_day,
#     num_age=num_age,
#     num_postal=num_postal,
#     sequence_length=sequence_length,
#     train_df=train_df,
#     val_df=val_df,
#     test_df=test_df,
#     postal2idx=postal2idx,
#     article_emb_dim=16,
#     day_emb_dim=8,
#     age_emb_dim=4,
#     postal_emb_dim=4,
#     transformer_nhead=2,
#     transformer_ff_dim=64,
#     num_transformer_layers=1,
#     predict_churn=False,
#     learning_rate=0.0005
# )


# 2. with churn
model = BST(
    #num_customers=num_customers,
    num_articles=num_articles,
    max_day=max_day,
    num_age=num_age,
    num_postal=num_postal,
    sequence_length=sequence_length,
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    #user2idx=user2idx,
    postal2idx=postal2idx,
    article_emb_dim=16,
    day_emb_dim=8,
    #customer_emb_dim=16,
    age_emb_dim=4,
    postal_emb_dim=4,
    transformer_nhead=2,
    transformer_ff_dim=64,
    num_transformer_layers=1,
    predict_churn=True,   # <--- Enable churn
    learning_rate=0.0005
)

trainer = pl.Trainer(accelerator="gpu", devices="auto", max_epochs=1)
trainer.fit(model)
trainer.test(model)



Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_reg_loss': 29194.90234375,
  'test_class_loss': 0.5889449119567871,
  'test_loss': 29195.490234375}]

# BASE VERSION

In [None]:
import math
import torch
import torch.nn as nn
import pytorch_lightning as pl
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict


def prepare_vocabularies(train_df, val_df, test_df):
    """
    1) Ensures each df is a pandas DataFrame (for easy indexing).
    2) Builds dictionaries to map string IDs (customer_id, postal_code) to integer indices.
    3) Finds max article ID, max day, and max age so we can define embedding sizes.
    """

    def to_pandas_if_polars(df):
        return df.to_pandas() if not hasattr(df, "iloc") else df

    train_pd = to_pandas_if_polars(train_df)
    val_pd   = to_pandas_if_polars(val_df)
    test_pd  = to_pandas_if_polars(test_df)

    # Combine for global vocabularies
    combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)

    # Map string-based customer_id -> integer
    unique_users = combined['customer_id'].unique()
    user2idx = {u: i for i, u in enumerate(unique_users)}
    num_customers = len(user2idx)

    # Map string-based postal_code -> integer
    unique_postals = combined['postal_code'].unique()
    postal2idx = {p: i for i, p in enumerate(unique_postals)}
    num_postal = len(postal2idx)

    # Determine max article ID
    all_articles = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['articles_ids_lst']:
            all_articles.extend(lst)  # 'lst' is a list of ints
    max_article_id = max(all_articles)
    num_articles = max_article_id + 1  # for embedding

    # Determine max day
    all_days = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['days_before_lst']:
            all_days.extend(lst)
    max_day = max(all_days)

    # Determine max age if we treat age as discrete
    max_age = combined['age'].max()
    num_age = max_age + 1

    return user2idx, postal2idx, num_customers, num_postal, num_articles, max_day, num_age

class CustomerDataset(Dataset):
    """
    Expects columns:
    - customer_id (str)
    - days_before_lst (list[int])
    - articles_ids_lst (list[int])
    - regression_label (float)
    - classification_label (int)  (not used here)
    - age (int)
    - postal_code (str)
    """
    def __init__(self, df, user2idx: Dict[str,int], postal2idx: Dict[str,int]):
        # Convert to Pandas if Polars
        if not hasattr(df, "iloc"):
            df = df.to_pandas()
        self.data = df

        self.user2idx = user2idx
        self.postal2idx = postal2idx

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Convert string-based IDs to integer indices
        user_id = self.user2idx[row['customer_id']]
        postal_id = self.postal2idx[row['postal_code']]

        age = int(row['age'])  # embedding or numeric

        # articles_ids_lst and days_before_lst are lists of ints
        articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)
        days = torch.tensor(row['days_before_lst'], dtype=torch.long)

        regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)

        return (
            user_id,
            articles,
            days,
            age,
            postal_id,
            regression_label
        )


# CUSTOM COLLATE FUNCTION FOR VARIABLE-LENGTH SEQUENCES


def fixed_length_collate_fn(batch, sequence_length=8):
    """
    Pads or truncates the 'articles' and 'days' sequences to 'sequence_length'.
    Each item in the batch is a tuple:
      (user_id, articles, days, age, postal_id, regression_label)
    """
    user_ids      = []
    article_seqs  = []
    day_seqs      = []
    ages          = []
    postal_ids    = []
    labels        = []

    # 1) Unpack
    for item in batch:
        (user_id, articles, days, age, postal_id, label) = item
        user_ids.append(user_id)
        article_seqs.append(articles)
        day_seqs.append(days)
        ages.append(age)
        postal_ids.append(postal_id)
        labels.append(label)

    # 2) Pad or truncate each sequence
    def pad_or_trunc(seq, desired_length):
        length = seq.size(0)
        if length > desired_length:
            return seq[:desired_length]
        elif length < desired_length:
            pad_size = desired_length - length
            pad = torch.zeros(pad_size, dtype=seq.dtype)
            return torch.cat([seq, pad], dim=0)
        else:
            return seq

    for i in range(len(article_seqs)):
        article_seqs[i] = pad_or_trunc(article_seqs[i], sequence_length)
        day_seqs[i] = pad_or_trunc(day_seqs[i], sequence_length)

    # 3) Stack everything
    user_ids_tensor = torch.tensor(user_ids, dtype=torch.long)
    article_seqs_tensor = torch.stack(article_seqs, dim=0)  # shape: (batch_size, sequence_length)
    day_seqs_tensor = torch.stack(day_seqs, dim=0)         # shape: (batch_size, sequence_length)
    ages_tensor = torch.tensor(ages, dtype=torch.long)
    postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)
    labels_tensor = torch.stack(labels, dim=0)  # shape: (batch_size,)

    return (
        user_ids_tensor,
        article_seqs_tensor,
        day_seqs_tensor,
        ages_tensor,
        postal_ids_tensor,
        labels_tensor
    )


# BST MODEL
class PositionalEmbedding(nn.Module):
    """
    Simple positional embedding that learns a unique embedding per position (0..max_len-1).
    """
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        # x: (batch_size, seq_length, d_model)
        batch_size, seq_length, _ = x.size()
        positions = torch.arange(seq_length, device=x.device).unsqueeze(0).expand(batch_size, seq_length)
        return self.pe(positions)  # (batch_size, seq_length, d_model)

class BST(pl.LightningModule):
    def __init__(
        self,
        num_customers,
        num_articles,
        max_day,
        num_age,
        num_postal,
        sequence_length,
        train_df,
        val_df,
        test_df,
        user2idx,
        postal2idx,
        article_emb_dim=16,
        day_emb_dim=8,
        customer_emb_dim=16,
        age_emb_dim=4,
        postal_emb_dim=4,
        transformer_nhead=2,
        learning_rate=0.0005
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['train_df','val_df','test_df','user2idx','postal2idx'])
        self.learning_rate = learning_rate

        # DataFrames + Mappings
        self.train_df = train_df
        self.val_df   = val_df
        self.test_df  = test_df
        self.user2idx = user2idx
        self.postal2idx = postal2idx

        # Embeddings
        self.embeddings_customer = nn.Embedding(num_customers, customer_emb_dim)
        self.embeddings_age = nn.Embedding(num_age, age_emb_dim)
        self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)

        self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)
        self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)

        # Sequence dimension
        self.seq_feature_dim = article_emb_dim + day_emb_dim
        self.positional_embedding = PositionalEmbedding(sequence_length, self.seq_feature_dim)

        # Transformer
        self.transformer_layer = nn.TransformerEncoderLayer(
            d_model=self.seq_feature_dim,
            nhead=transformer_nhead,
            dropout=0.2
        )

        # Flattened dimension after transformer
        transformer_output_dim = sequence_length * self.seq_feature_dim

        # User features dimension
        user_feature_dim = customer_emb_dim + age_emb_dim + postal_emb_dim

        # Combined dimension
        combined_dim = transformer_output_dim + user_feature_dim

        # Final regressor
        self.linear = nn.Sequential(
            nn.Linear(combined_dim, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1)
        )

        self.criterion = nn.MSELoss()

    def encode_input(self, batch):
        user_id, articles, days, age, postal_id, regression_label = batch

        # Sequence embeddings
        article_embeds = self.embeddings_article(articles)  # (B, L, article_emb_dim)
        day_embeds = self.embeddings_day(days)              # (B, L, day_emb_dim)
        sequence_features = torch.cat([article_embeds, day_embeds], dim=-1)  # (B, L, seq_feature_dim)

        # Positional embeddings
        pos_embeds = self.positional_embedding(sequence_features)
        transformer_input = sequence_features + pos_embeds

        # Transformer expects (L, B, d_model)
        transformer_input = transformer_input.transpose(0, 1)  # (L, B, seq_feature_dim)
        transformer_output = self.transformer_layer(transformer_input)
        transformer_output = transformer_output.transpose(0, 1)  # (B, L, seq_feature_dim)

        # Flatten
        transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)

        # User features
        customer_embed = self.embeddings_customer(user_id)
        age_embed = self.embeddings_age(age)
        postal_embed = self.embeddings_postal(postal_id)
        user_features = torch.cat([customer_embed, age_embed, postal_embed], dim=1)

        # Combine
        combined_features = torch.cat([transformer_output_flat, user_features], dim=1)
        return combined_features, regression_label

    def forward(self, batch):
        features, target = self.encode_input(batch)
        output = self.linear(features)
        return output.squeeze(), target

    def training_step(self, batch, batch_idx):
        output, target = self(batch)
        loss = self.criterion(output, target)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        output, target = self(batch)
        loss = self.criterion(output, target)
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        output, target = self(batch)
        loss = self.criterion(output, target)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = CustomerDataset(self.train_df, self.user2idx, self.postal2idx)
            self.val_dataset   = CustomerDataset(self.val_df,   self.user2idx, self.postal2idx)
        if stage == "test" or stage is None:
            self.test_dataset  = CustomerDataset(self.test_df,  self.user2idx, self.postal2idx)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=128,
            shuffle=True,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )


# TRAIN AND TEST


user2idx, postal2idx, num_customers, num_postal, num_articles, max_day, num_age = prepare_vocabularies(
     train_df, val_df, test_df
 )


sequence_length = 8


model = BST(
  num_customers=num_customers,
    num_articles=num_articles,
    max_day=max_day,
    num_age=num_age,
    num_postal=num_postal,
    sequence_length=sequence_length,
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    user2idx=user2idx,
    postal2idx=postal2idx,
    article_emb_dim=16,
    day_emb_dim=8,
    customer_emb_dim=16,
    age_emb_dim=4,
   postal_emb_dim=4,
    transformer_nhead=2,
    learning_rate=0.0005
)


trainer = pl.Trainer(accelerator="gpu", devices="auto", max_epochs=1)
trainer.fit(model)
trainer.test(model)

# USING ALTERNATIVE TRANSFORMER

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Optional


# Custom Transformer classes


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-8) -> None:
        super().__init__()
        self.eps = eps
        self.gain = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return norm * self.gain


class CustomMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.dropout = dropout
      
    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        B, T, _ = query.size()
        qkv = self.in_proj(query)
        q, k, v = qkv.chunk(3, dim=-1)

        # Reshape for multi-head
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Merge masks if present
        attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)

        # Use PyTorch's scaled dot-product attention
        attn_output = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=False,
        )

        # Reshape back
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)
        output = self.out_proj(attn_output)
        return output

    def merge_masks(
        self,
        attn_mask: Optional[torch.Tensor],
        key_padding_mask: Optional[torch.Tensor],
        query: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        merged_mask = None
        batch_size, seq_len, _ = query.shape

        def convert_to_float_mask(mask):
            if mask.dtype == torch.bool:
                return mask.float().masked_fill(mask, float("-inf"))
            return mask

        # key_padding_mask -> float mask
        if key_padding_mask is not None:
            # shape (B, T) -> (B, 1, 1, T) -> expand to (B, num_heads, 1, T)
            key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
                -1, self.num_heads, -1, -1
            )
            merged_mask = convert_to_float_mask(key_padding_mask)

        # attn_mask -> float mask
        if attn_mask is not None:
            if attn_mask.dim() == 2:
                # shape (T, T) -> (B, num_heads, T, T)
                correct_2d_size = (seq_len, seq_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(
                        f"The shape of the 2D attn_mask is {attn_mask.shape}, "
                        f"but should be {correct_2d_size}."
                    )
                attn_mask = attn_mask.unsqueeze(0).expand(
                    batch_size, self.num_heads, -1, -1
                )
            elif attn_mask.dim() == 3:
                # shape (B*num_heads, T, T) -> (B, num_heads, T, T)
                correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(
                        f"The shape of the 3D attn_mask is {attn_mask.shape}, "
                        f"but should be {correct_3d_size}."
                    )
                attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)
            else:
                raise RuntimeError(
                    f"attn_mask's dimension {attn_mask.dim()} is not supported"
                )

            attn_mask = convert_to_float_mask(attn_mask)

            if merged_mask is None:
                merged_mask = attn_mask
            else:
                merged_mask = merged_mask + attn_mask

        return merged_mask


class TransformerEncoderLayer(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        embed_dim = config["embedding_dim"]
        num_heads = config.get("heads", 8)
        dropout = config["transformer_dropout"]
        dim_feedforward = config["dim_feedforward"]
        self.norm_first = config.get("norm_first", False)

        self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(embed_dim, dim_feedforward)
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, embed_dim)
        self.dropout2 = nn.Dropout(dropout)

        self.norm1 = RMSNorm(embed_dim)
        self.norm2 = RMSNorm(embed_dim)
        self.activation = nn.GELU()

    def _sa_block(self, src, attn_mask=None, key_padding_mask=None):
        src2 = self.self_attn(
            src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask
        )
        return self.dropout1(src2)

    def _ff_block(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        return self.dropout2(src2)

    def forward(
        self,
        src: torch.Tensor,
        src_key_padding_mask: torch.Tensor = None,
        src_mask: torch.Tensor = None,
    ):
        if self.norm_first:
            # Pre-norm
            src = src + self._sa_block(
                self.norm1(src),
                attn_mask=src_mask,
                key_padding_mask=src_key_padding_mask,
            )
            src = src + self._ff_block(self.norm2(src))
        else:
            # Post-norm
            src = self.norm1(
                src
                + self._sa_block(
                    src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
                )
            )
            src = self.norm2(src + self._ff_block(src))
        return src


class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = nn.ModuleList(
            [
                TransformerEncoderLayer(config)
                for _ in range(config["num_transformer_layers"])
            ]
        )

    def forward(
        self,
        src,
        src_key_padding_mask=None,
        src_mask=None,
    ):
        """
        src: shape (B, T, E)
        """
        for layer in self.encoder:
            src = layer(
                src, src_key_padding_mask=src_key_padding_mask, src_mask=src_mask
            )
        return src



#  Preparing vocabularies and the dataset classes


def prepare_vocabularies(train_df, val_df, test_df):
    """
    1) Ensures each df is a pandas DataFrame (for easy indexing).
    2) Builds dictionaries to map string IDs (customer_id, postal_code) to integer indices.
    3) Finds max article ID, max day, and max age so we can define embedding sizes.
    """

    def to_pandas_if_polars(df):
        return df.to_pandas() if not hasattr(df, "iloc") else df

    train_pd = to_pandas_if_polars(train_df)
    val_pd   = to_pandas_if_polars(val_df)
    test_pd  = to_pandas_if_polars(test_df)

    # Combine for global vocabularies
    combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)

    # Map string-based customer_id -> integer
    unique_users = combined['customer_id'].unique()
    user2idx = {u: i for i, u in enumerate(unique_users)}
    num_customers = len(user2idx)

    # Map string-based postal_code -> integer
    unique_postals = combined['postal_code'].unique()
    postal2idx = {p: i for i, p in enumerate(unique_postals)}
    num_postal = len(postal2idx)

    # Determine max article ID
    all_articles = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['articles_ids_lst']:
            all_articles.extend(lst)  # 'lst' is a list of ints
    max_article_id = max(all_articles)
    num_articles = max_article_id + 1  # for embedding

    # Determine max day
    all_days = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['days_before_lst']:
            all_days.extend(lst)
    max_day = max(all_days)

    # Determine max age if we treat age as discrete
    max_age = combined['age'].max()
    num_age = max_age + 1

    return user2idx, postal2idx, num_customers, num_postal, num_articles, max_day, num_age


class CustomerDataset(Dataset):
    """
    Expects columns:
    - customer_id (str)
    - days_before_lst (list[int])
    - articles_ids_lst (list[int])
    - regression_label (float)
    - classification_label (int)  (not used here)
    - age (int)
    - postal_code (str)
    """
    def __init__(self, df, user2idx: Dict[str,int], postal2idx: Dict[str,int]):
        # Convert to Pandas if Polars
        if not hasattr(df, "iloc"):
            df = df.to_pandas()
        self.data = df

        self.user2idx = user2idx
        self.postal2idx = postal2idx

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Convert string-based IDs to integer indices
        user_id = self.user2idx[row['customer_id']]
        postal_id = self.postal2idx[row['postal_code']]

        age = int(row['age'])  # embedding or numeric

        # articles_ids_lst and days_before_lst are lists of ints
        articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)
        days = torch.tensor(row['days_before_lst'], dtype=torch.long)

        regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)

        return (
            user_id,
            articles,
            days,
            age,
            postal_id,
            regression_label
        )


# Custom collate function for variable-length sequences
def fixed_length_collate_fn(batch, sequence_length=8):
    """
    Pads or truncates the 'articles' and 'days' sequences to 'sequence_length'.
    Each item in the batch is a tuple:
      (user_id, articles, days, age, postal_id, regression_label)
    """
    user_ids      = []
    article_seqs  = []
    day_seqs      = []
    ages          = []
    postal_ids    = []
    labels        = []

    # 1) Unpack
    for item in batch:
        (user_id, articles, days, age, postal_id, label) = item
        user_ids.append(user_id)
        article_seqs.append(articles)
        day_seqs.append(days)
        ages.append(age)
        postal_ids.append(postal_id)
        labels.append(label)

    # 2) Pad or truncate each sequence
    def pad_or_trunc(seq, desired_length):
        length = seq.size(0)
        if length > desired_length:
            return seq[:desired_length]
        elif length < desired_length:
            pad_size = desired_length - length
            pad = torch.zeros(pad_size, dtype=seq.dtype)
            return torch.cat([seq, pad], dim=0)
        else:
            return seq

    for i in range(len(article_seqs)):
        article_seqs[i] = pad_or_trunc(article_seqs[i], sequence_length)
        day_seqs[i] = pad_or_trunc(day_seqs[i], sequence_length)

    # 3) Stack everything
    user_ids_tensor = torch.tensor(user_ids, dtype=torch.long)
    article_seqs_tensor = torch.stack(article_seqs, dim=0)  # shape: (batch_size, sequence_length)
    day_seqs_tensor = torch.stack(day_seqs, dim=0)         # shape: (batch_size, sequence_length)
    ages_tensor = torch.tensor(ages, dtype=torch.long)
    postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)
    labels_tensor = torch.stack(labels, dim=0)  # shape: (batch_size,)

    return (
        user_ids_tensor,
        article_seqs_tensor,
        day_seqs_tensor,
        ages_tensor,
        postal_ids_tensor,
        labels_tensor
    )



# BST model with the custom Transformer


class PositionalEmbedding(nn.Module):
    """
    Simple positional embedding that learns a unique embedding per position (0..max_len-1).
    """
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        # x: (batch_size, seq_length, d_model)
        batch_size, seq_length, _ = x.size()
        positions = torch.arange(seq_length, device=x.device).unsqueeze(0).expand(batch_size, seq_length)
        return self.pe(positions)  # (batch_size, seq_length, d_model)


class BST(pl.LightningModule):
    def __init__(
        self,
        num_customers,
        num_articles,
        max_day,
        num_age,
        num_postal,
        sequence_length,
        train_df,
        val_df,
        test_df,
        user2idx,
        postal2idx,
        article_emb_dim=16,
        day_emb_dim=8,
        customer_emb_dim=16,
        age_emb_dim=4,
        postal_emb_dim=4,
        transformer_nhead=2,
        transformer_ff_dim=64,     # <-- new hyperparam for the feed-forward layer
        num_transformer_layers=1,  # <-- how many layers in the custom transformer
        learning_rate=0.0005
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['train_df','val_df','test_df','user2idx','postal2idx'])
        self.learning_rate = learning_rate

        # DataFrames + Mappings
        self.train_df = train_df
        self.val_df   = val_df
        self.test_df  = test_df
        self.user2idx = user2idx
        self.postal2idx = postal2idx

        # Embeddings
        self.embeddings_customer = nn.Embedding(num_customers, customer_emb_dim)
        self.embeddings_age = nn.Embedding(num_age, age_emb_dim)
        self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)

        self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)
        self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)

        # Sequence dimension
        self.seq_feature_dim = article_emb_dim + day_emb_dim
        self.positional_embedding = PositionalEmbedding(sequence_length, self.seq_feature_dim)

        # -------------------------
        # Custom Transformer setup
        # -------------------------
        config = {
            "embedding_dim": self.seq_feature_dim,
            "heads": transformer_nhead,
            "transformer_dropout": 0.2,
            "dim_feedforward": transformer_ff_dim,
            "norm_first": False,
            "num_transformer_layers": num_transformer_layers,
        }
        self.transformer = TransformerEncoder(config)

        # Flattened dimension after transformer
        self.transformer_output_dim = sequence_length * self.seq_feature_dim

        # User features dimension
        user_feature_dim = customer_emb_dim + age_emb_dim + postal_emb_dim

        # Combined dimension
        combined_dim = self.transformer_output_dim + user_feature_dim

        # Final regressor
        self.linear = nn.Sequential(
            nn.Linear(combined_dim, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1)
        )

        self.criterion = nn.MSELoss()

    def encode_input(self, batch):
        user_id, articles, days, age, postal_id, regression_label = batch

        # Sequence embeddings
        article_embeds = self.embeddings_article(articles)  # (B, L, article_emb_dim)
        day_embeds = self.embeddings_day(days)              # (B, L, day_emb_dim)
        sequence_features = torch.cat([article_embeds, day_embeds], dim=-1)  # (B, L, seq_feature_dim)

        # Positional embeddings
        pos_embeds = self.positional_embedding(sequence_features)  # (B, L, seq_feature_dim)
        transformer_input = sequence_features + pos_embeds         # (B, L, seq_feature_dim)

        # Pass through our custom Transformer (B, L, d_model)
        transformer_output = self.transformer(transformer_input)   # (B, L, seq_feature_dim)

        # Flatten
        transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)

        # User features
        customer_embed = self.embeddings_customer(user_id)
        age_embed = self.embeddings_age(age)
        postal_embed = self.embeddings_postal(postal_id)
        user_features = torch.cat([customer_embed, age_embed, postal_embed], dim=1)

        # Combine
        combined_features = torch.cat([transformer_output_flat, user_features], dim=1)
        return combined_features, regression_label

    def forward(self, batch):
        features, target = self.encode_input(batch)
        output = self.linear(features)
        return output.squeeze(), target

    def training_step(self, batch, batch_idx):
        output, target = self(batch)
        loss = self.criterion(output, target)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        output, target = self(batch)
        loss = self.criterion(output, target)
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        output, target = self(batch)
        loss = self.criterion(output, target)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = CustomerDataset(self.train_df, self.user2idx, self.postal2idx)
            self.val_dataset   = CustomerDataset(self.val_df,   self.user2idx, self.postal2idx)
        if stage == "test" or stage is None:
            self.test_dataset  = CustomerDataset(self.test_df,  self.user2idx, self.postal2idx)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=128,
            shuffle=True,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )



# Example usage (Train and Test)


sequence_length = 8
model = BST(
    num_customers=num_customers,
    num_articles=num_articles,
    max_day=max_day,
    num_age=num_age,
    num_postal=num_postal,
    sequence_length=sequence_length,
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    user2idx=user2idx,
    postal2idx=postal2idx,
    article_emb_dim=16,
    day_emb_dim=8,
    customer_emb_dim=16,
    age_emb_dim=4,
    postal_emb_dim=4,
    transformer_nhead=2,
    transformer_ff_dim=64,
    num_transformer_layers=1,
    learning_rate=0.0005
 )

trainer = pl.Trainer(accelerator="gpu", devices="auto", max_epochs=1)
trainer.fit(model)
trainer.test(model)


# ADDING CHURN PREDICTION

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Optional


# Custom Transformer classes 

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-8) -> None:
        super().__init__()
        self.eps = eps
        self.gain = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return norm * self.gain

class CustomMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.dropout = dropout
      
    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        B, T, _ = query.size()
        qkv = self.in_proj(query)
        q, k, v = qkv.chunk(3, dim=-1)

        # Reshape for multi-head
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Merge masks if present
        attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)

        # Use PyTorch's scaled dot-product attention
        attn_output = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=False,
        )

        # Reshape back
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)
        output = self.out_proj(attn_output)
        return output

    def merge_masks(
        self,
        attn_mask: Optional[torch.Tensor],
        key_padding_mask: Optional[torch.Tensor],
        query: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        merged_mask = None
        batch_size, seq_len, _ = query.shape

        def convert_to_float_mask(mask):
            if mask.dtype == torch.bool:
                return mask.float().masked_fill(mask, float("-inf"))
            return mask

        # key_padding_mask -> float mask
        if key_padding_mask is not None:
            # shape (B, T) -> (B, 1, 1, T) -> expand to (B, num_heads, 1, T)
            key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
                -1, self.num_heads, -1, -1
            )
            merged_mask = convert_to_float_mask(key_padding_mask)

        # attn_mask -> float mask
        if attn_mask is not None:
            if attn_mask.dim() == 2:
                # shape (T, T) -> (B, num_heads, T, T)
                correct_2d_size = (seq_len, seq_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(
                        f"The shape of the 2D attn_mask is {attn_mask.shape}, "
                        f"but should be {correct_2d_size}."
                    )
                attn_mask = attn_mask.unsqueeze(0).expand(
                    batch_size, self.num_heads, -1, -1
                )
            elif attn_mask.dim() == 3:
                # shape (B*num_heads, T, T) -> (B, num_heads, T, T)
                correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(
                        f"The shape of the 3D attn_mask is {attn_mask.shape}, "
                        f"but should be {correct_3d_size}."
                    )
                attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)
            else:
                raise RuntimeError(
                    f"attn_mask's dimension {attn_mask.dim()} is not supported"
                )

            attn_mask = convert_to_float_mask(attn_mask)

            if merged_mask is None:
                merged_mask = attn_mask
            else:
                merged_mask = merged_mask + attn_mask

        return merged_mask


class TransformerEncoderLayer(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        embed_dim = config["embedding_dim"]
        num_heads = config.get("heads", 8)
        dropout = config["transformer_dropout"]
        dim_feedforward = config["dim_feedforward"]
        self.norm_first = config.get("norm_first", False)

        self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(embed_dim, dim_feedforward)
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, embed_dim)
        self.dropout2 = nn.Dropout(dropout)

        self.norm1 = RMSNorm(embed_dim)
        self.norm2 = RMSNorm(embed_dim)
        self.activation = nn.GELU()

    def _sa_block(self, src, attn_mask=None, key_padding_mask=None):
        src2 = self.self_attn(
            src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask
        )
        return self.dropout1(src2)

    def _ff_block(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        return self.dropout2(src2)

    def forward(
        self,
        src: torch.Tensor,
        src_key_padding_mask: torch.Tensor = None,
        src_mask: torch.Tensor = None,
    ):
        if self.norm_first:
            # Pre-norm
            src = src + self._sa_block(
                self.norm1(src),
                attn_mask=src_mask,
                key_padding_mask=src_key_padding_mask,
            )
            src = src + self._ff_block(self.norm2(src))
        else:
            # Post-norm
            src = self.norm1(
                src
                + self._sa_block(
                    src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
                )
            )
            src = self.norm2(src + self._ff_block(src))
        return src


class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = nn.ModuleList(
            [
                TransformerEncoderLayer(config)
                for _ in range(config["num_transformer_layers"])
            ]
        )

    def forward(
        self,
        src,
        src_key_padding_mask=None,
        src_mask=None,
    ):
        """
        src: shape (B, T, E)
        """
        for layer in self.encoder:
            src = layer(
                src, src_key_padding_mask=src_key_padding_mask, src_mask=src_mask
            )
        return src


# Preparing vocabularies and the Dataset (with churn)


def prepare_vocabularies(train_df, val_df, test_df):
    """
    1) Ensures each df is a pandas DataFrame (for easy indexing).
    2) Builds dictionaries to map string IDs (customer_id, postal_code) to integer indices.
    3) Finds max article ID, max day, and max age so we can define embedding sizes.
    """

    def to_pandas_if_polars(df):
        return df.to_pandas() if not hasattr(df, "iloc") else df

    train_pd = to_pandas_if_polars(train_df)
    val_pd   = to_pandas_if_polars(val_df)
    test_pd  = to_pandas_if_polars(test_df)

    # Combine for global vocabularies
    combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)

    # Map string-based customer_id -> integer
    unique_users = combined['customer_id'].unique()
    user2idx = {u: i for i, u in enumerate(unique_users)}
    num_customers = len(user2idx)

    # Map string-based postal_code -> integer
    unique_postals = combined['postal_code'].unique()
    postal2idx = {p: i for i, p in enumerate(unique_postals)}
    num_postal = len(postal2idx)

    # Determine max article ID
    all_articles = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['articles_ids_lst']:
            all_articles.extend(lst)  # 'lst' is a list of ints
    max_article_id = max(all_articles)
    num_articles = max_article_id + 1  # for embedding

    # Determine max day
    all_days = []
    for df_pd in [train_pd, val_pd, test_df]:
        for lst in df_pd['days_before_lst']:
            all_days.extend(lst)
    max_day = max(all_days)

    # Determine max age if we treat age as discrete
    max_age = combined['age'].max()
    num_age = max_age + 1

    return user2idx, postal2idx, num_customers, num_postal, num_articles, max_day, num_age


class CustomerDataset(Dataset):
    """
    Expects columns:
    - customer_id (str)
    - days_before_lst (list[int])
    - articles_ids_lst (list[int])
    - regression_label (float)
    - classification_label (int)  (0 means churn, 1 means not churn)
    - age (int)
    - postal_code (str)
    """
    def __init__(self, df, user2idx: Dict[str,int], postal2idx: Dict[str,int]):
        # Convert to Pandas if Polars
        if not hasattr(df, "iloc"):
            df = df.to_pandas()
        self.data = df

        self.user2idx = user2idx
        self.postal2idx = postal2idx

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Convert string-based IDs to integer indices
        user_id = self.user2idx[row['customer_id']]
        postal_id = self.postal2idx[row['postal_code']]

        age = int(row['age'])  # embedding or numeric

        # articles_ids_lst and days_before_lst are lists of ints
        articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)
        days = torch.tensor(row['days_before_lst'], dtype=torch.long)

        regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)
        classification_label = torch.tensor(int(row['classification_label']), dtype=torch.long)

        # Return 7 items now
        return (
            user_id,
            articles,
            days,
            age,
            postal_id,
            regression_label,
            classification_label
        )


# Custom collate function for variable-length sequences
def fixed_length_collate_fn(batch, sequence_length=8):
    """
    Pads or truncates the 'articles' and 'days' sequences to 'sequence_length'.
    Each item in the batch is a tuple:
      (user_id, articles, days, age, postal_id, regression_label, classification_label)
    """
    user_ids      = []
    article_seqs  = []
    day_seqs      = []
    ages          = []
    postal_ids    = []
    reg_labels    = []
    class_labels  = []

    # 1) Unpack
    for item in batch:
        (user_id, articles, days, age, postal_id, reg_label, cls_label) = item
        user_ids.append(user_id)
        article_seqs.append(articles)
        day_seqs.append(days)
        ages.append(age)
        postal_ids.append(postal_id)
        reg_labels.append(reg_label)
        class_labels.append(cls_label)

    # 2) Pad or truncate each sequence
    def pad_or_trunc(seq, desired_length):
        length = seq.size(0)
        if length > desired_length:
            return seq[:desired_length]
        elif length < desired_length:
            pad_size = desired_length - length
            pad = torch.zeros(pad_size, dtype=seq.dtype)
            return torch.cat([seq, pad], dim=0)
        else:
            return seq

    for i in range(len(article_seqs)):
        article_seqs[i] = pad_or_trunc(article_seqs[i], sequence_length)
        day_seqs[i] = pad_or_trunc(day_seqs[i], sequence_length)

    # 3) Stack everything
    user_ids_tensor = torch.tensor(user_ids, dtype=torch.long)
    article_seqs_tensor = torch.stack(article_seqs, dim=0)  # shape: (batch_size, sequence_length)
    day_seqs_tensor = torch.stack(day_seqs, dim=0)         # shape: (batch_size, sequence_length)
    ages_tensor = torch.tensor(ages, dtype=torch.long)
    postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)
    reg_labels_tensor = torch.stack(reg_labels, dim=0)  # shape: (batch_size,)
    class_labels_tensor = torch.stack(class_labels, dim=0)  # shape: (batch_size,)

    return (
        user_ids_tensor,
        article_seqs_tensor,
        day_seqs_tensor,
        ages_tensor,
        postal_ids_tensor,
        reg_labels_tensor,
        class_labels_tensor
    )



# BST model with optional Churn Head


class PositionalEmbedding(nn.Module):
    """
    Simple positional embedding that learns a unique embedding per position (0..max_len-1).
    """
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        # x: (batch_size, seq_length, d_model)
        batch_size, seq_length, _ = x.size()
        positions = torch.arange(seq_length, device=x.device).unsqueeze(0).expand(batch_size, seq_length)
        return self.pe(positions)  # (batch_size, seq_length, d_model)


class BST(pl.LightningModule):
    def __init__(
        self,
        num_customers,
        num_articles,
        max_day,
        num_age,
        num_postal,
        sequence_length,
        train_df,
        val_df,
        test_df,
        user2idx,
        postal2idx,
        # embedding dims
        article_emb_dim=16,
        day_emb_dim=8,
        customer_emb_dim=16,
        age_emb_dim=4,
        postal_emb_dim=4,
        # transformer config
        transformer_nhead=2,
        transformer_ff_dim=64,
        num_transformer_layers=1,
        # multi-task config
        predict_churn=False,  # <-- Flag to enable or disable churn prediction
        # training
        learning_rate=0.0005
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['train_df','val_df','test_df','user2idx','postal2idx'])
        self.learning_rate = learning_rate
        self.predict_churn = predict_churn  # store the flag

        # DataFrames + Mappings
        self.train_df = train_df
        self.val_df   = val_df
        self.test_df  = test_df
        self.user2idx = user2idx
        self.postal2idx = postal2idx

        # Embeddings
        self.embeddings_customer = nn.Embedding(num_customers, customer_emb_dim)
        self.embeddings_age = nn.Embedding(num_age, age_emb_dim)
        self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)

        self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)
        self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)

        # Sequence dimension
        self.seq_feature_dim = article_emb_dim + day_emb_dim
        self.positional_embedding = PositionalEmbedding(sequence_length, self.seq_feature_dim)

        # -------------------------
        # Custom Transformer setup
        # -------------------------
        config = {
            "embedding_dim": self.seq_feature_dim,
            "heads": transformer_nhead,
            "transformer_dropout": 0.2,
            "dim_feedforward": transformer_ff_dim,
            "norm_first": False,
            "num_transformer_layers": num_transformer_layers,
        }
        self.transformer = TransformerEncoder(config)

        # Flattened dimension after transformer
        self.transformer_output_dim = sequence_length * self.seq_feature_dim

        # User features dimension
        user_feature_dim = customer_emb_dim + age_emb_dim + postal_emb_dim

        # Combined dimension
        combined_dim = self.transformer_output_dim + user_feature_dim

        # -------------------------
        # Separate heads:
        #   1) Regression (always)
        #   2) Classification (optional)
        # -------------------------
        self.regressor_head = nn.Sequential(
            nn.Linear(combined_dim, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1)
        )

        if self.predict_churn:
            # Binary classification head (churn=0, not-churn=1)
            self.classifier_head = nn.Sequential(
                nn.Linear(combined_dim, 128),
                nn.LeakyReLU(),
                nn.Linear(128, 1)  # single logit for BCE
            )
            self.classification_criterion = nn.BCEWithLogitsLoss()

        # MSE for regression
        self.regression_criterion = nn.MSELoss()

    def encode_input(self, batch):
        # We now receive 7 items instead of 6
        (user_id, articles, days, age, postal_id, regression_label, classification_label) = batch

        # Sequence embeddings
        article_embeds = self.embeddings_article(articles)  # (B, L, article_emb_dim)
        day_embeds = self.embeddings_day(days)              # (B, L, day_emb_dim)
        sequence_features = torch.cat([article_embeds, day_embeds], dim=-1)  # (B, L, seq_feature_dim)

        # Positional embeddings
        pos_embeds = self.positional_embedding(sequence_features)  # (B, L, seq_feature_dim)
        transformer_input = sequence_features + pos_embeds         # (B, L, seq_feature_dim)

        # Pass through our custom Transformer (B, L, d_model)
        transformer_output = self.transformer(transformer_input)   # (B, L, seq_feature_dim)

        # Flatten
        transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)

        # User features
        customer_embed = self.embeddings_customer(user_id)
        age_embed = self.embeddings_age(age)
        postal_embed = self.embeddings_postal(postal_id)
        user_features = torch.cat([customer_embed, age_embed, postal_embed], dim=1)

        # Combine
        combined_features = torch.cat([transformer_output_flat, user_features], dim=1)
        return combined_features, regression_label, classification_label

    def forward(self, batch):
        features, reg_label, class_label = self.encode_input(batch)

        # 1) Regression output
        reg_output = self.regressor_head(features).squeeze(dim=-1)

        # 2) Classification output (only if predict_churn=True)
        if self.predict_churn:
            class_output = self.classifier_head(features).squeeze(dim=-1)  # logit
        else:
            class_output = None

        return reg_output, class_output, reg_label, class_label

    def training_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)

        # Always compute regression loss
        reg_loss = self.regression_criterion(reg_output, reg_label)
        loss = reg_loss  # default if churn is off

        self.log("train_reg_loss", reg_loss)

        # If churn is enabled, compute classification loss
        if self.predict_churn:
            # BCEWithLogitsLoss expects float targets of 0 or 1
            class_label = class_label.float()
            class_loss = self.classification_criterion(class_output, class_label)
            loss = reg_loss + class_loss  # simple combined loss
            self.log("train_class_loss", class_loss)

        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)

        # Regression
        reg_loss = self.regression_criterion(reg_output, reg_label)
        self.log("val_reg_loss", reg_loss)

        loss = reg_loss

        # Classification
        if self.predict_churn:
            class_label = class_label.float()
            class_loss = self.classification_criterion(class_output, class_label)
            self.log("val_class_loss", class_loss)
            loss = reg_loss + class_loss

        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)

        # Regression
        reg_loss = self.regression_criterion(reg_output, reg_label)
        self.log("test_reg_loss", reg_loss)

        loss = reg_loss

        # Classification
        if self.predict_churn:
            class_label = class_label.float()
            class_loss = self.classification_criterion(class_output, class_label)
            self.log("test_class_loss", class_loss)
            loss = reg_loss + class_loss

        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = CustomerDataset(self.train_df, self.user2idx, self.postal2idx)
            self.val_dataset   = CustomerDataset(self.val_df,   self.user2idx, self.postal2idx)
        if stage == "test" or stage is None:
            self.test_dataset  = CustomerDataset(self.test_df,  self.user2idx, self.postal2idx)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=128,
            shuffle=True,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )


# train and test

# do regression only
model = BST(
    num_customers=num_customers,
    num_articles=num_articles,
    max_day=max_day,
    num_age=num_age,
    num_postal=num_postal,
    sequence_length=sequence_length,
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    user2idx=user2idx,
    postal2idx=postal2idx,
    article_emb_dim=16,
    day_emb_dim=8,
    customer_emb_dim=16,
    age_emb_dim=4,
    postal_emb_dim=4,
    transformer_nhead=2,
    transformer_ff_dim=64,
    num_transformer_layers=1,
    predict_churn=False,  # <--- Disable churn
    learning_rate=0.0005
)



# also predict churn (multi-task: regression + classification):
model = BST(
    num_customers=num_customers,
    num_articles=num_articles,
    max_day=max_day,
    num_age=num_age,
    num_postal=num_postal,
    sequence_length=sequence_length,
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    user2idx=user2idx,
    postal2idx=postal2idx,
    article_emb_dim=16,
    day_emb_dim=8,
    customer_emb_dim=16,
    age_emb_dim=4,
    postal_emb_dim=4,
    transformer_nhead=2,
    transformer_ff_dim=64,
    num_transformer_layers=1,
    predict_churn=True,   # <--- Enable churn
    learning_rate=0.0005
)

trainer = pl.Trainer(accelerator="gpu", devices="auto", max_epochs=1)
trainer.fit(model)
trainer.test(model)



# REMOVED ADDITIVE POSITIONAL ENCODING

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Optional


# custom Transformer class

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-8) -> None:
        super().__init__()
        self.eps = eps
        self.gain = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return norm * self.gain

class CustomMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.dropout = dropout
      
    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        B, T, _ = query.size()
        qkv = self.in_proj(query)
        q, k, v = qkv.chunk(3, dim=-1)

        # Reshape for multi-head
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Merge masks if present
        attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)

        # Use PyTorch's scaled dot-product attention
        attn_output = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=False,
        )

        # Reshape back
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)
        output = self.out_proj(attn_output)
        return output

    def merge_masks(
        self,
        attn_mask: Optional[torch.Tensor],
        key_padding_mask: Optional[torch.Tensor],
        query: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        merged_mask = None
        batch_size, seq_len, _ = query.shape

        def convert_to_float_mask(mask):
            if mask.dtype == torch.bool:
                return mask.float().masked_fill(mask, float("-inf"))
            return mask

        # key_padding_mask -> float mask
        if key_padding_mask is not None:
            # shape (B, T) -> (B, 1, 1, T) -> expand to (B, num_heads, 1, T)
            key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
                -1, self.num_heads, -1, -1
            )
            merged_mask = convert_to_float_mask(key_padding_mask)

        # attn_mask -> float mask
        if attn_mask is not None:
            if attn_mask.dim() == 2:
                # shape (T, T) -> (B, num_heads, T, T)
                correct_2d_size = (seq_len, seq_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(
                        f"The shape of the 2D attn_mask is {attn_mask.shape}, "
                        f"but should be {correct_2d_size}."
                    )
                attn_mask = attn_mask.unsqueeze(0).expand(
                    batch_size, self.num_heads, -1, -1
                )
            elif attn_mask.dim() == 3:
                # shape (B*num_heads, T, T) -> (B, num_heads, T, T)
                correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(
                        f"The shape of the 3D attn_mask is {attn_mask.shape}, "
                        f"but should be {correct_3d_size}."
                    )
                attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)
            else:
                raise RuntimeError(
                    f"attn_mask's dimension {attn_mask.dim()} is not supported"
                )

            attn_mask = convert_to_float_mask(attn_mask)

            if merged_mask is None:
                merged_mask = attn_mask
            else:
                merged_mask = merged_mask + attn_mask

        return merged_mask


class TransformerEncoderLayer(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        embed_dim = config["embedding_dim"]
        num_heads = config.get("heads", 8)
        dropout = config["transformer_dropout"]
        dim_feedforward = config["dim_feedforward"]
        self.norm_first = config.get("norm_first", False)

        self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(embed_dim, dim_feedforward)
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, embed_dim)
        self.dropout2 = nn.Dropout(dropout)

        self.norm1 = RMSNorm(embed_dim)
        self.norm2 = RMSNorm(embed_dim)
        self.activation = nn.GELU()

    def _sa_block(self, src, attn_mask=None, key_padding_mask=None):
        src2 = self.self_attn(
            src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask
        )
        return self.dropout1(src2)

    def _ff_block(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        return self.dropout2(src2)

    def forward(
        self,
        src: torch.Tensor,
        src_key_padding_mask: torch.Tensor = None,
        src_mask: torch.Tensor = None,
    ):
        if self.norm_first:
            # Pre-norm
            src = src + self._sa_block(
                self.norm1(src),
                attn_mask=src_mask,
                key_padding_mask=src_key_padding_mask,
            )
            src = src + self._ff_block(self.norm2(src))
        else:
            # Post-norm
            src = self.norm1(
                src
                + self._sa_block(
                    src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
                )
            )
            src = self.norm2(src + self._ff_block(src))
        return src


class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = nn.ModuleList(
            [
                TransformerEncoderLayer(config)
                for _ in range(config["num_transformer_layers"])
            ]
        )

    def forward(
        self,
        src,
        src_key_padding_mask=None,
        src_mask=None,
    ):
        """
        src: shape (B, T, E)
        """
        for layer in self.encoder:
            src = layer(
                src, src_key_padding_mask=src_key_padding_mask, src_mask=src_mask
            )
        return src



# Preparing vocabularies and the Dataset (with churn)


def prepare_vocabularies(train_df, val_df, test_df):
    """
    1) Ensures each df is a pandas DataFrame (for easy indexing).
    2) Builds dictionaries to map string IDs (customer_id, postal_code) to integer indices.
    3) Finds max article ID, max day, and max age so we can define embedding sizes.
    """

    def to_pandas_if_polars(df):
        return df.to_pandas() if not hasattr(df, "iloc") else df

    train_pd = to_pandas_if_polars(train_df)
    val_pd   = to_pandas_if_polars(val_df)
    test_pd  = to_pandas_if_polars(test_df)

    # Combine for global vocabularies
    combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)

    # Map string-based customer_id -> integer
    unique_users = combined['customer_id'].unique()
    user2idx = {u: i for i, u in enumerate(unique_users)}
    num_customers = len(user2idx)

    # Map string-based postal_code -> integer
    unique_postals = combined['postal_code'].unique()
    postal2idx = {p: i for i, p in enumerate(unique_postals)}
    num_postal = len(postal2idx)

    # Determine max article ID
    all_articles = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['articles_ids_lst']:
            all_articles.extend(lst)  # 'lst' is a list of ints
    max_article_id = max(all_articles)
    num_articles = max_article_id + 1  # for embedding

    # Determine max day
    all_days = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['days_before_lst']:
            all_days.extend(lst)
    max_day = max(all_days)

    # Determine max age if we treat age as discrete
    max_age = combined['age'].max()
    num_age = max_age + 1

    return user2idx, postal2idx, num_customers, num_postal, num_articles, max_day, num_age


class CustomerDataset(Dataset):
    """
    Expects columns:
    - customer_id (str)
    - days_before_lst (list[int])
    - articles_ids_lst (list[int])
    - regression_label (float)
    - classification_label (int)  (0 means churn, 1 means not churn)
    - age (int)
    - postal_code (str)
    """
    def __init__(self, df, user2idx: Dict[str,int], postal2idx: Dict[str,int]):
        # Convert to Pandas if Polars
        if not hasattr(df, "iloc"):
            df = df.to_pandas()
        self.data = df

        self.user2idx = user2idx
        self.postal2idx = postal2idx

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Convert string-based IDs to integer indices
        user_id = self.user2idx[row['customer_id']]
        postal_id = self.postal2idx[row['postal_code']]

        age = int(row['age'])  # embedding or numeric

        # articles_ids_lst and days_before_lst are lists of ints
        articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)
        days = torch.tensor(row['days_before_lst'], dtype=torch.long)

        regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)
        classification_label = torch.tensor(int(row['classification_label']), dtype=torch.long)

        return (
            user_id,
            articles,
            days,
            age,
            postal_id,
            regression_label,
            classification_label
        )


# Custom collate function for variable-length sequences
def fixed_length_collate_fn(batch, sequence_length=8):
    """
    Pads or truncates the 'articles' and 'days' sequences to 'sequence_length'.
    Each item in the batch is a tuple:
      (user_id, articles, days, age, postal_id, regression_label, classification_label)
    """
    user_ids      = []
    article_seqs  = []
    day_seqs      = []
    ages          = []
    postal_ids    = []
    reg_labels    = []
    class_labels  = []

    # 1) Unpack
    for item in batch:
        (user_id, articles, days, age, postal_id, reg_label, cls_label) = item
        user_ids.append(user_id)
        article_seqs.append(articles)
        day_seqs.append(days)
        ages.append(age)
        postal_ids.append(postal_id)
        reg_labels.append(reg_label)
        class_labels.append(cls_label)

    # 2) Pad or truncate each sequence
    def pad_or_trunc(seq, desired_length):
        length = seq.size(0)
        if length > desired_length:
            return seq[:desired_length]
        elif length < desired_length:
            pad_size = desired_length - length
            pad = torch.zeros(pad_size, dtype=seq.dtype)
            return torch.cat([seq, pad], dim=0)
        else:
            return seq

    for i in range(len(article_seqs)):
        article_seqs[i] = pad_or_trunc(article_seqs[i], sequence_length)
        day_seqs[i] = pad_or_trunc(day_seqs[i], sequence_length)

    # 3) Stack everything
    user_ids_tensor = torch.tensor(user_ids, dtype=torch.long)
    article_seqs_tensor = torch.stack(article_seqs, dim=0)  # shape: (batch_size, sequence_length)
    day_seqs_tensor = torch.stack(day_seqs, dim=0)         # shape: (batch_size, sequence_length)
    ages_tensor = torch.tensor(ages, dtype=torch.long)
    postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)
    reg_labels_tensor = torch.stack(reg_labels, dim=0)  # shape: (batch_size,)
    class_labels_tensor = torch.stack(class_labels, dim=0)  # shape: (batch_size,)

    return (
        user_ids_tensor,
        article_seqs_tensor,
        day_seqs_tensor,
        ages_tensor,
        postal_ids_tensor,
        reg_labels_tensor,
        class_labels_tensor
    )



# BST model WITHOUT separate pos. embed


class BST(pl.LightningModule):
    def __init__(
        self,
        num_customers,
        num_articles,
        max_day,
        num_age,
        num_postal,
        sequence_length,
        train_df,
        val_df,
        test_df,
        user2idx,
        postal2idx,
        # embedding dims
        article_emb_dim=16,
        day_emb_dim=8,
        customer_emb_dim=16,
        age_emb_dim=4,
        postal_emb_dim=4,
        # transformer config
        transformer_nhead=2,
        transformer_ff_dim=64,
        num_transformer_layers=1,
        # multi-task config
        predict_churn=False,  # <-- Flag to enable or disable churn prediction
        # training
        learning_rate=0.0005
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['train_df','val_df','test_df','user2idx','postal2idx'])
        self.learning_rate = learning_rate
        self.predict_churn = predict_churn  # store the flag

        # DataFrames + Mappings
        self.train_df = train_df
        self.val_df   = val_df
        self.test_df  = test_df
        self.user2idx = user2idx
        self.postal2idx = postal2idx

        # Embeddings
        self.embeddings_customer = nn.Embedding(num_customers, customer_emb_dim)
        self.embeddings_age = nn.Embedding(num_age, age_emb_dim)
        self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)

        self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)
        self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)

        # We treat "day_embeds" as the positional/time feature. 
        # So final dimension of each timestep = article_emb_dim + day_emb_dim
        self.seq_feature_dim = article_emb_dim + day_emb_dim

        # -------------------------
        # Custom Transformer setup
        # -------------------------
        config = {
            "embedding_dim": self.seq_feature_dim,
            "heads": transformer_nhead,
            "transformer_dropout": 0.2,
            "dim_feedforward": transformer_ff_dim,
            "norm_first": False,
            "num_transformer_layers": num_transformer_layers,
        }
        self.transformer = TransformerEncoder(config)

        # Flattened dimension after transformer
        self.transformer_output_dim = sequence_length * self.seq_feature_dim

        # User features dimension
        user_feature_dim = customer_emb_dim + age_emb_dim + postal_emb_dim

        # Combined dimension
        combined_dim = self.transformer_output_dim + user_feature_dim

        # -------------------------
        # Separate heads:
        #   1) Regression (always)
        #   2) Classification (optional)
        # -------------------------
        self.regressor_head = nn.Sequential(
            nn.Linear(combined_dim, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1)
        )

        if self.predict_churn:
            # Binary classification head (churn=0, not-churn=1)
            self.classifier_head = nn.Sequential(
                nn.Linear(combined_dim, 128),
                nn.LeakyReLU(),
                nn.Linear(128, 1)  # single logit for BCE
            )
            self.classification_criterion = nn.BCEWithLogitsLoss()

        # MSE for regression
        self.regression_criterion = nn.MSELoss()

    def encode_input(self, batch):
        """
        Returns:
          combined_features: (B, combined_dim)
          regression_label:  (B,)
          classification_label: (B,)
        """
        (user_id, articles, days, age, postal_id, regression_label, classification_label) = batch

        # Sequence embeddings
        article_embeds = self.embeddings_article(articles)  # (B, L, article_emb_dim)
        day_embeds = self.embeddings_day(days)              # (B, L, day_emb_dim)

        # No separate positional embedding. 
        # We treat 'days' as our time/pos feature.
        transformer_input = torch.cat([article_embeds, day_embeds], dim=-1)  
        # shape: (B, L, seq_feature_dim)

        # Pass through our custom Transformer
        transformer_output = self.transformer(transformer_input)   
        # shape: (B, L, seq_feature_dim)

        # Flatten the sequence dimension
        transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)

        # User-level features
        customer_embed = self.embeddings_customer(user_id)
        age_embed = self.embeddings_age(age)
        postal_embed = self.embeddings_postal(postal_id)
        user_features = torch.cat([customer_embed, age_embed, postal_embed], dim=1)

        # Combine sequence output + user features
        combined_features = torch.cat([transformer_output_flat, user_features], dim=1)

        return combined_features, regression_label, classification_label

    def forward(self, batch):
        features, reg_label, class_label = self.encode_input(batch)

        # 1) Regression output
        reg_output = self.regressor_head(features).squeeze(dim=-1)

        # 2) Classification output (only if predict_churn=True)
        if self.predict_churn:
            class_output = self.classifier_head(features).squeeze(dim=-1)  # logit
        else:
            class_output = None

        return reg_output, class_output, reg_label, class_label

    def training_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)

        # Always compute regression loss
        reg_loss = self.regression_criterion(reg_output, reg_label)
        loss = reg_loss  # default if churn is off
        self.log("train_reg_loss", reg_loss)

        # If churn is enabled, compute classification loss
        if self.predict_churn:
            class_label = class_label.float()  # BCE expects float
            class_loss = self.classification_criterion(class_output, class_label)
            loss = reg_loss + class_loss  # combine them
            self.log("train_class_loss", class_loss)

        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)

        # Regression
        reg_loss = self.regression_criterion(reg_output, reg_label)
        self.log("val_reg_loss", reg_loss)
        loss = reg_loss

        # Classification
        if self.predict_churn:
            class_label = class_label.float()
            class_loss = self.classification_criterion(class_output, class_label)
            self.log("val_class_loss", class_loss)
            loss = reg_loss + class_loss

        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)

        # Regression
        reg_loss = self.regression_criterion(reg_output, reg_label)
        self.log("test_reg_loss", reg_loss)
        loss = reg_loss

        # Classification
        if self.predict_churn:
            class_label = class_label.float()
            class_loss = self.classification_criterion(class_output, class_label)
            self.log("test_class_loss", class_loss)
            loss = reg_loss + class_loss

        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = CustomerDataset(self.train_df, self.user2idx, self.postal2idx)
            self.val_dataset   = CustomerDataset(self.val_df,   self.user2idx, self.postal2idx)
        if stage == "test" or stage is None:
            self.test_dataset  = CustomerDataset(self.test_df,  self.user2idx, self.postal2idx)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=128,
            shuffle=True,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )



# Train and Test


model = BST(
    num_customers=num_customers,
    num_articles=num_articles,
    max_day=max_day,
    num_age=num_age,
    num_postal=num_postal,
    sequence_length=sequence_length,
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    user2idx=user2idx,
    postal2idx=postal2idx,
    article_emb_dim=16,
    day_emb_dim=8,
    customer_emb_dim=16,
    age_emb_dim=4,
    postal_emb_dim=4,
    transformer_nhead=2,
    transformer_ff_dim=64,
    num_transformer_layers=1,
    predict_churn=True,   # <--- Enable churn
    learning_rate=0.0005
)

trainer = pl.Trainer(accelerator="gpu", devices="auto", max_epochs=1)
trainer.fit(model)
trainer.test(model)



# REMOVED customer_id

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Optional


# Custom Transformer Classes

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-8) -> None:
        super().__init__()
        self.eps = eps
        self.gain = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return norm * self.gain

class CustomMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.dropout = dropout
      
    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        B, T, _ = query.size()
        qkv = self.in_proj(query)
        q, k, v = qkv.chunk(3, dim=-1)

        # Reshape for multi-head attention
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)
        attn_output = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=False,
        )

        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)
        output = self.out_proj(attn_output)
        return output

    def merge_masks(self,
                    attn_mask: Optional[torch.Tensor],
                    key_padding_mask: Optional[torch.Tensor],
                    query: torch.Tensor) -> Optional[torch.Tensor]:
        merged_mask = None
        batch_size, seq_len, _ = query.shape

        def convert_to_float_mask(mask):
            if mask.dtype == torch.bool:
                return mask.float().masked_fill(mask, float("-inf"))
            return mask

        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
                -1, self.num_heads, -1, -1)
            merged_mask = convert_to_float_mask(key_padding_mask)

        if attn_mask is not None:
            if attn_mask.dim() == 2:
                correct_2d_size = (seq_len, seq_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, "
                                       f"but should be {correct_2d_size}.")
                attn_mask = attn_mask.unsqueeze(0).expand(batch_size, self.num_heads, -1, -1)
            elif attn_mask.dim() == 3:
                correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, "
                                       f"but should be {correct_3d_size}.")
                attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)
            else:
                raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
            attn_mask = convert_to_float_mask(attn_mask)
            if merged_mask is None:
                merged_mask = attn_mask
            else:
                merged_mask = merged_mask + attn_mask
        return merged_mask

class TransformerEncoderLayer(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        embed_dim = config["embedding_dim"]
        num_heads = config.get("heads", 8)
        dropout = config["transformer_dropout"]
        dim_feedforward = config["dim_feedforward"]
        self.norm_first = config.get("norm_first", False)

        self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(embed_dim, dim_feedforward)
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, embed_dim)
        self.dropout2 = nn.Dropout(dropout)
        self.norm1 = RMSNorm(embed_dim)
        self.norm2 = RMSNorm(embed_dim)
        self.activation = nn.GELU()

    def _sa_block(self, src, attn_mask=None, key_padding_mask=None):
        src2 = self.self_attn(src, src, src,
                              key_padding_mask=key_padding_mask,
                              attn_mask=attn_mask)
        return self.dropout1(src2)

    def _ff_block(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        return self.dropout2(src2)

    def forward(self,
                src: torch.Tensor,
                src_key_padding_mask: torch.Tensor = None,
                src_mask: torch.Tensor = None):
        if self.norm_first:
            src = src + self._sa_block(self.norm1(src),
                                        attn_mask=src_mask,
                                        key_padding_mask=src_key_padding_mask)
            src = src + self._ff_block(self.norm2(src))
        else:
            src = self.norm1(src + self._sa_block(src,
                                                  attn_mask=src_mask,
                                                  key_padding_mask=src_key_padding_mask))
            src = self.norm2(src + self._ff_block(src))
        return src

class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = nn.ModuleList([
            TransformerEncoderLayer(config)
            for _ in range(config["num_transformer_layers"])
        ])

    def forward(self,
                src,
                src_key_padding_mask=None,
                src_mask=None):
        for layer in self.encoder:
            src = layer(src,
                        src_key_padding_mask=src_key_padding_mask,
                        src_mask=src_mask)
        return src


# Preparing vocabularies and the dataset (without customer_id)


def prepare_vocabularies(train_df, val_df, test_df):
    """
    1) Ensure each df is a pandas DataFrame.
    2) Builds dictionary for postal codes.
    3) Finds max article ID, max day, and max age.
    """
    def to_pandas_if_polars(df):
        return df.to_pandas() if not hasattr(df, "iloc") else df

    train_pd = to_pandas_if_polars(train_df)
    val_pd   = to_pandas_if_polars(val_df)
    test_pd  = to_pandas_if_polars(test_df)

    combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)

    unique_postals = combined['postal_code'].unique()
    postal2idx = {p: i for i, p in enumerate(unique_postals)}
    num_postal = len(postal2idx)

    all_articles = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['articles_ids_lst']:
            all_articles.extend(lst)
    max_article_id = max(all_articles)
    num_articles = max_article_id + 1

    all_days = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['days_before_lst']:
            all_days.extend(lst)
    max_day = max(all_days)

    max_age = combined['age'].max()
    num_age = max_age + 1

    return postal2idx, num_postal, num_articles, max_day, num_age

class CustomerDataset(Dataset):
    """
    Expects columns:
      - postal_code (str)
      - days_before_lst (list[int])
      - articles_ids_lst (list[int])
      - regression_label (float)
      - classification_label (int) (0 means churn, 1 means not churn)
      - age (int)
    Note: customer_id is no longer used.
    """
    def __init__(self, df, postal2idx: Dict[str, int]):
        if not hasattr(df, "iloc"):
            df = df.to_pandas()
        self.data = df
        self.postal2idx = postal2idx

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        postal_id = self.postal2idx[row['postal_code']]
        age = int(row['age'])
        articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)
        days = torch.tensor(row['days_before_lst'], dtype=torch.long)
        regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)
        classification_label = torch.tensor(int(row['classification_label']), dtype=torch.long)
        return (articles, days, age, postal_id, regression_label, classification_label)

# Custom collate function for variable-length sequences (without customer_id)
def fixed_length_collate_fn(batch, sequence_length=8):
    articles_list, days_list, ages, postal_ids, reg_labels, class_labels = [], [], [], [], [], []
    for item in batch:
        articles, days, age, postal_id, reg_label, cls_label = item
        articles_list.append(articles)
        days_list.append(days)
        ages.append(age)
        postal_ids.append(postal_id)
        reg_labels.append(reg_label)
        class_labels.append(cls_label)

    def pad_or_trunc(seq, desired_length):
        length = seq.size(0)
        if length > desired_length:
            return seq[:desired_length]
        elif length < desired_length:
            pad = torch.zeros(desired_length - length, dtype=seq.dtype)
            return torch.cat([seq, pad], dim=0)
        else:
            return seq

    for i in range(len(articles_list)):
        articles_list[i] = pad_or_trunc(articles_list[i], sequence_length)
        days_list[i] = pad_or_trunc(days_list[i], sequence_length)

    articles_tensor = torch.stack(articles_list, dim=0)
    days_tensor = torch.stack(days_list, dim=0)
    ages_tensor = torch.tensor(ages, dtype=torch.long)
    postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)
    reg_labels_tensor = torch.stack(reg_labels, dim=0)
    class_labels_tensor = torch.stack(class_labels, dim=0)

    return (articles_tensor, days_tensor, ages_tensor, postal_ids_tensor, reg_labels_tensor, class_labels_tensor)


# BST model WITHOUT customer_id input

class BST(pl.LightningModule):
    def __init__(
        self,
        num_articles,
        max_day,
        num_age,
        num_postal,
        sequence_length,
        train_df,
        val_df,
        test_df,
        postal2idx,
        # embedding dims
        article_emb_dim=16,
        day_emb_dim=8,
        age_emb_dim=4,
        postal_emb_dim=4,
        # transformer config
        transformer_nhead=2,
        transformer_ff_dim=64,
        num_transformer_layers=1,
        # multi-task config
        predict_churn=False,  # Flag to enable/disable churn prediction
        # training
        learning_rate=0.0005
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['train_df','val_df','test_df','postal2idx'])
        self.learning_rate = learning_rate
        self.predict_churn = predict_churn

        # DataFrames and mapping
        self.train_df = train_df
        self.val_df   = val_df
        self.test_df  = test_df
        self.postal2idx = postal2idx

        # Embeddings (customer embedding removed)
        self.embeddings_age = nn.Embedding(num_age, age_emb_dim)
        self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)
        self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)
        self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)

        # Sequence features: concatenation of article and day embeddings
        self.seq_feature_dim = article_emb_dim + day_emb_dim

        # Custom Transformer setup
        config = {
            "embedding_dim": self.seq_feature_dim,
            "heads": transformer_nhead,
            "transformer_dropout": 0.2,
            "dim_feedforward": transformer_ff_dim,
            "norm_first": False,
            "num_transformer_layers": num_transformer_layers,
        }
        self.transformer = TransformerEncoder(config)
        self.transformer_output_dim = sequence_length * self.seq_feature_dim

        # User features: only age and postal embeddings are used
        user_feature_dim = age_emb_dim + postal_emb_dim

        combined_dim = self.transformer_output_dim + user_feature_dim

        # Separate heads for regression and (optional) classification
        self.regressor_head = nn.Sequential(
            nn.Linear(combined_dim, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1)
        )

        if self.predict_churn:
            self.classifier_head = nn.Sequential(
                nn.Linear(combined_dim, 128),
                nn.LeakyReLU(),
                nn.Linear(128, 1)  # single logit for binary classification
            )
            self.classification_criterion = nn.BCEWithLogitsLoss()

        self.regression_criterion = nn.MSELoss()

    def encode_input(self, batch):
        # Expected tuple: (articles, days, age, postal_id, regression_label, classification_label)
        articles, days, age, postal_id, regression_label, classification_label = batch

        article_embeds = self.embeddings_article(articles)  # (B, L, article_emb_dim)
        day_embeds = self.embeddings_day(days)              # (B, L, day_emb_dim)
        # Concatenate to form the sequence features; day_embeds serve as the time/position signal.
        transformer_input = torch.cat([article_embeds, day_embeds], dim=-1)  # (B, L, seq_feature_dim)

        transformer_output = self.transformer(transformer_input)  # (B, L, seq_feature_dim)
        transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)

        age_embed = self.embeddings_age(age)
        postal_embed = self.embeddings_postal(postal_id)
        user_features = torch.cat([age_embed, postal_embed], dim=1)

        combined_features = torch.cat([transformer_output_flat, user_features], dim=1)
        return combined_features, regression_label, classification_label

    def forward(self, batch):
        features, reg_label, class_label = self.encode_input(batch)
        reg_output = self.regressor_head(features).squeeze(dim=-1)
        if self.predict_churn:
            class_output = self.classifier_head(features).squeeze(dim=-1)
        else:
            class_output = None
        return reg_output, class_output, reg_label, class_label

    def training_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)
        reg_loss = self.regression_criterion(reg_output, reg_label)
        loss = reg_loss
        self.log("train_reg_loss", reg_loss)
        if self.predict_churn:
            class_loss = self.classification_criterion(class_output, class_label.float())
            self.log("train_class_loss", class_loss)
            loss = reg_loss + class_loss
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)
        reg_loss = self.regression_criterion(reg_output, reg_label)
        self.log("val_reg_loss", reg_loss)
        loss = reg_loss
        if self.predict_churn:
            class_loss = self.classification_criterion(class_output, class_label.float())
            self.log("val_class_loss", class_loss)
            loss = reg_loss + class_loss
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)
        reg_loss = self.regression_criterion(reg_output, reg_label)
        self.log("test_reg_loss", reg_loss)
        loss = reg_loss
        if self.predict_churn:
            class_loss = self.classification_criterion(class_output, class_label.float())
            self.log("test_class_loss", class_loss)
            loss = reg_loss + class_loss
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = CustomerDataset(self.train_df, self.postal2idx)
            self.val_dataset   = CustomerDataset(self.val_df, self.postal2idx)
        if stage == "test" or stage is None:
            self.test_dataset  = CustomerDataset(self.test_df, self.postal2idx)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=128,
            shuffle=True,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda b: fixed_length_collate_fn(b, sequence_length=self.hparams.sequence_length)
        )


# train and test

model = BST(
    #num_customers=num_customers,
    num_articles=num_articles,
    max_day=max_day,
    num_age=num_age,
    num_postal=num_postal,
    sequence_length=sequence_length,
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    #user2idx=user2idx,
    postal2idx=postal2idx,
    article_emb_dim=16,
    day_emb_dim=8,
    #customer_emb_dim=16,
    age_emb_dim=4,
    postal_emb_dim=4,
    transformer_nhead=2,
    transformer_ff_dim=64,
    num_transformer_layers=1,
    predict_churn=True,   # <--- Enable churn
    learning_rate=0.0005
)

trainer = pl.Trainer(accelerator="gpu", devices="auto", max_epochs=1)
trainer.fit(model)
trainer.test(model)



# ALTERNATIVE VERSION (WITH CUSTOM COLLATOR)

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from typing import List, Dict, Optional


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-8) -> None:
        super().__init__()
        self.eps = eps
        self.gain = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return norm * self.gain

class CustomMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.dropout = dropout
      
    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        B, T, _ = query.size()
        qkv = self.in_proj(query)
        q, k, v = qkv.chunk(3, dim=-1)
        # Reshape for multi-head attention
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        attn_mask = self.merge_masks(attn_mask, key_padding_mask, query)
        attn_output = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            dropout_p=self.dropout if self.training else 0.0,
            is_causal=False,
        )
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.embed_dim)
        output = self.out_proj(attn_output)
        return output

    def merge_masks(
        self,
        attn_mask: Optional[torch.Tensor],
        key_padding_mask: Optional[torch.Tensor],
        query: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        merged_mask = None
        batch_size, seq_len, _ = query.shape

        def convert_to_float_mask(mask):
            if mask.dtype == torch.bool:
                return mask.float().masked_fill(mask, float("-inf"))
            return mask

        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
                -1, self.num_heads, -1, -1)
            merged_mask = convert_to_float_mask(key_padding_mask)

        if attn_mask is not None:
            if attn_mask.dim() == 2:
                correct_2d_size = (seq_len, seq_len)
                if attn_mask.shape != correct_2d_size:
                    raise RuntimeError(
                        f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
                    )
                attn_mask = attn_mask.unsqueeze(0).expand(batch_size, self.num_heads, -1, -1)
            elif attn_mask.dim() == 3:
                correct_3d_size = (batch_size * self.num_heads, seq_len, seq_len)
                if attn_mask.shape != correct_3d_size:
                    raise RuntimeError(
                        f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
                    )
                attn_mask = attn_mask.view(batch_size, self.num_heads, seq_len, seq_len)
            else:
                raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
            attn_mask = convert_to_float_mask(attn_mask)
            if merged_mask is None:
                merged_mask = attn_mask
            else:
                merged_mask = merged_mask + attn_mask

        return merged_mask

class TransformerEncoderLayer(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        embed_dim = config["embedding_dim"]
        num_heads = config.get("heads", 8)
        dropout = config["transformer_dropout"]
        dim_feedforward = config["dim_feedforward"]
        self.norm_first = config.get("norm_first", False)
        self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(embed_dim, dim_feedforward)
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, embed_dim)
        self.dropout2 = nn.Dropout(dropout)
        self.norm1 = RMSNorm(embed_dim)
        self.norm2 = RMSNorm(embed_dim)
        self.activation = nn.GELU()

    def _sa_block(self, src, attn_mask=None, key_padding_mask=None):
        src2 = self.self_attn(src, src, src,
                              key_padding_mask=key_padding_mask,
                              attn_mask=attn_mask)
        return self.dropout1(src2)

    def _ff_block(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        return self.dropout2(src2)

    def forward(self,
                src: torch.Tensor,
                src_key_padding_mask: torch.Tensor = None,
                src_mask: torch.Tensor = None):
        if self.norm_first:
            src = src + self._sa_block(self.norm1(src),
                                        attn_mask=src_mask,
                                        key_padding_mask=src_key_padding_mask)
            src = src + self._ff_block(self.norm2(src))
        else:
            src = self.norm1(src + self._sa_block(src,
                                                  attn_mask=src_mask,
                                                  key_padding_mask=src_key_padding_mask))
            src = self.norm2(src + self._ff_block(src))
        return src

class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = nn.ModuleList([
            TransformerEncoderLayer(config)
            for _ in range(config["num_transformer_layers"])
        ])

    def forward(self,
                src,
                src_key_padding_mask=None,
                src_mask=None):
        for layer in self.encoder:
            src = layer(src,
                        src_key_padding_mask=src_key_padding_mask,
                        src_mask=src_mask)
        return src


def prepare_vocabularies(train_df, val_df, test_df):
    """
    Ensures each df is a pandas DataFrame, builds a dictionary for postal codes,
    and finds max article ID, max day, and max age.
    """
    def to_pandas_if_polars(df):
        return df.to_pandas() if not hasattr(df, "iloc") else df

    train_pd = to_pandas_if_polars(train_df)
    val_pd   = to_pandas_if_polars(val_df)
    test_pd  = to_pandas_if_polars(test_df)
    combined = pd.concat([train_pd, val_pd, test_pd], ignore_index=True)

    unique_postals = combined['postal_code'].unique()
    postal2idx = {p: i for i, p in enumerate(unique_postals)}
    num_postal = len(postal2idx)

    all_articles = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['articles_ids_lst']:
            all_articles.extend(lst)
    max_article_id = max(all_articles)
    num_articles = max_article_id + 1

    all_days = []
    for df_pd in [train_pd, val_pd, test_pd]:
        for lst in df_pd['days_before_lst']:
            all_days.extend(lst)
    max_day = max(all_days)

    max_age = combined['age'].max()
    num_age = max_age + 1

    return postal2idx, num_postal, num_articles, max_day, num_age

class CustomerDataset(Dataset):
    """
    Expects columns:
      - postal_code (str)
      - days_before_lst (list[int])
      - articles_ids_lst (list[int])
      - regression_label (float)
      - classification_label (int) (0 means churn, 1 means not churn)
      - age (int)
    Note: customer_id is not used.
    """
    def __init__(self, df, postal2idx: Dict[str, int]):
        if not hasattr(df, "iloc"):
            df = df.to_pandas()
        self.data = df
        self.postal2idx = postal2idx

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        postal_id = self.postal2idx[row['postal_code']]
        age = int(row['age'])
        articles = torch.tensor(row['articles_ids_lst'], dtype=torch.long)
        days = torch.tensor(row['days_before_lst'], dtype=torch.long)
        regression_label = torch.tensor(float(row['regression_label']), dtype=torch.float)
        classification_label = torch.tensor(int(row['classification_label']), dtype=torch.long)
        return (articles, days, age, postal_id, regression_label, classification_label)


def fixed_length_collate_fn(batch: list[tuple[torch.Tensor, torch.Tensor, int, int, torch.Tensor, torch.Tensor]],
                           sequence_length: int = 8, padding_value: int = 0) -> tuple[torch.Tensor, ...]:
    """
    Efficiently pads sequences using PyTorch's pad_sequence and then truncates them.
    
    Args:
        batch: List of tuples where each tuple contains 
              (articles, days, age, postal_id, regression_label, classification_label)
        sequence_length: Desired length of the sequences
        padding_value: Value to use for padding sequences
    
    Returns:
        Tuple of tensors: 
          (article_seqs_tensor, day_seqs_tensor, ages_tensor, postal_ids_tensor, reg_labels_tensor, class_labels_tensor)
    """
    article_seqs, day_seqs, ages, postal_ids, reg_labels, class_labels = zip(*batch)
    article_seqs_tensor = pad_sequence(article_seqs, batch_first=True, padding_value=padding_value)
    day_seqs_tensor = pad_sequence(day_seqs, batch_first=True, padding_value=padding_value)
    # Truncate to the desired sequence length
    article_seqs_tensor = article_seqs_tensor[:, :sequence_length]
    day_seqs_tensor = day_seqs_tensor[:, :sequence_length]
    ages_tensor = torch.tensor(ages, dtype=torch.long)
    postal_ids_tensor = torch.tensor(postal_ids, dtype=torch.long)
    reg_labels_tensor = torch.stack(reg_labels, dim=0)
    class_labels_tensor = torch.stack(class_labels, dim=0)
    return (
        article_seqs_tensor,
        day_seqs_tensor,
        ages_tensor,
        postal_ids_tensor,
        reg_labels_tensor,
        class_labels_tensor
    )


class BST(pl.LightningModule):
    def __init__(
        self,
        num_articles,
        max_day,
        num_age,
        num_postal,
        sequence_length,
        train_df,
        val_df,
        test_df,
        postal2idx,
        # Embedding dimensions
        article_emb_dim: int = 16,
        day_emb_dim: int = 8,
        age_emb_dim: int = 4,
        postal_emb_dim: int = 4,
        # Transformer config
        transformer_nhead: int = 2,
        transformer_ff_dim: int = 64,
        num_transformer_layers: int = 1,
        # Multi-task config
        predict_churn: bool = False,
        # Training
        learning_rate: float = 0.0005
    ):
        super().__init__()
        # Save hyperparameters (including sequence_length)
        self.save_hyperparameters(ignore=['train_df','val_df','test_df','postal2idx'])
        self.learning_rate = learning_rate
        self.predict_churn = predict_churn

        # DataFrames and mapping
        self.train_df = train_df
        self.val_df   = val_df
        self.test_df  = test_df
        self.postal2idx = postal2idx

        # Embeddings (customer_id removed)
        self.embeddings_age = nn.Embedding(num_age, age_emb_dim)
        self.embeddings_postal = nn.Embedding(num_postal, postal_emb_dim)
        self.embeddings_article = nn.Embedding(num_articles, article_emb_dim)
        self.embeddings_day = nn.Embedding(max_day + 1, day_emb_dim)

        # Sequence features: concatenation of article and day embeddings
        self.seq_feature_dim = article_emb_dim + day_emb_dim

        # Custom Transformer setup
        config = {
            "embedding_dim": self.seq_feature_dim,
            "heads": transformer_nhead,
            "transformer_dropout": 0.2,
            "dim_feedforward": transformer_ff_dim,
            "norm_first": False,
            "num_transformer_layers": num_transformer_layers,
        }
        self.transformer = TransformerEncoder(config)
        self.transformer_output_dim = sequence_length * self.seq_feature_dim

        # User features: only age and postal embeddings are used
        user_feature_dim = age_emb_dim + postal_emb_dim
        combined_dim = self.transformer_output_dim + user_feature_dim

        # Separate heads for regression and optional classification
        self.regressor_head = nn.Sequential(
            nn.Linear(combined_dim, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1)
        )
        if self.predict_churn:
            self.classifier_head = nn.Sequential(
                nn.Linear(combined_dim, 128),
                nn.LeakyReLU(),
                nn.Linear(128, 1)  # single logit for binary classification
            )
            self.classification_criterion = nn.BCEWithLogitsLoss()
        self.regression_criterion = nn.MSELoss()

    def encode_input(self, batch):
        # Expecting: (articles, days, age, postal_id, regression_label, classification_label)
        articles, days, age, postal_id, regression_label, classification_label = batch
        article_embeds = self.embeddings_article(articles)  # (B, L, article_emb_dim)
        day_embeds = self.embeddings_day(days)              # (B, L, day_emb_dim)
        # Concatenate article and day embeddings as the sequence features
        transformer_input = torch.cat([article_embeds, day_embeds], dim=-1)  # (B, L, seq_feature_dim)
        transformer_output = self.transformer(transformer_input)  # (B, L, seq_feature_dim)
        transformer_output_flat = transformer_output.reshape(transformer_output.size(0), -1)
        age_embed = self.embeddings_age(age)
        postal_embed = self.embeddings_postal(postal_id)
        user_features = torch.cat([age_embed, postal_embed], dim=1)
        combined_features = torch.cat([transformer_output_flat, user_features], dim=1)
        return combined_features, regression_label, classification_label

    def forward(self, batch):
        features, reg_label, class_label = self.encode_input(batch)
        reg_output = self.regressor_head(features).squeeze(dim=-1)
        if self.predict_churn:
            class_output = self.classifier_head(features).squeeze(dim=-1)
        else:
            class_output = None
        return reg_output, class_output, reg_label, class_label

    def training_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)
        reg_loss = self.regression_criterion(reg_output, reg_label)
        loss = reg_loss
        self.log("train_reg_loss", reg_loss)
        if self.predict_churn:
            class_loss = self.classification_criterion(class_output, class_label.float())
            self.log("train_class_loss", class_loss)
            loss = reg_loss + class_loss
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)
        reg_loss = self.regression_criterion(reg_output, reg_label)
        self.log("val_reg_loss", reg_loss)
        loss = reg_loss
        if self.predict_churn:
            class_loss = self.classification_criterion(class_output, class_label.float())
            self.log("val_class_loss", class_loss)
            loss = reg_loss + class_loss
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        reg_output, class_output, reg_label, class_label = self(batch)
        reg_loss = self.regression_criterion(reg_output, reg_label)
        self.log("test_reg_loss", reg_loss)
        loss = reg_loss
        if self.predict_churn:
            class_loss = self.classification_criterion(class_output, class_label.float())
            self.log("test_class_loss", class_loss)
            loss = reg_loss + class_loss
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = CustomerDataset(self.train_df, self.postal2idx)
            self.val_dataset   = CustomerDataset(self.val_df, self.postal2idx)
        if stage == "test" or stage is None:
            self.test_dataset  = CustomerDataset(self.test_df, self.postal2idx)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=128,
            shuffle=True,
            num_workers=4,
            collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=4,
            collate_fn=lambda batch: fixed_length_collate_fn(batch, sequence_length=self.hparams.sequence_length)
        )
postal2idx, num_postal, num_articles, max_day, num_age = prepare_vocabularies(train_df, val_df, test_df)

    # Define the desired sequence length
sequence_length = 8        
model = BST(
    #num_customers=num_customers,
    num_articles=num_articles,
    max_day=max_day,
    num_age=num_age,
    num_postal=num_postal,
    sequence_length=sequence_length,
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    #user2idx=user2idx,
    postal2idx=postal2idx,
    article_emb_dim=16,
    day_emb_dim=8,
    #customer_emb_dim=16,
    age_emb_dim=4,
    postal_emb_dim=4,
    transformer_nhead=2,
    transformer_ff_dim=64,
    num_transformer_layers=1,
    predict_churn=True,   # <--- Enable churn
    learning_rate=0.0005
)

trainer = pl.Trainer(accelerator="gpu", devices="auto", max_epochs=1)
trainer.fit(model)
trainer.test(model)

# GET ENVIRONMENT 

In [18]:
!pip freeze > requirements.txt
