## Cross-Cohort Transaction Sequence Forecasting on the Multichannel Dataset

This notebook implements the complete forecasting pipeline described in the accompanying thesis. It is optimized for execution on a GPU to accelerate model training and inference.

The workflow proceeds as follows:
1. **Environment Setup**: Load libraries and define utility functions.
2. **Data Preparation**: Load and preprocess the Multichannel dataset in accordance with the temporal framing and forecasting task.
3. **Model Definition**: Specify the architecture of the Transformer-based forecasting model and supporting components.
4. **Training & Evaluation**: Configure and execute the training loop, followed by performance evaluation and visualization of key results.
5. **Execution Flow**: A main execution block is provided to orchestrate the training and prediction processes.

Most configurations (e.g., model parameters, training horizon, evaluation metrics) are modular and adjustable. For details, please refer to the final section of the notebook.


**Note**: This notebook expects the transaction dataset at `15-transactions_allCohorts.csv`. Please ensure this file is placed in the working directory.


## 1. Setup and Imports

In [None]:
# Standard libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import datetime
import gc
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.distributions as dist
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import autocast, GradScaler

# For metrics
from sklearn.metrics import mean_squared_error, mean_absolute_error

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## 2. Data Loading and Preprocessing


In [None]:
# Load the transaction data
df = pd.read_csv('15-transactions_allCohorts.csv')
print(f"Data loaded: {df.shape[0]} transactions")
df.head(5)

# Convert dates
df['Date'] = pd.to_datetime(df['ORDER_DATE'])

Data loaded: 190315 transactions


In [None]:
def prepare_sequence_data(df, customer_field='CUSTNO', date_field='Date', cohort_field='COHORT_NUMBER',
                        cohort_range=None, train_periods=None, pred_periods=None):
    """
    Prepare sequence data for transaction forecasting with flexible period and cohort configuration

    Parameters:
    -----------
    df : DataFrame
        Transaction data
    customer_field : str
        Column name for customer ID
    date_field : str
        Column name for transaction date
    cohort_field : str or None
        Column name for cohort information
    cohort_range : tuple or None
        Range of cohorts to include (min_cohort, max_cohort)
    train_periods : dict or None
        Dictionary with 'input_start', 'input_end', 'target_start', 'target_end' for training
    pred_periods : dict or None
        Dictionary with 'input_start', 'input_end', 'target_start', 'target_end' for prediction
    """
    # Convert date field to datetime if it's not already
    df[date_field] = pd.to_datetime(df[date_field])

    # Set default periods if not provided
    if train_periods is None:
        train_periods = {
            'input_start': '2005-01-01',
            'input_end': '2006-12-31',
            'target_start': '2007-01-01',
            'target_end': '2007-12-31'
        }

    if pred_periods is None:
        pred_periods = {
            'input_start': '2008-01-01',
            'input_end': '2009-12-31',
            'target_start': '2010-01-01',
            'target_end': '2010-12-31'
        }

    # Extract periods
    INPUT_START = train_periods['input_start']
    INPUT_END = train_periods['input_end']
    TARGET_START = train_periods['target_start']
    TARGET_END = train_periods['target_end']

    PRED_INPUT_START = pred_periods['input_start']
    PRED_INPUT_END = pred_periods['input_end']
    PRED_TARGET_START = pred_periods['target_start']
    PRED_TARGET_END = pred_periods['target_end']

    print(f"Training input: {INPUT_START} to {INPUT_END}")
    print(f"Training target: {TARGET_START} to {TARGET_END}")
    print(f"Prediction input: {PRED_INPUT_START} to {PRED_INPUT_END}")
    print(f"Prediction target: {PRED_TARGET_START} to {PRED_TARGET_END}")

    # Apply cohort filtering if specified
    if cohort_range is not None and cohort_field is not None:
        min_cohort, max_cohort = cohort_range

        # Check data type of cohort_field and convert if necessary
        cohort_dtype = df[cohort_field].dtype

        if cohort_dtype == 'object' or cohort_dtype == 'string':
            # If cohort is string/object type, convert range values to string
            min_cohort_val = str(min_cohort)
            max_cohort_val = str(max_cohort)
        elif 'datetime' in str(cohort_dtype):
            # If cohort is datetime, ensure range values are datetime
            min_cohort_val = pd.to_datetime(min_cohort)
            max_cohort_val = pd.to_datetime(max_cohort)
        else:
            # For numeric types, make sure the types match
            if isinstance(min_cohort, str):
                # If they're strings but should be numeric, convert the column
                df[cohort_field] = pd.to_numeric(df[cohort_field], errors='coerce')
                min_cohort_val = float(min_cohort)
                max_cohort_val = float(max_cohort)
            else:
                # Otherwise use the values as-is
                min_cohort_val = min_cohort
                max_cohort_val = max_cohort

        # Apply the filter with proper types
        df_filtered = df[(df[cohort_field] >= min_cohort_val) & (df[cohort_field] <= max_cohort_val)]
        print(f"Filtered to cohorts {min_cohort}-{max_cohort}: {len(df_filtered[customer_field].unique())} customers")
    else:
        df_filtered = df
        print(f"Using all cohorts: {len(df_filtered[customer_field].unique())} customers")

    # Create customer mapping
    customer_ids = df_filtered[customer_field].unique()
    customer_to_idx = {cid: idx for idx, cid in enumerate(customer_ids)}

    # Create date ranges - use Monday as the start of the week for consistency
    input_dates = pd.date_range(INPUT_START, INPUT_END, freq='W-MON')
    target_dates = pd.date_range(TARGET_START, TARGET_END, freq='W-MON')
    pred_input_dates = pd.date_range(PRED_INPUT_START, PRED_INPUT_END, freq='W-MON')
    pred_target_dates = pd.date_range(PRED_TARGET_START, PRED_TARGET_END, freq='W-MON')

    print(f"Training input weeks: {len(input_dates)}")
    print(f"Training target weeks: {len(target_dates)}")
    print(f"Prediction input weeks: {len(pred_input_dates)}")
    print(f"Prediction target weeks: {len(pred_target_dates)}")

    # Prepare containers for data
    X_train_data = []
    y_target_data = []
    X_pred_data = []
    y_pred_target_data = []
    customer_indices = []
    customer_cohorts = []

    # Process each customer
    for customer_id in tqdm(customer_ids, desc="Processing customers"):
        # Filter customer transactions
        customer_df = df_filtered[df_filtered[customer_field] == customer_id].copy()

        # Get customer cohort if available
        if cohort_field is not None:
            customer_cohort = customer_df[cohort_field].iloc[0]
        else:
            customer_cohort = 0

        # Aggregate by week
        weekly_counts = customer_df.groupby([pd.Grouper(key=date_field, freq='W-MON')]).size().to_frame('transactions')

        # Create templates with all dates
        train_input_template = pd.DataFrame(index=input_dates)
        train_target_template = pd.DataFrame(index=target_dates)
        pred_input_template = pd.DataFrame(index=pred_input_dates)
        pred_target_template = pd.DataFrame(index=pred_target_dates)

        # Merge transaction counts
        train_input_data = train_input_template.join(weekly_counts).fillna(0)['transactions'].values
        train_target_data = train_target_template.join(weekly_counts).fillna(0)['transactions'].values
        pred_input_data = pred_input_template.join(weekly_counts).fillna(0)['transactions'].values
        pred_target_data = pred_target_template.join(weekly_counts).fillna(0)['transactions'].values

        # Add to datasets
        X_train_data.append(train_input_data)
        y_target_data.append(train_target_data)
        X_pred_data.append(pred_input_data)
        y_pred_target_data.append(pred_target_data)

        # Add customer index and cohort
        customer_indices.append(customer_to_idx[customer_id])
        customer_cohorts.append(customer_cohort)

    # Convert to tensors
    X_train_tensor = torch.tensor(np.array(X_train_data), dtype=torch.float32)
    y_target_tensor = torch.tensor(np.array(y_target_data), dtype=torch.float32)
    X_pred_tensor = torch.tensor(np.array(X_pred_data), dtype=torch.float32)
    y_pred_target_tensor = torch.tensor(np.array(y_pred_target_data), dtype=torch.float32)
    customer_indices_tensor = torch.tensor(customer_indices, dtype=torch.long)

    return {
        'X_train': X_train_tensor,
        'y_target': y_target_tensor,
        'X_pred': X_pred_tensor,
        'y_pred_target': y_pred_target_tensor,
        'customer_indices': customer_indices_tensor,
        'customer_mapping': customer_to_idx,
        'num_customers': len(customer_ids),
        'customer_cohorts': customer_cohorts,
        'train_periods': train_periods,
        'pred_periods': pred_periods,
        'input_dates': input_dates,
        'target_dates': target_dates,
        'pred_input_dates': pred_input_dates,
        'pred_target_dates': pred_target_dates
    }

