In [None]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import os
import time

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar

# Hydra and OmegaConf
import hydra
from omegaconf import DictConfig, OmegaConf

# ClearML (optional, for experiment tracking)
try:
    from clearml import Task
except ImportError:
    Task = None
    print("ClearML not found. To use ClearML, install with: pip install clearml")


# Attempt to import Mamba
try:
    from mamba_ssm import Mamba
except ImportError:
    print("WARNING: mamba_ssm not found. Using a placeholder Mamba module.")
    print("Please install mamba_ssm: pip install mamba-ssm causal-conv1d>=1.0.0")
    class Mamba(nn.Module):
        def __init__(self, d_model, **kwargs):
            super().__init__()
            self.d_model = d_model
            self.dummy_layer = nn.Linear(d_model, d_model)
            print("Placeholder Mamba initialized. Model will not train correctly.")
        def forward(self, hidden_states, **kwargs):
            return self.dummy_layer(hidden_states)


In [None]:
class Mamba4Rec(nn.Module):
    """
    Mamba for Sequential Recommendation.
    """
    def __init__(self, vocab_size, mamba_config, add_head=True,
                 tie_weights=True, padding_idx=0, init_std=0.02,
                 **kwargs):
        """
        Initializes the Mamba4Rec model.

        Args:
            vocab_size (int): The total number of unique entities in the dataset
                              (items + users + padding).
            mamba_config (dict): Configuration dictionary for the Mamba model.
                                 Must include 'd_model' (embedding/hidden size).
            add_head (bool, optional): Whether to add a final linear layer. Defaults to True.
            tie_weights (bool, optional): Whether to tie embedding and head weights. Defaults to True.
            padding_idx (int, optional): Index for padding. Defaults to 0.
            init_std (float, optional): Standard deviation for weight initialization. Defaults to 0.02.
            **kwargs: Additional keyword arguments passed to the Mamba model constructor.
        """
        super().__init__()

        self.vocab_size = vocab_size
        self.mamba_config = mamba_config
        self.add_head = add_head
        self.tie_weights = tie_weights
        self.padding_idx = padding_idx
        self.init_std = init_std

        if 'd_model' not in mamba_config:
            raise ValueError("mamba_config must contain 'd_model'")
        self.hidden_size = mamba_config['d_model']

        self.embed_layer = nn.Embedding(num_embeddings=vocab_size,
                                        embedding_dim=self.hidden_size,
                                        padding_idx=padding_idx)
        self.mamba_model = Mamba(**mamba_config, **kwargs)

        if self.add_head:
            self.head = nn.Linear(self.hidden_size, vocab_size, bias=False)
            if self.tie_weights:
                if self.head.weight.shape != self.embed_layer.weight.shape:
                     raise ValueError(f"Head ({self.head.weight.shape}) and Embed ({self.embed_layer.weight.shape}) "
                                      f"shapes don't match for tied weights.")
                self.head.weight = self.embed_layer.weight
        self.init_weights()

    def init_weights(self):
        """Initializes weights for embedding and head layers."""
        self.embed_layer.weight.data.normal_(mean=0.0, std=self.init_std)
        if self.padding_idx is not None:
            with torch.no_grad():
                self.embed_layer.weight.data[self.padding_idx].zero_()
        if self.add_head and not self.tie_weights:
             self.head.weight.data.normal_(mean=0.0, std=self.init_std)

    def forward(self, input_ids, attention_mask=None):
        """
        Forward pass of the Mamba4Rec model.
        `attention_mask` is included for API consistency but not used by Mamba.
        """
        embeds = self.embed_layer(input_ids)
        mamba_outputs = self.mamba_model(embeds)
        outputs = mamba_outputs
        if self.add_head:
            outputs = self.head(outputs)
        return outputs

In [None]:

def add_time_idx(df, user_col='user_id', timestamp_col='timestamp', sort=True):
    """Add time index (0-based) to interactions dataframe."""
    if sort:
        print(f"Sorting interactions by {user_col} and {timestamp_col}...")
        df = df.sort_values([user_col, timestamp_col])
    print("Adding time indices (time_idx and time_idx_reversed)...")
    df['time_idx'] = df.groupby(user_col).cumcount()
    df['time_idx_reversed'] = df.groupby(user_col).cumcount(ascending=False)
    return df

