In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import RobustScaler
import time
import os
import sys
from tqdm import tqdm
import shutil
import pyarrow.parquet as pq
import pyarrow as pa
from uuid import uuid4
import warnings
import subprocess
warnings.filterwarnings('ignore')

# Enhanced PyTorch and CUDA diagnostics
def check_cuda_environment():
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. This notebook requires a GPU.")
    
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU count: {torch.cuda.device_count()}")
    print(f"Current device: {torch.cuda.current_device()}")
    
    # Test CUDA operation
    try:
        test_tensor = torch.tensor([1.0, 2.0, 3.0], device='cuda')
        test_result = test_tensor + 1
        print(f"CUDA test operation successful: {test_result}")
    except Exception as e:
        print(f"CUDA test operation failed: {e}")
        raise
    
    # Check NVIDIA driver and CUDA toolkit
    try:
        nvidia_smi = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
        print("NVIDIA-SMI output:")
        print(nvidia_smi.stdout)
    except Exception as e:
        print(f"Failed to run nvidia-smi: {e}")
    
    # Check GPU memory
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 2**30:.2f} GiB")
    print(f"Allocated GPU memory: {torch.cuda.memory_allocated(0) / 2**30:.2f} GiB")
    print(f"Reserved GPU memory: {torch.cuda.memory_reserved(0) / 2**30:.2f} GiB")

try:
    check_cuda_environment()
except Exception as e:
    print(f"Error with PyTorch or CUDA setup: {e}")
    print("Try reinstalling PyTorch: pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cu124")
    raise

# Environment diagnostics
print(f"Python version: {sys.version}")
print(f"Python executable: {sys.executable}")
print(f"PATH: {os.environ.get('PATH')}")
print(f"Available disk space: {shutil.disk_usage('/').free / (2**30):.2f} GiB")

# Check for module shadowing
if os.path.exists('/workspace/XAI/torch.py') or os.path.exists('/workspace/XAI/torch.pyc'):
    print("Warning: Found 'torch.py' or 'torch.pyc' in /workspace/XAI. Please rename or remove it.")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration (GPU only)
device = torch.device("cuda")
print(f"Using device: {device}")

# Clear GPU memory
torch.cuda.empty_cache()
print("Cleared GPU memory cache")

In [None]:
def load_data(data_dir="/workspace/data", file_name="merged_data.csv"):
    file_path = os.path.join(data_dir, file_name)
    alt_path = "/workspace/XAI-1/Predict Future Sales/merged_data.csv"
    
    if not os.path.exists(file_path):
        if os.path.exists(alt_path):
            file_path = alt_path
        else:
            raise FileNotFoundError(f"File not found at {file_path} or {alt_path}")
    
    try:
        data = pd.read_csv(file_path)
        print(f"Loaded data from {file_path}")
        print(f"Dataset shape: {data.shape}")
        print(f"Columns: {list(data.columns)}")
    except Exception as e:
        raise RuntimeError(f"Failed to load {file_path}: {e}")
    
    # Verify expected columns
    expected_columns = ['date', 'shop_id', 'item_id', 'item_name', 'item_cnt_day', 'item_price', 'item_category_id', 'shop_name', 'item_category_name', 'date_block_num']
    missing_cols = [col for col in expected_columns if col not in data.columns]
    if missing_cols:
        print(f"Warning: Missing expected columns: {missing_cols}")
    
    return data

# Load data
data = load_data()