In [None]:
def prepare_cross_customer_validation(data, validation_ratio=0.1, random_seed=None):
    """
    Split data by customers for validation

    Parameters:
    -----------
    data : dict
        Dictionary containing 'X_train', 'y_target', and 'customer_indices'
    validation_ratio : float
        Proportion of customers to use for validation (default: 0.1)
    random_seed : int, optional
        Random seed for reproducibility (default: None)

    Returns:
    --------
    dict
        Dictionary containing train/validation splits
    """
    # Extract components
    X_train = data['X_train']
    y_target = data['y_target']
    customer_indices = data['customer_indices']

    # Get number of customers
    num_customers = len(X_train)
    num_val = int(num_customers * validation_ratio)

    # Set random seed if provided
    if random_seed is not None:
        np.random.seed(random_seed)

    # Random indices for validation
    all_indices = np.arange(num_customers)
    np.random.shuffle(all_indices)
    val_indices = all_indices[:num_val]
    train_indices = all_indices[num_val:]

    # Create train and validation sets
    X_train_split = X_train[train_indices]
    y_train_split = y_target[train_indices]
    customer_train = customer_indices[train_indices]

    X_val_split = X_train[val_indices]
    y_val_split = y_target[val_indices]
    customer_val = customer_indices[val_indices]

    return {
        'X_train': X_train_split,
        'y_train': y_train_split,
        'customer_train': customer_train,
        'X_val': X_val_split,
        'y_val': y_val_split,
        'customer_val': customer_val,
        'train_indices': train_indices,
        'val_indices': val_indices
    }

## 3. Dataset and Model Classes


In [None]:
class CustomerTransactionDataset(Dataset):
    """Dataset for customer transaction data"""
    def __init__(self, X, y, customer_ids):
        self.X = X
        self.y = y
        self.customer_ids = customer_ids

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.customer_ids[idx]


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)  # Store on same device as model

    def forward(self, x):
        return x + self.pe[:x.size(1), :].to(x.device)