def filter_items(df, item_min_count, user_col='user_id', item_col='item_id'):
    """Filter out items with fewer than item_min_count interactions."""
    print(f"Filtering items with less than {item_min_count} unique user interactions...")
    item_counts = df.groupby(item_col)[user_col].nunique()
    valid_items = item_counts[item_counts >= item_min_count].index
    n_items_before = df[item_col].nunique()
    df_filtered = df[df[item_col].isin(valid_items)].copy()
    print(f"  Items before: {n_items_before}, After: {df_filtered[item_col].nunique()}")
    return df_filtered

def filter_users(df, user_min_count, user_col='user_id', item_col='item_id'):
    """Filter out users with fewer than user_min_count interactions."""
    print(f"Filtering users with less than {user_min_count} unique item interactions...")
    user_counts = df.groupby(user_col)[item_col].nunique()
    valid_users = user_counts[user_counts >= user_min_count].index
    n_users_before = df[user_col].nunique()
    df_filtered = df[df[user_col].isin(valid_users)].copy()
    print(f"  Users before: {n_users_before}, After: {df_filtered[user_col].nunique()}")
    return df_filtered

def map_ids(df, user_col='user_id', item_col='item_id'):
    """
    Maps original user_ids and item_ids to contiguous integer ranges.
    Item IDs: 1 to N. User IDs: N+1 to N+M. 0 is for padding.
    """
    print("Mapping original IDs to contiguous integer ranges...")
    unique_items = df[item_col].unique()
    item_map = {item_id: i + 1 for i, item_id in enumerate(unique_items)}
    num_items = len(item_map)
    df['mapped_item_id'] = df[item_col].map(item_map)
    print(f"  Mapped {num_items} unique items (IDs 1 to {num_items}).")

    unique_users = df[user_col].unique()
    # User IDs start after the last item ID
    user_map = {user_id: i + num_items + 1 for i, user_id in enumerate(unique_users)}
    num_users = len(user_map)
    df['mapped_user_id'] = df[user_col].map(user_map)
    print(f"  Mapped {num_users} unique users (IDs {num_items + 1} to {num_items + num_users}).")

    # Total entities for vocab size: items + users + 1 (for padding token 0)
    total_entities_in_vocab = num_items + num_users + 1
    print(f"  Total entities for vocab (incl. padding=0): {total_entities_in_vocab}")
    return df, user_map, item_map, num_items, num_users, total_entities_in_vocab

def create_sequences_with_user_id(df, user_col='user_id', mapped_item_col='mapped_item_id', mapped_user_col='mapped_user_id'):
    """
    Generates sequences for each user, inserting the user's mapped ID
    after every two mapped item IDs. Assumes df is sorted by user and timestamp.
    """
    print("Generating sequences with user ID injection...")
    user_sequences = defaultdict(list)
    grouped = df.groupby(user_col)
    for user_id, group in grouped:
        item_sequence = group[mapped_item_col].tolist()
        mapped_user_id_val = group[mapped_user_col].iloc[0]
        new_sequence = []
        item_count_since_last_user = 0
        for item_id_val in item_sequence:
            new_sequence.append(item_id_val)
            item_count_since_last_user += 1
            if item_count_since_last_user == 2:
                new_sequence.append(mapped_user_id_val)
                item_count_since_last_user = 0
        user_sequences[user_id] = new_sequence
    print(f"Finished generating sequences for {len(user_sequences)} users.")
    return dict(user_sequences)