In [None]:
# Import libraries
import os
import time
import json
import pickle
import logging
import psutil
import numpy as np
import pandas as pd
import polars as pl
from datetime import datetime
from sklearn.preprocessing import RobustScaler

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[
        logging.FileHandler('/workspace/processed_data/preprocess.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def log_memory_usage():
    """Log current memory usage in GB."""
    mem = psutil.Process().memory_info().rss / (1024 ** 3)
    logger.info(f"Memory usage: {mem:.2f} GB")

def load_data(data_dir="/workspace/data", file_name="merged_data.csv"):
    """Load CSV data with fallback path."""
    file_path = os.path.join(data_dir, file_name)
    alt_path = "/workspace/XAI-1/Predict Future Sales/merged_data.csv"
    path = file_path if os.path.exists(file_path) else alt_path
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found at {path}")
    
    df = pd.read_csv(path)
    logger.info(f"Loaded data from {path}, shape: {df.shape}")
    logger.info(f"Columns: {df.columns.tolist()}")
    
    expected_cols = ['date', 'shop_id', 'item_id', 'item_name', 'item_cnt_day', 
                     'item_price', 'item_category_id', 'shop_name', 'item_category_name', 'date_block_num']
    missing_cols = [col for col in expected_cols if col not in df.columns]
    if missing_cols:
        logger.warning(f"Missing columns: {missing_cols}")
    
    return df

# Load data
data = load_data()

In [None]:
def get_russian_holidays():
    """Return DataFrame of Russian holidays and shopping events (2013–2015)."""
    holidays = [
        ("2013-01-01", "New Year"), ("2014-01-01", "New Year"), ("2015-01-01", "New Year"),
        ("2013-02-23", "Defender Day"), ("2014-02-23", "Defender Day"), ("2015-02-23", "Defender Day"),
        ("2013-03-08", "Women's Day"), ("2014-03-08", "Women's Day"), ("2015-03-08", "Women's Day"),
        ("2013-06-12", "Russia Day"), ("2014-06-12", "Russia Day"), ("2015-06-12", "Russia Day"),
        ("2013-11-29", "Black Friday"), ("2014-11-28", "Black Friday"), ("2015-11-27", "Black Friday")
    ]
    holiday_df = pl.DataFrame({
        "date": [datetime.strptime(date, "%Y-%m-%d") for date, _ in holidays],
        "holiday": [name for _, name in holidays]
    }).with_columns([
        pl.col("date").dt.month().cast(pl.Int32).alias("month"),
        pl.col("date").dt.year().cast(pl.Int32).alias("year")
    ])
    return holiday_df

# Create holiday DataFrame
holiday_df = get_russian_holidays()
logger.info(f"Holiday DataFrame created with {len(holiday_df)} entries")

In [None]:
def initial_preprocessing(df):
    """Perform initial data cleaning and filtering."""
    start_time = time.time()
    
    # Convert to Polars
    df = pl.from_pandas(df)
    logger.info(f"Initial dataset size: {len(df)}")
    logger.info(f"Unique shops: {df['shop_id'].n_unique()}, items: {df['item_id'].n_unique()}")
    logger.info(f"Date block range: {df['date_block_num'].min()}–{df['date_block_num'].max()}")
    log_memory_usage()
    
    # Parse dates
    df = df.with_columns(
        pl.col('date').str.strptime(pl.Date, "%Y-%m-%d")
    ).with_columns([
        pl.col('date').dt.month().cast(pl.Int32).alias('month'),
        pl.col('date').dt.year().cast(pl.Int32).alias('year')
    ])
    
   
    
    # Filter to date_block_num <= 32
    df = df.filter(pl.col('date_block_num') <= 32)
    logger.info(f"Dataset size after date filter: {len(df)}")
    
    # Note: Top-54 shop filtering is skipped, assuming all shops are used (paper reports 54 shops)
    logger.info(f"Unique shops after processing: {df['shop_id'].n_unique()}")
    
    # Optimize dtypes
    df = df.with_columns([
        pl.col('date_block_num').cast(pl.Int16),
        pl.col('shop_id').cast(pl.Int32),
        pl.col('item_id').cast(pl.Int32),
        pl.col('item_category_id').cast(pl.Int32),
        pl.col('item_cnt_day').cast(pl.Float32),
        pl.col('item_price').cast(pl.Float32),
        pl.col('month').cast(pl.Int32),
        pl.col('year').cast(pl.Int32)
    ])
    
    logger.info(f"Initial preprocessing time: {time.time() - start_time:.2f} seconds")
    log_memory_usage()
    return df

# Apply initial preprocessing
df = initial_preprocessing(data)

In [None]:
def handle_outliers_and_returns(df):
    """Apply rolling Winsorization and handle negative sales."""
    start_time = time.time()
    
    # Rolling Winsorization (30-day window)
    df = df.sort(['shop_id', 'item_id', 'date']).with_columns(
        rolling_quantile=pl.col('item_cnt_day').rolling_quantile(
            quantile=0.99, window_size=30, min_periods=1
        ).over(['shop_id', 'item_id'])
    ).with_columns(
        item_cnt_day_winsor=pl.col('item_cnt_day').clip(None, pl.col('rolling_quantile'))
    )
    outlier_count = df.filter(pl.col('item_cnt_day') > pl.col('rolling_quantile')).height
    logger.info(f"Outliers capped: {outlier_count} ({outlier_count / len(df) * 100:.2f}%)")
    
    # Handle negative sales
    df = df.with_columns([
        pl.when(pl.col('item_cnt_day_winsor') < 0)
          .then(pl.col('item_cnt_day_winsor').abs())
          .otherwise(0)
          .alias('returns'),
        pl.col('item_cnt_day_winsor').clip(lower_bound=0).alias('item_cnt_day_winsor')
    ])
    logger.info(f"Negative sales after processing: {df.filter(pl.col('item_cnt_day_winsor') < 0).height}")
    
    logger.info(f"Outlier and returns handling time: {time.time() - start_time:.2f} seconds")
    log_memory_usage()
    return df

# Apply outlier and returns handling
df = handle_outliers_and_returns(df)

In [None]:
def aggregate_to_monthly(df):
    """Aggregate data to monthly level."""
    start_time = time.time()
    
    df = df.group_by(['date_block_num', 'shop_id', 'item_id', 'item_category_id', 'month', 'year']).agg([
        pl.col('item_cnt_day_winsor').sum().alias('item_cnt_day_winsor'),
        pl.col('returns').sum().alias('returns'),
        pl.col('item_price').mean().alias('item_price')
    ])
    
    logger.info(f"Dataset size after aggregation: {len(df)}")
    logger.info(f"Unique shops: {df['shop_id'].n_unique()}, items: {df['item_id'].n_unique()}")
    logger.info(f"Aggregation time: {time.time() - start_time:.2f} seconds")
    log_memory_usage()
    return df

# Aggregate to monthly
df = aggregate_to_monthly(df)

In [None]:
def create_full_grid(df):
    """Create full shop-item-month grid."""
    start_time = time.time()
    
    shops = df['shop_id'].unique().to_list()
    items = df['item_id'].unique().to_list()
    date_blocks = list(range(33))  # 0–32
    
    grid = pl.DataFrame({'shop_id': shops}).join(
        pl.DataFrame({'item_id': items}), how='cross'
    ).join(
        pl.DataFrame({'date_block_num': date_blocks}), how='cross'
    ).with_columns([
        pl.col('shop_id').cast(pl.Int32),
        pl.col('item_id').cast(pl.Int32),
        pl.col('date_block_num').cast(pl.Int16),
        ((pl.col('date_block_num') % 12) + 1).cast(pl.Int32).alias('month'),
        ((pl.col('date_block_num') // 12) + 2013).cast(pl.Int32).alias('year')
    ])
    
    # Merge with aggregated data
    df = grid.join(
        df, on=['shop_id', 'item_id', 'date_block_num', 'month', 'year'], how='left'
    ).with_columns([
        pl.col('item_cnt_day_winsor').fill_null(0),
        pl.col('returns').fill_null(0),
        pl.col('item_price').fill_null(pl.col('item_price').mean().over('item_id')).fill_null(0),
        pl.col('item_category_id').fill_null(pl.col('item_category_id').first().over('item_id')).fill_null(0)
    ]).with_columns(
        pl.datetime(pl.col('year'), pl.col('month'), 1).alias('date')
    )
    
    logger.info(f"Grid size: {len(grid)}, after merge: {len(df)}")
    logger.info(f"Unique shops: {df['shop_id'].n_unique()}, items: {df['item_id'].n_unique()}")
    logger.info(f"Grid creation time: {time.time() - start_time:.2f} seconds")
    log_memory_usage()
    return df

# Create full grid
df = create_full_grid(df)

In [None]:
def seasonal_imputation(df, col):
    """Apply seasonality-aware imputation to a column."""
    df = df.with_columns(
        pl.col(col).interpolate().over(['shop_id', 'item_id']).alias(f'{col}_interp')
    ).with_columns(
        seasonal_value=pl.col(col).shift(12).over(['shop_id', 'item_id', 'month']),
        ma_value=pl.col(col).rolling_mean(window_size=12, min_periods=1).over(['shop_id', 'item_id'])
    ).with_columns(
        pl.when(pl.col(f'{col}_interp').is_null() & pl.col('seasonal_value').is_not_null())
          .then(pl.col('seasonal_value'))
          .when(pl.col(f'{col}_interp').is_null())
          .then(pl.col('ma_value'))
          .otherwise(pl.col(f'{col}_interp'))
          .alias(col)
    ).drop([f'{col}_interp', 'seasonal_value', 'ma_value'])
    return df

def apply_imputation(df, cols):
    """Impute missing values for specified columns."""
    start_time = time.time()
    
    for col in cols:
        df = seasonal_imputation(df, col)
    df = df.with_columns([pl.col(col).fill_null(0) for col in cols])
    
    logger.info(f"Imputation time: {time.time() - start_time:.2f} seconds")
    log_memory_usage()
    return df

# Apply imputation
numerical_cols = ['item_cnt_day_winsor', 'returns', 'item_price']
df = apply_imputation(df, numerical_cols)

In [None]:
def add_holiday_features(df, holiday_df):
    """Add holiday features to the dataset."""
    start_time = time.time()
    
    df = df.join(
        holiday_df.select(['year', 'month', 'holiday']),
        on=['year', 'month'], how='left'
    ).with_columns(
        is_holiday=pl.col('holiday').is_not_null().cast(pl.Int8),
        holiday=pl.col('holiday').fill_null('None')
    )
    
    logger.info(f"Holiday feature addition time: {time.time() - start_time:.2f} seconds")
    log_memory_usage()
    return df

# Add holiday features
df = add_holiday_features(df, holiday_df)

In [None]:
def filter_sparse_products(df):
    """Exclude shop-item pairs with >30% missing data."""
    start_time = time.time()
    
    shop_item_missing = df.group_by(['shop_id', 'item_id']).agg(
        missing_ratio=pl.col('item_cnt_day_winsor').eq(0).mean()
    )
    valid_shop_items = shop_item_missing.filter(pl.col('missing_ratio') <= 0.3).select(['shop_id', 'item_id'])
    initial_size = len(df)
    df = df.join(valid_shop_items, on=['shop_id', 'item_id'], how='inner')
    
    logger.info(f"Records dropped due to >30% missing: {initial_size - len(df)}")
    logger.info(f"Records after filtering: {len(df)}, shop-item pairs: {len(valid_shop_items)}")
    logger.info(f"Filtering time: {time.time() - start_time:.2f} seconds")
    log_memory_usage()
    return df

# Filter sparse products
df = filter_sparse_products(df)

In [None]:
def create_lag_features(df):
    """Create lag features for 1–3 months."""
    start_time = time.time()
    
    df = df.sort(['shop_id', 'item_id', 'date'])
    for lag in [1, 2, 3]:
        df = df.with_columns([
            pl.col('item_cnt_day_winsor').shift(lag).over(['shop_id', 'item_id']).alias(f'lag_sales_{lag}'),
            pl.col('returns').shift(lag).over(['shop_id', 'item_id']).alias(f'lag_returns_{lag}'),
            pl.col('item_price').shift(lag).over(['shop_id', 'item_id']).alias(f'lag_price_{lag}')
        ])
    
    logger.info(f"Lag feature creation time: {time.time() - start_time:.2f} seconds")
    log_memory_usage()
    return df

# Create lag features
df = create_lag_features(df)

In [None]:
# Impute lag features
numerical_cols += ['lag_sales_1', 'lag_sales_2', 'lag_sales_3',
                   'lag_returns_1', 'lag_returns_2', 'lag_returns_3',
                   'lag_price_1', 'lag_price_2', 'lag_price_3']
logger.info("\nMissing values before lag imputation:")
logger.info(df.select(numerical_cols).null_count().to_pandas().to_string())

df = apply_imputation(df, numerical_cols)

logger.info("\nMissing values after lag imputation:")
logger.info(df.select(numerical_cols).null_count().to_pandas().to_string())

# Note: Ensure generate_embeddings.py creates entity embeddings for shop_id, item_id, item_category_id

In [None]:
def scale_and_save_data(df, numerical_cols):
    """Apply robust scaling and save datasets."""
    start_time = time.time()
    
    # Convert to pandas
    monthly_sales = df.to_pandas()
    del df  # Free memory
    log_memory_usage()
    
    # Save unscaled data
    monthly_sales.to_parquet('/workspace/processed_data/monthly_sales_unscaled.parquet')
    logger.info("Saved unscaled data")
    
    # Robust scaling
    scaler = RobustScaler()
    train_data = monthly_sales[monthly_sales['date_block_num'] < 30][numerical_cols]
    logger.info(f"Scaler training data shape: {train_data.shape}")
    
    scaler.fit(train_data)
    monthly_sales[numerical_cols] = scaler.transform(monthly_sales[numerical_cols])
    
    if monthly_sales[numerical_cols].isna().any().any():
        raise ValueError("NaNs introduced during scaling")
    
    with open('/workspace/processed_data/scaler.pkl', 'wb') as f:
        pickle.dump(scaler, f)
    
    # Optimize dtypes
    dtypes = {col: 'float32' for col in numerical_cols}
    dtypes.update({
        'shop_id': 'int32', 'item_id': 'int32', 'item_category_id': 'int32',
        'date_block_num': 'int16', 'month': 'int32', 'year': 'int32', 'is_holiday': 'int8'
    })
    monthly_sales = monthly_sales.astype(dtypes, errors='ignore')
    
    logger.info(f"Scaling and saving time: {time.time() - start_time:.2f} seconds")
    log_memory_usage()
    return monthly_sales

# Scale and save data
monthly_sales = scale_and_save_data(df, numerical_cols)

In [None]:
def split_and_save_sets(df):
    """Split data into train/val/test per paper (months 0–30, 31, 32)."""
    start_time = time.time()
    
    train_df = df[df['date_block_num'] <= 30]  # Inclusive
    val_df = df[df['date_block_num'] == 31]
    test_df = df[df['date_block_num'] == 32]
    
    logger.info(f"Train: {train_df.shape}, Val: {val_df.shape}, Test: {test_df.shape}")
    
    output_dir = '/workspace/processed_data'
    os.makedirs(output_dir, exist_ok=True)
    
    train_df.to_parquet(os.path.join(output_dir, 'X_train_processed.parquet'))
    val_df.to_parquet(os.path.join(output_dir, 'X_val_processed.parquet'))
    test_df.to_parquet(os.path.join(output_dir, 'X_test_processed.parquet'))
    df.to_parquet('/workspace/raw_data/processed_sales.parquet')
    
    logger.info(f"Split and save time: {time.time() - start_time:.2f} seconds")
    return train_df, val_df, test_df

# Split and save datasets
train_df, val_df, test_df = split_and_save_sets(monthly_sales)

In [None]:
def save_metadata(df, numerical_cols):
    """Save feature and date index as JSON."""
    start_time = time.time()
    
    feature_index = {col: i for i, col in enumerate(numerical_cols)}
    date_index = df[['date']].reset_index().rename(columns={'index': 'row_index'}).to_dict(orient='records')
    
    metadata = {
        "feature_index": feature_index,
        "date_index": date_index
    }
    
    with open('/workspace/processed_data/metadata.json', 'w') as f:
        json.dump(metadata, f, indent=2, default=str)
    
    logger.info("Saved metadata to /workspace/processed_data/metadata.json")
    logger.info(f"Metadata save time: {time.time() - start_time:.2f} seconds")

# Save metadata
save_metadata(monthly_sales, numerical_cols)

In [None]:
def validate_statistics(raw_df, processed_df):
    """Compute coefficient of variation for raw and processed data."""
    raw_cv = raw_df['item_cnt_day'].std() / raw_df['item_cnt_day'].mean()
    processed_cv = processed_df['item_cnt_day_winsor'].std() / processed_df['item_cnt_day_winsor'].mean()
    logger.info(f"Raw coefficient of variation: {raw_cv:.2f}")
    logger.info(f"Processed coefficient of variation: {processed_cv:.2f}")
    if abs(raw_cv - 2.8) > 0.1 or abs(processed_cv - 1.9) > 0.1:
        logger.warning("Coefficient of variation deviates from paper's reported values (raw: 2.8, processed: 1.9)")

# Validate statistics
validate_statistics(data, monthly_sales)

In [None]:
def validate_final_dataset(df):
    """Validate final dataset and log statistics."""
    logger.info(f"Final dataset size: {len(df)}")
    logger.info(f"Unique shops: {df['shop_id'].nunique()}, items: {df['item_id'].nunique()}")
    logger.info(f"Date block range: {df['date_block_num'].min()}–{df['date_block_num'].max()}")
    
    expected_size = 2935849  # Raw record count from paper
    if abs(len(df) - expected_size) / expected_size > 0.1:
        logger.warning(f"Dataset size {len(df)} deviates from expected {expected_size}")

# Validate final dataset
validate_final_dataset(monthly_sales)