class TransactionAwareTransformer(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim, num_layers, num_heads, output_dim, num_customers=None, dropout=0.1):
        super(TransactionAwareTransformer, self).__init__()

        # Standard components
        self.input_projection = nn.Linear(input_dim, embed_dim)
        self.pos_encoding = PositionalEncoding(embed_dim)
        self.customer_embedding = nn.Embedding(num_customers+1, embed_dim)

        # Transaction event detector
        self.transaction_detector = nn.Sequential(
            nn.Linear(input_dim, embed_dim),
            nn.LayerNorm(embed_dim),
            nn.GELU()
        )

        # Transformer encoder
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)

        # Output projection
        self.output_projection = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
            nn.Softplus()  # Ensure non-negative predictions
        )

        # Scale factor for prediction (learnable)
        self.log_output_scale = nn.Parameter(torch.tensor(0.0))

    def forward(self, x, customer_ids, time_indices=None):
        # Ensure correct dimensions
        if x.dim() == 2:
            x = x.unsqueeze(-1)

        batch_size, seq_len = x.size(0), x.size(1)

        # Transaction mask - highlight where transactions occurred
        transaction_mask = (x > 0).float()

        # Customer embedding
        cust_embed = self.customer_embedding(customer_ids).unsqueeze(1).expand(-1, seq_len, -1)

        # Process input
        x_embed = self.input_projection(x)
        x_embed = self.pos_encoding(x_embed)

        # Add transaction signal
        trans_signal = self.transaction_detector(x) * transaction_mask
        x_embed = x_embed + cust_embed + trans_signal * 0.5

        # Add time information if available
        if time_indices is not None:
            # Handle time indices shape
            if time_indices.dim() == 1:
                if time_indices.size(0) != seq_len:
                    time_indices = time_indices[:seq_len] if time_indices.size(0) > seq_len else torch.cat([
                        time_indices,
                        time_indices[-1].repeat(seq_len - time_indices.size(0))
                    ])
                time_indices = time_indices.unsqueeze(0).expand(batch_size, -1)

            # Use week of year as a feature (1-52)
            week_of_year = (time_indices % 52) + 1
            week_embed = torch.zeros((batch_size, seq_len, x_embed.size(-1)), device=x_embed.device)

            # Create simple encoding - add a small signal based on week of year
            for i in range(batch_size):
                for j in range(seq_len):
                    week = week_of_year[i, j].item()
                    # Add signal for holiday seasons (weeks 50-52)
                    if week >= 50:
                        week_embed[i, j] += 0.2
                    # Add signal for mid-year (weeks 25-27)
                    elif 25 <= week <= 27:
                        week_embed[i, j] += 0.1

            x_embed = x_embed + week_embed

        # Apply transformer
        x_embed = self.transformer_encoder(x_embed)

        # Use the entire context for prediction (average pooling)
        # instead of just the final hidden state
        final_hidden = x_embed.mean(dim=1)  # Average pooling over sequence length

        # Generate and scale output
        output = self.output_projection(final_hidden)
        output = output * torch.exp(self.log_output_scale)

        return output

## 5. Training and Evaluation Functions