In [None]:
def preds_to_item_recs_mixed_vocab(predictions_batches, item_id_reverse_map, num_items_in_vocab, top_k_items_to_return=10):
    """
    Processes raw model predictions (mixed vocab) to generate item recommendations.
    """
    print(f"Starting post-processing of prediction batches for top {top_k_items_to_return} items...")
    user_recs_list = []
    for batch_idx, batch_output in enumerate(predictions_batches):
        original_user_ids_batch = batch_output['user_ids']
        predicted_entity_ids_batch = batch_output['preds'] # Mapped entity IDs
        predicted_entity_scores_batch = batch_output['scores']

        for i in range(len(original_user_ids_batch)):
            user_id = original_user_ids_batch[i]
            entity_ids_for_user = predicted_entity_ids_batch[i]
            entity_scores_for_user = predicted_entity_scores_batch[i]

            # Filter for actual items: IDs from 1 to num_items_in_vocab
            item_mask = (entity_ids_for_user >= 1) & (entity_ids_for_user <= num_items_in_vocab)
            actual_item_ids = entity_ids_for_user[item_mask]
            actual_item_scores = entity_scores_for_user[item_mask]

            if len(actual_item_ids) == 0:
                continue

            sorted_indices = np.argsort(actual_item_scores)[::-1]
            top_items_for_user = actual_item_ids[sorted_indices][:top_k_items_to_return]
            top_scores_for_user = actual_item_scores[sorted_indices][:top_k_items_to_return]

            for item_idx in range(len(top_items_for_user)):
                user_recs_list.append({
                    'user_id': user_id,
                    'mapped_item_id': top_items_for_user[item_idx],
                    'prediction_score': top_scores_for_user[item_idx]
                })

    if not user_recs_list:
        print("No valid item recommendations generated.")
        return pd.DataFrame(columns=['user_id', 'item_id', 'prediction_score'])

    recs_df = pd.DataFrame(user_recs_list)
    if item_id_reverse_map is not None:
        recs_df['item_id'] = recs_df['mapped_item_id'].map(item_id_reverse_map)
        recs_df.dropna(subset=['item_id'], inplace=True) # Drop if mapping failed
        if 'item_id' in recs_df.columns and not recs_df['item_id'].empty:
             # Attempt to cast to original item ID type, fallback to object
            try:
                original_item_type = pd.Series(list(item_id_reverse_map.values())).dtype
                recs_df['item_id'] = recs_df['item_id'].astype(original_item_type)
            except Exception:
                recs_df['item_id'] = recs_df['item_id'].astype(object)

    else:
        recs_df['item_id'] = recs_df['mapped_item_id'] # Use mapped ID if no reverse map

    if 'item_id' not in recs_df.columns: # Ensure column exists
        recs_df['item_id'] = np.nan

    final_df = recs_df[['user_id', 'item_id', 'prediction_score']].copy()
    final_df.sort_values(['user_id', 'prediction_score'], ascending=[True, False], inplace=True)
    print(f"Post-processing complete. Returning {len(final_df)} item recommendations.")
    return final_df

In [None]:
class RecSysDataset(Dataset):
    def __init__(self, sequences_dict, max_length, padding_idx=0):
        """
        Args:
            sequences_dict (dict): Dict mapping original user_id to their sequence of mapped entity IDs.
            max_length (int): Maximum sequence length for padding.
            padding_idx (int): Value to use for padding.
        """
        self.user_ids_orig = list(sequences_dict.keys())
        self.sequences = [sequences_dict[uid] for uid in self.user_ids_orig]
        self.max_length = max_length
        self.padding_idx = padding_idx

    def __len__(self):
        return len(self.user_ids_orig)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        
        # Input sequence is up to max_length - 1, target is shifted
        input_seq = seq[:-1]
        target_seq = seq[1:]

        # Pad sequences
        input_len = len(input_seq)
        target_len = len(target_seq)

        padded_input = input_seq[:self.max_length] + [self.padding_idx] * (self.max_length - input_len)
        padded_target = target_seq[:self.max_length] + [self.padding_idx] * (self.max_length - target_len)
        
        return {
            'user_id_orig': self.user_ids_orig[idx], # For tracking/prediction
            'input_ids': torch.LongTensor(padded_input[:self.max_length]),
            'target_ids': torch.LongTensor(padded_target[:self.max_length])
        }

class PaddingCollateFn:
    def __init__(self, padding_idx=0):
        self.padding_idx = padding_idx

    def __call__(self, batch):
        # Batch is a list of dicts from RecSysDataset.__getitem__
        user_ids_orig = [item['user_id_orig'] for item in batch]
        input_ids_list = [item['input_ids'] for item in batch]
        target_ids_list = [item['target_ids'] for item in batch]

        # Pad to the max length in this specific batch (already done in Dataset, but good for safety)
        # Or rely on max_length from Dataset for consistent tensor shapes
        padded_input_ids = torch.stack(input_ids_list)
        padded_target_ids = torch.stack(target_ids_list)
        
        return {
            'user_ids_orig': user_ids_orig,
            'input_ids': padded_input_ids,
            'target_ids': padded_target_ids
        }

class RecSysPredictionDataset(Dataset):
    def __init__(self, sequences_dict, max_length, padding_idx=0):
        """ For generating predictions, we only need the input sequence. """
        self.user_ids_orig = list(sequences_dict.keys())
        # Use the full sequence as input for prediction
        self.sequences = [sequences_dict[uid] for uid in self.user_ids_orig]
        self.max_length = max_length
        self.padding_idx = padding_idx

    def __len__(self):
        return len(self.user_ids_orig)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        seq_len = len(seq)
        padded_seq = seq[:self.max_length] + [self.padding_idx] * (self.max_length - seq_len)
        
        return {
            'user_id_orig': self.user_ids_orig[idx],
            'input_ids': torch.LongTensor(padded_seq[:self.max_length])
        }


In [None]:
class SeqRecModule(pl.LightningModule):
    def __init__(self, model, learning_rate, predict_top_k=10, padding_idx=0):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.predict_top_k = predict_top_k
        self.padding_idx = padding_idx # For loss calculation
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.padding_idx)
        self.save_hyperparameters(ignore=['model']) # ignore model to avoid saving it twice

    def forward(self, input_ids):
        return self.model(input_ids)

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        target_ids = batch['target_ids']
        logits = self(input_ids) # (batch_size, seq_len, vocab_size)
        
        # Reshape for CrossEntropyLoss: (N, C) and (N)
        # Logits: (batch_size * seq_len, vocab_size)
        # Target: (batch_size * seq_len)
        loss = self.criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        target_ids = batch['target_ids']
        logits = self(input_ids)
        loss = self.criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
        self.log('val_loss', loss, prog_bar=True)
        
        # Simple accuracy (next item prediction accuracy)
        # Consider only non-padded targets
        mask = (target_ids != self.padding_idx).view(-1)
        if mask.sum() > 0:
            preds = torch.argmax(logits.view(-1, logits.size(-1))[mask], dim=1)
            correct_targets = target_ids.view(-1)[mask]
            accuracy = (preds == correct_targets).float().mean()
            self.log('val_acc', accuracy, prog_bar=True)
        else:
            self.log('val_acc', 0.0, prog_bar=True) # Or handle as NaN/skip
        return loss
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        input_ids = batch['input_ids'] # (batch_size, seq_len)
        user_ids_orig = batch['user_ids_orig']

        logits = self(input_ids) # (batch_size, seq_len, vocab_size)
        
        # For recommendation, typically use logits from the last relevant time step.
        # Here, we'll take the logits corresponding to the prediction for the item *after* the last input item.
        # This means we look at the logits at the sequence position of the last actual input item.
        # A simpler approach for next-item prediction is to use the logits from the last time step of the output.
        last_step_logits = logits[:, -1, :] # (batch_size, vocab_size)

        scores, top_k_preds = torch.topk(last_step_logits, self.predict_top_k, dim=1)
        
        return {'user_ids': np.array(user_ids_orig), 'preds': top_k_preds.cpu().numpy(), 'scores': scores.cpu().numpy()}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:

def prepare_data_for_run(config: DictConfig):
    """Loads, preprocesses, and splits data."""
    print("--- Preparing Data ---")
    if config.data.path == "dummy":
        print("Using dummy data...")
        data_dict = {
            'user_id': ['u1','u1','u1','u1','u1', 'u2','u2','u2', 'u3','u3','u3','u3', 'u4','u4','u4','u4','u4'],
            'item_id': ['i1','i2','i3','i4','i5', 'i2','i3','i6', 'i1','i7','i3','i8', 'i2','i5','i9','i1','i2'],
            'timestamp': [1,  2,  3,  4,  5,   1,  2,  3,   1,  2,  3,  4,   1,  2,  3,  4,  5]
        }
        df = pd.DataFrame(data_dict)
    else:
        print(f"Loading data from {config.data.path}...")
        # Assuming CSV with 'user_id', 'item_id', 'timestamp' columns
        df = pd.read_csv(config.data.path) # Add sep, header etc. if needed

    df = add_time_idx(df.copy()) # Ensure sorting and time_idx

    # Apply filters if specified
    if config.data.get('item_min_count', 0) > 0:
        df = filter_items(df, config.data.item_min_count)
    if config.data.get('user_min_count', 0) > 0:
        df = filter_users(df, config.data.user_min_count)
    
    if df.empty:
        raise ValueError("DataFrame is empty after filtering. Check filter conditions or data.")

    # Map IDs and create sequences (this is done for the whole dataset first)
    df_mapped, user_map, item_map, num_items, num_users, total_entities_vocab = map_ids(df)
    
    # Create sequences based on the whole mapped df
    # The sequences are then split logically for train/val/test
    # This approach differs slightly from example, where df is split first
    # For sequence-aware splitting, it's often better to split user sequences
    
    all_sequences_mapped = create_sequences_with_user_id(df_mapped.sort_values(['user_id', 'time_idx']))

    # Split sequences for train, validation, test
    # Example: for each user, last 2 interactions for test, next last for val
    train_sequences, val_sequences, test_sequences = {}, {}, {}
    for user_id_orig, seq in all_sequences_mapped.items():
        if len(seq) < 3: # Need at least 3 items in mixed seq for train/val/test split
            # Or put very short sequences only in train
            train_sequences[user_id_orig] = seq 
            continue
        
        # This splitting needs care due to injected user IDs
        # A simpler split based on original item interactions might be more robust
        # For now, a simple split of the *generated mixed sequence*:
        test_sequences[user_id_orig] = seq # The full sequence for predicting the very last part
        val_sequences[user_id_orig] = seq[:-1] # Sequence up to second to last element
        train_sequences[user_id_orig] = seq[:-2] # Sequence up to third to last
    
    # Filter out empty sequences after splitting
    train_sequences = {k: v for k, v in train_sequences.items() if len(v) > 1} # Min len 2 for input/target pair
    val_sequences = {k: v for k, v in val_sequences.items() if len(v) > 1}
    test_sequences = {k: v for k, v in test_sequences.items() if len(v) > 0}


    print(f"  Num train user sequences: {len(train_sequences)}")
    print(f"  Num val user sequences: {len(val_sequences)}")
    print(f"  Num test user sequences: {len(test_sequences)}")
    
    return train_sequences, val_sequences, test_sequences, item_map, user_map, num_items, num_users, total_entities_vocab


def create_dataloaders_for_run(train_seq, val_seq, config: DictConfig, padding_idx=0):
    print("--- Creating DataLoaders ---")
    train_dataset = RecSysDataset(train_seq, config.model.max_length, padding_idx)
    val_dataset = RecSysDataset(val_seq, config.model.max_length, padding_idx)
    
    collate_fn = PaddingCollateFn(padding_idx)

    train_loader = DataLoader(train_dataset, batch_size=config.trainer.batch_size, shuffle=True,
                              num_workers=config.trainer.num_workers, collate_fn=collate_fn, pin_memory=True)
    eval_loader = DataLoader(val_dataset, batch_size=config.trainer.batch_size, shuffle=False,
                             num_workers=config.trainer.num_workers, collate_fn=collate_fn, pin_memory=True)
    print(f"  Train Dataloader: {len(train_loader)} batches")
    print(f"  Validation Dataloader: {len(eval_loader)} batches")
    return train_loader, eval_loader, train_dataset, val_dataset


def create_mamba_model_for_run(config: DictConfig, vocab_size: int, padding_idx: int):
    print("--- Creating Model ---")
    # Ensure model_params from config are passed correctly
    mamba_specific_config = config.model.model_params.mamba_config
    
    model = Mamba4Rec(
        vocab_size=vocab_size,
        mamba_config=mamba_specific_config, # Pass the nested mamba_config
        padding_idx=padding_idx,
        add_head=config.model.model_params.get('add_head', True),
        tie_weights=config.model.model_params.get('tie_weights', True),
        init_std=config.model.model_params.get('init_std', 0.02)
    )
    return model