In [None]:
def simplified_transaction_forecaster_train(model, train_loader, val_loader, time_indices=None,
                           num_epochs=50, patience=10, learning_rate=0.001, weight_decay=0.001, device=None):
    """
    Training function for transaction sequence forecasting with
    focus on handling imbalanced data and preventing zero predictions.

    Parameters:
    -----------
    model : nn.Module
        Transaction aware transformer model
    train_loader : DataLoader
        DataLoader for training data
    val_loader : DataLoader
        DataLoader for validation data
    time_indices : torch.Tensor, optional
        Time indices for temporal information
    num_epochs : int
        Number of epochs to train
    patience : int
        Number of epochs to wait for improvement before early stopping
    learning_rate : float
        Learning rate for optimizer
    weight_decay : float
        Weight decay for optimizer
    device : torch.device
        Device to use for training

    Returns:
    --------
    model : nn.Module
        Trained model
    dict
        Training history
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    # 1. Initialization to prevent zero predictions
    with torch.no_grad():
        # Set output bias to positive values
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear) and 'output_projection' in name:
                if hasattr(model, 'output_projection'):
                    if module == model.output_projection[-2]:  # Final layer before activation
                        module.bias.fill_(0.2)
                        print(f"Set positive bias in output layer")

        # Set log_output_scale to positive value
        model.log_output_scale.fill_(1.0)  # Stronger initial value
        print(f"Set log_output_scale to {model.log_output_scale.item():.2f}")

    # Move time indices to device if provided
    if time_indices is not None:
        time_indices = time_indices.to(device)

    # 2. Analyze data imbalance
    print("Analyzing transaction data statistics...")
    total_elements = 0
    nonzero_elements = 0
    target_sum = 0

    with torch.no_grad():
        for _, y_batch, _ in val_loader:
            y_np = y_batch.numpy()
            total_elements += y_np.size
            nonzero_elements += np.sum(y_np > 0)
            target_sum += np.sum(y_np)

    # Calculate key statistics
    tx_ratio = nonzero_elements / total_elements if total_elements > 0 else 0.01
    target_mean = target_sum / total_elements if total_elements > 0 else 0.01

    print(f"Transaction ratio: {tx_ratio:.6f} ({nonzero_elements}/{total_elements})")
    print(f"Target mean: {target_mean:.6f}")

    # Calculate transaction weight based on imbalance
    tx_weight = min(max(5.0, 1.0 / (tx_ratio + 1e-8)), 100.0)
    print(f"Using transaction weight: {tx_weight:.2f}")

    # 3. Setup optimizer with different learning rates
    output_params = []
    other_params = []

    for name, param in model.named_parameters():
        if 'output' in name or name == 'log_output_scale':
            output_params.append(param)
        else:
            other_params.append(param)

    optimizer = optim.AdamW([
        {'params': other_params, 'lr': learning_rate},
        {'params': output_params, 'lr': learning_rate * 0.5}  # Lower LR for output layers
    ], weight_decay=weight_decay)

    # LR scheduler with warmup
    scheduler = OneCycleLR(
        optimizer,
        max_lr=[learning_rate, learning_rate * 0.5],
        total_steps=num_epochs * len(train_loader),
        pct_start=0.2,  # 20% warmup
        div_factor=20
    )

    # 4. Training state variables
    best_val_metric = float('inf')
    best_model = None
    patience_counter = 0
    best_scale_factor = 1.0
    zero_prediction_counter = 0

    # For mixed precision training
    scaler = GradScaler()

    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_rmse': [],
        'transaction_rmse': [],
        'scale_factors': [],
        'pred_means': []
    }

    # 5. Loss function with imbalance handling
    def imbalanced_loss(y_pred, y_true, epoch):
        """Loss function optimized for imbalanced transaction data"""
        # Basic stats
        batch_mean_true = torch.mean(y_true)
        batch_mean_pred = torch.mean(y_pred)

        # Scale correction
        if batch_mean_pred < 0.001:
            scale_factor = torch.tensor(10.0, device=y_pred.device)
        else:
            scale_factor = batch_mean_true / (batch_mean_pred + 1e-8)
            scale_factor = torch.clamp(scale_factor, 0.2, 50.0)

        y_pred_scaled = y_pred * scale_factor

        # Transaction masks
        nonzero_mask = (y_true > 0)
        zero_mask = ~nonzero_mask

        # Weighted MSE with higher weights for transactions
        weights = torch.ones_like(y_true, device=y_true.device)
        if nonzero_mask.sum() > 0:
            weights[nonzero_mask] = tx_weight

            # Extra weight for high-value transactions
            high_value_mask = (y_true > target_mean * 2)
            if high_value_mask.sum() > 0:
                weights[high_value_mask] *= 1.5

        # Base loss with weights
        base_loss = F.mse_loss(y_pred_scaled, y_true, reduction='none')
        weighted_mse = (base_loss * weights).mean()

        # Transaction-specific component
        if nonzero_mask.sum() > 0:
            # Transaction accuracy
            tx_loss = F.mse_loss(y_pred_scaled[nonzero_mask], y_true[nonzero_mask])

            # Zero-prediction penalty
            zero_pred_mask = (y_pred_scaled[nonzero_mask] < 0.05)
            if zero_pred_mask.sum() > 0:
                zero_penalty = F.mse_loss(
                    y_pred_scaled[nonzero_mask][zero_pred_mask],
                    torch.ones_like(y_pred_scaled[nonzero_mask][zero_pred_mask]) * 0.2
                ) * 10.0
            else:
                zero_penalty = torch.tensor(0.0, device=y_pred.device)
        else:
            tx_loss = torch.tensor(0.0, device=y_pred.device)
            zero_penalty = torch.tensor(0.0, device=y_pred.device)

        # Volume preservation
        pred_sum = torch.sum(y_pred_scaled, dim=1)
        true_sum = torch.sum(y_true, dim=1)
        volume_penalty = torch.mean(torch.abs(pred_sum - true_sum))

        # Zero-prediction trap escape
        if batch_mean_pred < 0.01:
            zero_trap_penalty = torch.exp(-100.0 * batch_mean_pred) * 10.0
        else:
            zero_trap_penalty = torch.tensor(0.0, device=y_pred.device)

        # Early epochs: focus on non-zero predictions
        if epoch < 5:
            return weighted_mse + 2.0 * tx_loss + volume_penalty + zero_penalty + zero_trap_penalty
        else:
            return weighted_mse + 1.5 * tx_loss + volume_penalty + zero_penalty * 0.5

    # 6. Training loop
    print(f"Starting transaction forecaster training on {device}...")

    for epoch in range(num_epochs):
        # TRAINING
        model.train()
        train_loss = 0.0
        epoch_scale_factors = []
        epoch_pred_means = []

        for X_batch, y_batch, customer_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            # Move data to device
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            customer_batch = customer_batch.to(device)

            # Forward pass
            optimizer.zero_grad()

            try:
                # Forward with mixed precision
                with autocast():
                    # Get predictions
                    if time_indices is not None:
                        predictions = model(X_batch, customer_batch, time_indices)
                    else:
                        predictions = model(X_batch, customer_batch)

                    # Track statistics
                    batch_mean_pred = torch.mean(predictions).item()
                    epoch_pred_means.append(batch_mean_pred)

                    # Handle zero predictions with offset if needed
                    if batch_mean_pred < 0.001:
                        predictions = predictions + 0.01
                        batch_mean_pred = torch.mean(predictions).item()

                    # Calculate scale factor
                    batch_mean_true = torch.mean(y_batch).item()
                    if batch_mean_pred > 1e-8:
                        scale = batch_mean_true / batch_mean_pred
                        scale = min(max(scale, 0.1), 100.0)
                        epoch_scale_factors.append(scale)

                    # Calculate imbalanced loss
                    loss = imbalanced_loss(predictions, y_batch, epoch)

                # Backward with scaling
                scaler.scale(loss).backward()

                # Gradient clipping
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

                # Update weights
                scaler.step(optimizer)
                scaler.update()

                # Update LR
                scheduler.step()

            except Exception as e:
                print(f"Error in training step: {e}")
                # Fall back to simple MSE
                optimizer.zero_grad()
                predictions = model(X_batch, customer_batch)
                simple_loss = F.mse_loss(predictions, y_batch)
                simple_loss.backward()
                optimizer.step()
                loss = simple_loss

            train_loss += loss.item()

        # Process epoch results
        avg_train_loss = train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # Process prediction means
        avg_pred_mean = np.mean(epoch_pred_means) if epoch_pred_means else 0
        history['pred_means'].append(avg_pred_mean)
        print(f"  Average prediction mean: {avg_pred_mean:.6f}")

        # Process scale factors
        if epoch_scale_factors:
            avg_scale_factor = np.mean(epoch_scale_factors)
            history['scale_factors'].append(avg_scale_factor)
            print(f"  Average scale factor: {avg_scale_factor:.2f}")

            # Update best scale factor
            if avg_scale_factor > 0.5 and avg_scale_factor < 50.0:
                if best_scale_factor == 1.0:
                    best_scale_factor = avg_scale_factor
                else:
                    best_scale_factor = 0.8 * best_scale_factor + 0.2 * avg_scale_factor

        # Handle zero predictions
        if avg_pred_mean < 0.001:
            zero_prediction_counter += 1
            print(f"  WARNING: Near-zero predictions detected ({zero_prediction_counter} epochs)")

            if zero_prediction_counter >= 2:
                # Apply intervention
                print("  Applying scale intervention")
                with torch.no_grad():
                    # Boost log_output_scale
                    current_scale = model.log_output_scale.item()
                    model.log_output_scale.fill_(current_scale + 1.0)

                    # Reset output bias
                    for name, module in model.named_modules():
                        if isinstance(module, nn.Linear) and 'output_projection' in name:
                            if hasattr(model, 'output_projection'):
                                if module == model.output_projection[-2]:
                                    module.bias.fill_(0.5)
        else:
            zero_prediction_counter = 0

        # VALIDATION
        model.eval()
        val_loss = 0.0
        val_predictions = []
        val_targets = []

        with torch.no_grad():
            for X_batch, y_batch, customer_batch in val_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)
                customer_batch = customer_batch.to(device)

                # Get predictions
                if time_indices is not None:
                    val_preds = model(X_batch, customer_batch, time_indices)
                else:
                    val_preds = model(X_batch, customer_batch)

                # Apply scale factor
                scale_to_apply = best_scale_factor if best_scale_factor > 1.0 else 5.0
                val_preds_scaled = val_preds * scale_to_apply

                # Calculate simple loss
                batch_loss = F.mse_loss(val_preds, y_batch).item()
                val_loss += batch_loss

                # Store predictions and targets
                val_predictions.append(val_preds_scaled.cpu().numpy())
                val_targets.append(y_batch.cpu().numpy())

        # Calculate metrics
        val_predictions_np = np.vstack(val_predictions)
        val_targets_np = np.vstack(val_targets)

        val_mse = mean_squared_error(val_targets_np, val_predictions_np)
        val_rmse = np.sqrt(val_mse)

        # Transaction-specific metrics
        nonzero_mask = val_targets_np > 0
        if nonzero_mask.sum() > 0:
            tx_mse = mean_squared_error(
                val_targets_np[nonzero_mask],
                val_predictions_np[nonzero_mask]
            )
            tx_rmse = np.sqrt(tx_mse)
        else:
            tx_rmse = 0.0

        # Store metrics
        val_loss /= len(val_loader)
        history['val_loss'].append(val_loss)
        history['val_rmse'].append(val_rmse)
        history['transaction_rmse'].append(tx_rmse)

        # Print metrics
        print(f"  Validation RMSE: {val_rmse:.6f}, Transaction RMSE: {tx_rmse:.6f}")

        # Validation metric with focus on transaction accuracy
        val_metric = val_rmse * 0.4 + tx_rmse * 0.6

        # Check for improvement
        if epoch >= 5 and avg_pred_mean > 0.0005 and val_metric < best_val_metric:
            best_val_metric = val_metric
            best_model = {k: v.cpu().detach() for k, v in model.state_dict().items()}
            patience_counter = 0
            print(f"  New best model! Val metric: {val_metric:.6f}")
        else:
            patience_counter += 1
            print(f"  No improvement for {patience_counter} epochs")

        # Early stopping
        if patience_counter >= patience and epoch >= 10:
            print(f"Early stopping after {epoch+1} epochs")
            break

    # Load best model if available
    if best_model is not None:
        model.load_state_dict(best_model)
        print("Loaded best model from checkpoint")

    # Store best scale factor
    model.best_scale_factor = best_scale_factor
    print(f"Best scale factor: {best_scale_factor:.2f}")

    return model, history

In [None]:
def predict_and_evaluate(model, X_pred, y_target, customer_indices, time_indices=None, fixed_scale=14.0):
    """
    Generate predictions and evaluate on target data with automatic scaling

    Parameters:
    -----------
    model : nn.Module
        Trained model
    X_pred : torch.Tensor
        Input data
    y_target : torch.Tensor
        Target values
    customer_indices : torch.Tensor
        Customer indices
    time_indices : torch.Tensor, optional
        Time indices for temporal information
    fixed_scale : float, optional
        Fixed scaling factor to apply to predictions (default: 14.0)
    """
    model.eval()

    # Prepare data
    X_pred = X_pred.to(device)
    y_target = y_target.to(device)
    customer_indices = customer_indices.to(device)

    if time_indices is not None:
        time_indices = time_indices.to(device)

    # Generate predictions
    with torch.no_grad():
        # Get raw predictions
        raw_predictions = model(X_pred, customer_indices, time_indices)

        # Apply fixed scaling for better results on new cohorts
        predictions = raw_predictions * fixed_scale

        # Log the scale used
        print(f"Applied fixed scale factor: {fixed_scale:.2f}")

    # Calculate metrics
    predictions_np = predictions.cpu().numpy()
    raw_predictions_np = raw_predictions.cpu().numpy()
    targets_np = y_target.cpu().numpy()

    # Overall metrics
    mse = mean_squared_error(targets_np.flatten(), predictions_np.flatten())
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(targets_np.flatten(), predictions_np.flatten())

    print(f"Prediction Metrics:")
    print(f"MSE: {mse:.6f}")
    print(f"RMSE: {rmse:.6f}")
    print(f"MAE: {mae:.6f}")

    # Transaction-specific metrics
    nonzero_mask = targets_np > 0
    if nonzero_mask.sum() > 0:
        trans_mse = mean_squared_error(
            targets_np[nonzero_mask],
            predictions_np[nonzero_mask]
        )
        trans_rmse = np.sqrt(trans_mse)
        trans_mae = mean_absolute_error(
            targets_np[nonzero_mask],
            predictions_np[nonzero_mask]
        )

        # Calculate transaction recall and precision
        pred_threshold = 0.1  # Threshold to consider a prediction as significant
        detected = (predictions_np[nonzero_mask] >= pred_threshold).sum()
        total_transactions = nonzero_mask.sum()
        tx_recall = detected / total_transactions if total_transactions > 0 else 0

        # How many predicted transactions were actual transactions
        predicted_tx_mask = predictions_np >= pred_threshold
        true_positives = (predicted_tx_mask & nonzero_mask).sum()
        tx_precision = true_positives / predicted_tx_mask.sum() if predicted_tx_mask.sum() > 0 else 0

        print(f"Transaction-only Metrics:")
        print(f"Trans MSE: {trans_mse:.6f}")
        print(f"Trans RMSE: {trans_rmse:.6f}")
        print(f"Trans MAE: {trans_mae:.6f}")
        print(f"Transaction Recall: {tx_recall:.4f} ({detected}/{total_transactions})")
        print(f"Transaction Precision: {tx_precision:.4f}")
    else:
        trans_mse = trans_rmse = trans_mae = tx_recall = tx_precision = None

    # Volume metrics
    pred_volume = np.sum(predictions_np)
    raw_pred_volume = np.sum(raw_predictions_np)
    target_volume = np.sum(targets_np)
    volume_ratio = pred_volume / target_volume if target_volume > 0 else 0
    raw_volume_ratio = raw_pred_volume / target_volume if target_volume > 0 else 0

    print(f"Volume Metrics:")
    print(f"Predicted total transactions: {pred_volume:.1f}")
    print(f"Raw (unscaled) predicted total: {raw_pred_volume:.1f}")
    print(f"Actual total transactions: {target_volume:.1f}")
    print(f"Volume ratio: {volume_ratio:.4f} (Raw: {raw_volume_ratio:.4f})")

    return predictions.cpu(), {
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'predictions': predictions_np,
        'raw_predictions': raw_predictions_np,
        'targets': targets_np,
        'trans_mse': trans_mse,
        'trans_rmse': trans_rmse,
        'trans_mae': trans_mae,
        'tx_recall': tx_recall,
        'tx_precision': tx_precision,
        'volume_ratio': volume_ratio
    }

## 6. Visualizations

In [None]:
def visualize_results(X_input, y_true, predictions, num_samples=5, title="Model Predictions"):
    """
    Visualize input sequences and predictions for selected customers
    """
    # Select random samples
    if num_samples > len(X_input):
        num_samples = len(X_input)

    indices = np.random.choice(len(X_input), num_samples, replace=False)

    plt.figure(figsize=(15, 4 * num_samples))

    for i, idx in enumerate(indices):
        plt.subplot(num_samples, 1, i+1)

        # Get data
        input_seq = X_input[idx].numpy()
        true_seq = y_true[idx].numpy()
        pred_seq = predictions[idx].numpy()

        # Create time indices
        input_weeks = np.arange(len(input_seq))
        output_weeks = np.arange(len(input_seq), len(input_seq) + len(true_seq))

        # Plot
        plt.plot(input_weeks, input_seq, 'o-', color='blue',
                 label='Input Sequence', alpha=0.7, markersize=3)
        plt.plot(output_weeks, true_seq, 'o-', color='green',
                 label='True Target', alpha=0.7, markersize=3)
        plt.plot(output_weeks, pred_seq, 'x--', color='red',
                 label='Prediction', alpha=0.7, markersize=4)

        # Add vertical line to separate input and output
        plt.axvline(x=len(input_seq)-1, color='gray', linestyle='--', alpha=0.5)

        plt.title(f'Customer {idx} Transaction Pattern')
        plt.ylabel('Transactions')
        plt.legend()
        plt.grid(True, alpha=0.3)

    plt.suptitle(title, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.show()


def visualize_aggregated_results(X_input, y_true, predictions, title="Aggregated Model Predictions"):
    """
    Visualize aggregated predictions vs true values across all samples.
    Shows average input, target, and predicted sequences over time.
    """
    # Convert to NumPy arrays if necessary
    if not isinstance(X_input, np.ndarray):
        X_input = np.array([x.numpy() for x in X_input])
    if not isinstance(y_true, np.ndarray):
        y_true = np.array([y.numpy() for y in y_true])
    if not isinstance(predictions, np.ndarray):
        predictions = np.array([p.numpy() for p in predictions])

    # Compute mean across samples
    mean_input = X_input.mean(axis=0)
    mean_true = y_true.mean(axis=0)
    mean_pred = predictions.mean(axis=0)

    # Time indices
    input_weeks = np.arange(len(mean_input))
    output_weeks = np.arange(len(mean_input), len(mean_input) + len(mean_true))

    # Plot
    plt.figure(figsize=(12, 6))
    #plt.plot(input_weeks, mean_input, 'o-', label='Mean Input Sequence', color='blue')
    plt.plot(output_weeks, mean_true, 'o-', label='Mean True Target', color='green')
    plt.plot(output_weeks, mean_pred, 'x--', label='Mean Prediction', color='red')

    plt.axvline(x=len(mean_input)-1, color='gray', linestyle='--', alpha=0.5)
    plt.title(title)
    plt.xlabel('Weeks')
    plt.ylabel('Transaction Value')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


def plot_training_history(history):
    """Plot training metrics history from the model training"""
    plt.figure(figsize=(18, 12))

    # Plot training and validation loss
    plt.subplot(2, 2, 1)
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss During Training')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot validation RMSE
    plt.subplot(2, 2, 2)
    plt.plot(history['val_rmse'], label='Validation RMSE')
    if 'transaction_rmse' in history:
        plt.plot(history['transaction_rmse'], label='Transaction-only RMSE')
    plt.title('RMSE During Training')
    plt.xlabel('Epoch')
    plt.ylabel('RMSE')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot validation MAE if available
    plt.subplot(2, 2, 3)
    if 'val_mae' in history:
        plt.plot(history['val_mae'], label='Validation MAE')
        plt.title('MAE During Training')
    else:
        if 'pred_means' in history:
            plt.plot(history['pred_means'], label='Prediction Means')
            plt.title('Prediction Means During Training')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot scale factors if available
    plt.subplot(2, 2, 4)
    if 'scale_factors' in history and history['scale_factors']:
        plt.plot(history['scale_factors'], label='Scale Factors')
        plt.title('Scale Factor Evolution')
    else:
        # Plot something else if scale factors not available
        plt.plot(history['train_loss'], label='Training Loss Evolution')
        plt.title('Loss Evolution')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

## 7. Main Execution Flow

In [None]:
# Define configurable time periods
training_periods = {
    'input_start': '2005-01-01',
    'input_end': '2006-12-31',
    'target_start': '2007-01-01',
    'target_end': '2009-12-26'
}

prediction_periods = {
    'input_start': '2007-01-01',
    'input_end': '2008-12-28',
    'target_start': '2009-01-01',
    'target_end': '2011-12-31'
}

In [None]:
# Prepare the transaction data
df['Date'] = pd.to_datetime(df['ORDER_DATE'])

# %%
# Prepare data with configurable time periods and cohort ranges
training_data = prepare_sequence_data(
    df,
    customer_field='CUSTNO',
    date_field='Date',
    cohort_field='COHORT_NUMBER',
    cohort_range=(1, 24),  # For training cohorts 1-24
    train_periods=training_periods,
    pred_periods=prediction_periods
)

# Prepare prediction data for cohorts 37-60
prediction_data = prepare_sequence_data(
    df,
    customer_field='CUSTNO',
    date_field='Date',
    cohort_field='COHORT_NUMBER',
    cohort_range=(25, 48),  # For prediction cohorts 25-48
    train_periods=training_periods,
    pred_periods=prediction_periods
)


In [None]:
# Extract training data
X_train = training_data['X_train']
y_train_target = training_data['y_target']
train_customer_indices = training_data['customer_indices']
train_num_customers = training_data['num_customers']

# Extract prediction data
X_pred = prediction_data['X_pred']
y_pred_target = prediction_data['y_pred_target']
pred_customer_indices = prediction_data['customer_indices']
pred_num_customers = prediction_data['num_customers']


In [None]:
# Create time indices for weeks of year for both datasets
training_input_start = pd.to_datetime(training_periods['input_start'])
training_input_weeks = X_train.shape[1]
training_time_indices = torch.tensor([
    pd.to_datetime(training_input_start + pd.Timedelta(weeks=i)).isocalendar()[1]
    for i in range(training_input_weeks)
], dtype=torch.long)

prediction_input_start = pd.to_datetime(prediction_periods['input_start'])
prediction_input_weeks = X_pred.shape[1]
prediction_time_indices = torch.tensor([
    pd.to_datetime(prediction_input_start + pd.Timedelta(weeks=i)).isocalendar()[1]
    for i in range(prediction_input_weeks)
], dtype=torch.long)

In [None]:
# Create a reproducible split with seed 42
validation_data = prepare_cross_customer_validation(
    {
        'X_train': X_train,
        'y_target': y_train_target,
        'customer_indices': train_customer_indices
    },
    validation_ratio=0.1,
    random_seed=66
)

# Check average transactions - useful for model tuning
print(f"Average transactions in training data: {X_train.mean():.4f}")
print(f"Average transactions in training target: {y_train_target.mean():.4f}")
print(f"Average transactions in prediction data: {X_pred.mean():.4f}")
print(f"Average transactions in prediction target: {y_pred_target.mean():.4f}")

In [None]:
# Initialize model parameters
model_params = {
    'input_dim': 1,
    'embed_dim': 128,
    'hidden_dim': 512,
    'num_layers': 2,
    'num_heads': 8,
    'output_dim': y_train_target.shape[1],
    'num_customers': train_num_customers,
    'dropout': 0.1
}

# Initialize the model
model = TransactionAwareTransformer(**model_params).to(device)

# Create dataloaders
train_dataset = CustomerTransactionDataset(
    validation_data['X_train'],
    validation_data['y_train'],
    validation_data['customer_train']
)
val_dataset = CustomerTransactionDataset(
    validation_data['X_val'],
    validation_data['y_val'],
    validation_data['customer_val']
)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Train the model
trained_model, training_history = simplified_transaction_forecaster_train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    time_indices=training_time_indices,
    num_epochs=30,
    patience=10,
    learning_rate=0.001,
    weight_decay=0.001,
    device=device
)

plot_training_history(training_history)

## 8. Prediction

In [None]:
# For prediction on new cohorts, we need a new model that handles different customer indices
# We'll use the weights from the trained model but adjust the customer embedding layer
prediction_model_params = model_params.copy()
prediction_model_params['num_customers'] = pred_num_customers
prediction_model_params['output_dim'] = y_pred_target.shape[1]

prediction_model = TransactionAwareTransformer(**prediction_model_params).to(device)

# Copy weights except for customer embedding
with torch.no_grad():
    # Copy all parameters except customer embedding
    for name, param in trained_model.named_parameters():
        if 'customer_embedding' not in name:
            corresponding_param = dict(prediction_model.named_parameters())[name]
            if param.shape == corresponding_param.shape:
                corresponding_param.copy_(param)
            else:
                print(f"Warning: Shape mismatch for {name}, skipping...")

In [None]:
# Evaluate on prediction set
predictions, metrics = predict_and_evaluate(
    prediction_model,
    X_pred,
    y_pred_target,
    pred_customer_indices,
    time_indices=prediction_time_indices,
    fixed_scale=1
)

In [None]:
# Visualize individual results
visualize_results(
    X_pred,
    y_pred_target,
    predictions,
    num_samples=10,
    title="Transaction Predictions for Selected Customers"
)

In [None]:
# Visualize aggregated results
visualize_aggregated_results(
    X_pred,
    y_pred_target,
    predictions,
    title="Aggregated Transaction Predictions"
)

##9. Alternative Configurations

In [None]:
# Example of using different time periods
short_period_config = {
    'input_start': '2005-01-01',
    'input_end': '2005-12-31',  # 1 year input
    'target_start': '2006-01-01',
    'target_end': '2006-06-30'  # 6 months target
}

short_prediction_config = {
    'input_start': '2008-01-01',
    'input_end': '2008-12-31',  # 1 year input
    'target_start': '2009-01-01',
    'target_end': '2009-06-30'  # 6 months target
}

# To use these configurations, run:

# Prepare data with different time periods
short_training_data = prepare_sequence_data(
    df,
    customer_field='CUSTNO',
    date_field='Date',
    cohort_field='COHORT_NUMBER',
    cohort_range=(1, 24),
    train_periods=short_period_config,
    pred_periods=short_prediction_config
)

short_prediction_data = prepare_sequence_data(
    df,
    customer_field='CUSTNO',
    date_field='Date',
    cohort_field='COHORT_NUMBER',
    cohort_range=(37, 60),
    train_periods=short_period_config,
    pred_periods=short_prediction_config
)


# Example of different model configurations
larger_model_params = {
    'input_dim': 1,
    'embed_dim': 256,     # Increased embedding dimension
    'hidden_dim': 1024,   # Increased hidden dimension
    'num_layers': 4,      # More transformer layers
    'num_heads': 16,      # More attention heads
    'output_dim': y_train_target.shape[1],
    'num_customers': train_num_customers,
    'dropout': 0.2        # Increased dropout
}

# To use this configuration:

# Initialize larger model
larger_model = TransactionAwareTransformer(**larger_model_params).to(device)

# Train the larger model
larger_trained_model, larger_training_history = simplified_transaction_forecaster_train(
    model=larger_model,
    train_loader=train_loader,
    val_loader=val_loader,
    time_indices=training_time_indices,
    num_epochs=30,
    patience=10,
    learning_rate=0.0005,  # Lower learning rate for larger model
    weight_decay=0.002,    # Increased weight decay
    device=device
)


# Example of different cohort ranges
different_cohorts = {
    'training': (5, 30),   # Different training cohorts
    'prediction': (45, 70) # Different prediction cohorts
}

# To use these different cohort ranges:

# Prepare data with different cohort ranges
different_cohort_training_data = prepare_sequence_data(
    df,
    customer_field='CUSTNO',
    date_field='Date',
    cohort_field='COHORT_NUMBER',
    cohort_range=different_cohorts['training'],
    train_periods=training_periods,
    pred_periods=prediction_periods
)

different_cohort_prediction_data = prepare_sequence_data(
    df,
    customer_field='CUSTNO',
    date_field='Date',
    cohort_field='COHORT_NUMBER',
    cohort_range=different_cohorts['prediction'],
    train_periods=training_periods,
    pred_periods=prediction_periods
)