def training_run(seq_rec_module: SeqRecModule, train_loader: DataLoader, eval_loader: DataLoader, config: DictConfig):
    print("--- Training Model ---")
    callbacks = [
        TQDMProgressBar(refresh_rate=10), # Smaller refresh rate for faster updates
        ModelCheckpoint(
            save_top_k=1,
            monitor=config.trainer.early_stopping.monitor,
            mode=config.trainer.early_stopping.mode,
            filename='best_model-{epoch:02d}-{val_loss:.2f}'
        )
    ]
    if config.trainer.early_stopping.enabled:
        callbacks.append(EarlyStopping(
            monitor=config.trainer.early_stopping.monitor,
            patience=config.trainer.early_stopping.patience,
            mode=config.trainer.early_stopping.mode,
            verbose=True
        ))

    trainer = pl.Trainer(
        max_epochs=config.trainer.max_epochs,
        accelerator=config.trainer.accelerator,
        devices=config.trainer.devices,
        callbacks=callbacks,
        logger=True, # Basic logger, replace with ClearMLLogger or TensorBoardLogger if needed
        deterministic=config.trainer.get('deterministic', False) # For reproducibility
    )
    trainer.fit(model=seq_rec_module, train_dataloaders=train_loader, val_dataloaders=eval_loader)
    
    # Load best model for prediction
    best_model_path = callbacks[1].best_model_path # ModelCheckpoint is the second callback
    print(f"Best model path: {best_model_path}")
    if best_model_path:
         # The LightningModule's load_from_checkpoint will recreate the module with its hparams
        trained_module = SeqRecModule.load_from_checkpoint(best_model_path)
    else:
        print("No best model path found, using model from last epoch.")
        trained_module = seq_rec_module # Fallback
        
    return trainer, trained_module


def predict_run(trainer: pl.Trainer, seq_rec_module: SeqRecModule, sequences_to_predict: dict, config: DictConfig, padding_idx: int):
    print("--- Predicting ---")
    if not sequences_to_predict:
        print("No sequences to predict. Skipping.")
        return pd.DataFrame(), None

    predict_dataset = RecSysPredictionDataset(sequences_to_predict, config.model.max_length, padding_idx)
    collate_fn = PaddingCollateFn(padding_idx=padding_idx) # Use the same collate for prediction inputs
    
    # Create a predict_loader that only provides input_ids and user_ids_orig
    predict_loader_custom = DataLoader(
        predict_dataset,
        batch_size=config.trainer.batch_size, # Use same batch size or a specific predict_batch_size
        shuffle=False,
        num_workers=config.trainer.num_workers,
        collate_fn=lambda batch: { # Custom collate for prediction
            'user_ids_orig': [item['user_id_orig'] for item in batch],
            'input_ids': torch.stack([item['input_ids'] for item in batch])
        }
    )
    
    seq_rec_module.model.eval() # Ensure model is in eval mode
    seq_rec_module.predict_top_k = config.evaluation.get('predict_top_k_entities', 20) # How many entities model predicts
    
    raw_predictions_batches = trainer.predict(model=seq_rec_module, dataloaders=predict_loader_custom)
    
    return raw_predictions_batches, predict_dataset


def evaluate_run(raw_predictions_batches: list, item_id_reverse_map: dict, num_items: int,
                 ground_truth_sequences: dict, config: DictConfig, clearml_task: Task = None, prefix: str = "eval"):
    print(f"--- Evaluating {prefix} ---")
    if not raw_predictions_batches or not ground_truth_sequences:
        print(f"No predictions or ground truth for {prefix}. Skipping evaluation.")
        return

    recs_df = preds_to_item_recs_mixed_vocab(
        predictions_batches=raw_predictions_batches,
        item_id_reverse_map=item_id_reverse_map,
        num_items_in_vocab=num_items,
        top_k_items_to_return=max(config.evaluation.top_k_metrics) # Get enough items for all K
    )
    print(f"{prefix} recommendations head:\n{recs_df.head()}")

    # Dummy metric calculation (replace with actual metric functions)
    # Example: compute_metrics(test_df_for_eval, recs_df, k)
    # test_df_for_eval needs to be created from ground_truth_sequences
    # For simplicity, just printing shapes and a placeholder
    
    all_metrics = {}
    for k_metric in config.evaluation.top_k_metrics:
        # Placeholder for actual metric calculation.
        # You'd typically compare `recs_df` against the held-out items in `ground_truth_sequences`.
        # For example, for each user in `recs_df`, find their actual next items from `ground_truth_sequences`.
        # Then calculate NDCG@k, Recall@k, etc.
        
        # Example of how ground truth for evaluation could be structured:
        # gt_eval_data = []
        # for user_id_orig, full_seq in ground_truth_sequences.items():
        #     # Assuming the last item(s) of full_seq are the ground truth for prediction
        #     # This depends on how train/val/test sequences were split.
        #     # If test_sequences are full sequences, target is the very last item.
        #     # If train_sequences were seq[:-1] and test_sequences were seq, then target is seq[-1]
        #     # This part needs careful alignment with the data splitting logic.
        #     # For now, let's assume ground_truth_sequences contain the items to be predicted.
        #
        #     # This is a simplified placeholder. True evaluation is complex.
        #     # We need to extract the *actual next items* that were held out for this user.
        #     # The current `ground_truth_sequences` are the *input* sequences for prediction.
        #     # We need the corresponding *target* sequences that were held out.
        #
        #     # Let's assume for this placeholder that we have a separate ground truth DataFrame.
        #     # For now, we'll just simulate some metrics.
        
        simulated_ndcg = np.random.rand() * 0.1 + (0.1 / k_metric) # Dummy value
        simulated_recall = np.random.rand() * 0.2 + (0.2 / k_metric) # Dummy value
        
        metrics = {f'{prefix}_ndcg@{k_metric}': simulated_ndcg, f'{prefix}_recall@{k_metric}': simulated_recall}
        print(metrics)
        all_metrics.update(metrics)

    if clearml_task:
        for key, value in all_metrics.items():
            clearml_task.get_logger().report_scalar(title=key, series=key, value=value, iteration=0)
        print(f"Metrics reported to ClearML for {prefix}.")


In [None]:
@hydra.main(version_base=None, config_path=None, config_name="experiment_config")
def run_experiment(config: DictConfig):
    print("--- Experiment Start ---")
    print(OmegaConf.to_yaml(config))
    
    # For reproducibility
    if config.get('seed'):
        pl.seed_everything(config.seed, workers=True)

    # Initialize ClearML Task (optional)
    clearml_task_instance = None
    if Task and config.clearml.project_name and config.clearml.task_name:
        clearml_task_instance = Task.init(
            project_name=config.clearml.project_name,
            task_name=config.clearml.task_name,
            reuse_last_task_id=False
        )
        clearml_task_instance.connect(OmegaConf.to_container(config, resolve=True))
        print("ClearML task initialized.")

    # 1. Prepare Data
    train_seqs, val_seqs, test_seqs, item_map, user_map, num_items, num_users, total_vocab_size = prepare_data_for_run(config)
    item_id_reverse_map = {v: k for k, v in item_map.items()}
    padding_idx = config.model.get('padding_idx', 0)

    # 2. Create DataLoaders
    train_loader, val_loader, train_dataset, val_dataset = create_dataloaders_for_run(train_seqs, val_seqs, config, padding_idx)

    # 3. Create Model
    mamba_model = create_mamba_model_for_run(config, total_vocab_size, padding_idx)
    
    # 4. Setup Lightning Module
    seq_rec_lightning_module = SeqRecModule(
        mamba_model,
        learning_rate=config.trainer.learning_rate,
        padding_idx=padding_idx
    )
    
    # 5. Training
    start_time = time.time()
    trainer, trained_seq_rec_module = training_run(seq_rec_lightning_module, train_loader, val_loader, config)
    training_time = time.time() - start_time
    print(f"Training completed in {training_time:.2f} seconds.")
    if clearml_task_instance:
        clearml_task_instance.get_logger().report_single_value('training_time_seconds', training_time)

    # 6. Prediction & Evaluation on Validation Set (Optional, usually done during training by PL)
    # Here we demonstrate explicit prediction and evaluation on the val set if needed post-training.
    # val_raw_preds, _ = predict_run(trainer, trained_seq_rec_module, val_seqs, config, padding_idx)
    # evaluate_run(val_raw_preds, item_id_reverse_map, num_items, val_seqs, config, clearml_task_instance, prefix="val_post_train")

    # 7. Prediction & Evaluation on Test Set
    test_raw_preds, _ = predict_run(trainer, trained_seq_rec_module, test_seqs, config, padding_idx)
    evaluate_run(test_raw_preds, item_id_reverse_map, num_items, test_seqs, config, clearml_task_instance, prefix="test")

    if clearml_task_instance:
        clearml_task_instance.close()
    print("--- Experiment End ---")

In [None]:
def save_dummy_config():
    # This function creates a dummy config file if one doesn't exist.
    # In a real Hydra setup, you'd have this in your config directory.
    config_content = """
# @package _group_
seed: 42

data:
  path: "dummy" # Path to data CSV or "dummy" for internal dummy data
  item_min_count: 1 # Min interactions for an item to be kept
  user_min_count: 1 # Min interactions for a user to be kept

model:
  name: "Mamba4Rec"
  max_length: 50 # Max sequence length for padding and processing
  padding_idx: 0
  model_params:
    add_head: True
    tie_weights: True
    init_std: 0.02
    mamba_config: # Parameters specific to the Mamba layer itself
      d_model: 32
      # n_layer: 1 # The Mamba class from mamba_ssm is a single block. Stacking needs custom logic.
      d_state: 8
      d_conv: 2
      expand: 2
      # Add other mamba-specific params like bias, conv_bias if needed

trainer:
  batch_size: 16 # Reduced for small dummy data
  num_workers: 0 # For Windows, 0 is often more stable. For Linux, can be > 0.
  learning_rate: 1.0e-3
  max_epochs: 3 # Keep low for quick demo
  accelerator: "auto" # "cpu", "gpu", "tpu", "mps", "auto"
  devices: "auto" # Number of devices or "auto"
  deterministic: True # For reproducibility
  early_stopping:
    enabled: True
    monitor: "val_loss" # Metric to monitor
    patience: 3         # Number of epochs with no improvement
    mode: "min"         # "min" for loss/error, "max" for accuracy/NDCG

evaluation:
  top_k_metrics: [5, 10] # List of K values for metrics
  predict_top_k_entities: 20 # How many entities model predicts before filtering for items

clearml: # Optional ClearML configuration
  project_name: "Mamba4Rec_Experiments"
  task_name: "Default_Run"

# Hydra specific settings
hydra:
  run:
    dir: outputs_mamba4rec/\${now:%Y-%m-%d}/\${now:%H-%M-%S}
  sweep:
    dir: multirun_mamba4rec/\${now:%Y-%m-%d}/\${now:%H-%M-%S}
    subdir: \${hydra.job.num}

"""
    os.makedirs("configs", exist_ok=True)
    with open("configs/experiment_config.yaml", "w") as f:
        f.write(config_content)
    print("Dummy 'configs/experiment_config.yaml' created.")


In [None]:
if __name__ == '__main__':
    # Create a dummy config if it doesn't exist, for standalone running.
    # In a typical Hydra project, you'd run from CLI: python your_script.py
    if not os.path.exists("configs/experiment_config.yaml"):
        save_dummy_config()
    
    # Set the config path for Hydra when running as a script
    # This is a bit of a workaround for running Hydra from a notebook-like script.
    # Normally, Hydra finds configs based on the script's location or explicit CLI args.
    # Here, we ensure Hydra knows where to look.
    
    # Launch the Hydra application
    # Note: When running this script directly, Hydra's working directory changes.
    # Paths in config (like data.path) might need to be absolute or relative to original script location.
    # Alternatively, use hydra.utils.to_absolute_path() within the run_experiment function.
    
    # To run, you would typically execute from the command line:
    # python <your_script_name>.py
    # Hydra will then pick up the @hydra.main decorator and the config.
    
    # For interactive environments or direct script execution,
    # you might need to adjust how Hydra is invoked or how paths are handled.
    # The `save_dummy_config` and the call to `run_experiment()` below are
    # primarily for making this combined script runnable.
    
    # If you get errors related to config path, ensure 'configs/experiment_config.yaml' exists
    # or run with `python your_script_name.py --config-dir ./configs --config-name experiment_config`
    # from the directory containing `your_script_name.py` and the `configs` folder.

    # This direct call to run_experiment() will bypass Hydra's CLI if not run via `python script.py`
    # For full Hydra functionality (overrides, multirun), use the CLI.
    # However, @hydra.main should still work if the config can be loaded.
    
    # To make it runnable directly as `python script.py` where `script.py` is this file:
    # We assume `configs/experiment_config.yaml` is created by `save_dummy_config`.
    # The @hydra.main decorator will handle the rest.
    run_experiment()
