In [71]:
import os
import sys
import subprocess
import warnings
import shutil
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import RobustScaler

warnings.filterwarnings('ignore')

def check_cuda_environment():
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. This notebook requires a GPU.")
    
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU count: {torch.cuda.device_count()}")
    print(f"Current device: {torch.cuda.current_device()}")
    
    try:
        test_tensor = torch.tensor([1.0, 2.0, 3.0], device='cuda')
        test_result = test_tensor + 1
        print(f"CUDA test operation successful: {test_result}")
    except Exception as e:
        print(f"CUDA test operation failed: {e}")
        raise
    
    try:
        nvidia_smi = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
        print("NVIDIA-SMI output:")
        print(nvidia_smi.stdout)
    except Exception as e:
        print(f"Failed to run nvidia-smi: {e}")
    
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 2**30:.2f} GiB")
    print(f"Allocated GPU memory: {torch.cuda.memory_allocated(0) / 2**30:.2f} GiB")
    print(f"Reserved GPU memory: {torch.cuda.memory_reserved(0) / 2**30:.2f} GiB")

try:
    check_cuda_environment()
except Exception as e:
    print(f"Error with PyTorch or CUDA setup: {e}")
    print("Try reinstalling PyTorch: pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cu124")
    raise

print(f"Python version: {sys.version}")
print(f"Python executable: {sys.executable}")
print(f"PATH: {os.environ.get('PATH')}")
print(f"Available disk space: {shutil.disk_usage('/').free / (2**30):.2f} GiB")

if os.path.exists('/workspace/XAI-2/XAI/torch.py') or os.path.exists('/workspace/XAI-2/XAI/torch.pyc'):
    print("Warning: Found 'torch.py' or 'torch.pyc' in /workspace/XAI-2/XAI. Please rename or remove it.")

torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda")
print(f"Using device: {device}")

torch.cuda.empty_cache()
print("Cleared GPU memory cache")

PyTorch version: 2.7.0+cu118
CUDA available: True
CUDA version: 11.8
GPU device: NVIDIA RTX A6000
GPU count: 1
Current device: 0
CUDA test operation successful: tensor([2., 3., 4.], device='cuda:0')


NVIDIA-SMI output:
Mon May 12 18:28:05 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               On  |   00000000:53:00.0 Off |                  Off |
| 30%   32C    P0             23W /  300W |     316MiB /  49140MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                             

In [72]:
def load_data(data_dir="/workspace/XAI-2/XAI/Predict Future Sales", file_name="merged_data.csv"):
    """Load CSV data with fallback path."""
    file_path = os.path.join(data_dir, file_name)
    alt_path = "/workspace/XAI-2/Predict Future Sales/merged_data.csv"
    
    if not os.path.exists(file_path):
        if os.path.exists(alt_path):
            file_path = alt_path
        else:
            raise FileNotFoundError(f"File not found at {file_path} or {alt_path}")
    
    try:
        data = pd.read_csv(file_path)
        print(f"Loaded data from {file_path}")
        print(f"Dataset shape: {data.shape}")
        print(f"Columns: {list(data.columns)}")
    except Exception as e:
        raise RuntimeError(f"Failed to load {file_path}: {e}")
    
    expected_columns = ['date', 'shop_id', 'item_id', 'item_name', 'item_cnt_day', 
                        'item_price', 'item_category_id', 'shop_name', 'item_category_name', 'date_block_num']
    missing_cols = [col for col in expected_columns if col not in data.columns]
    if missing_cols:
        print(f"Warning: Missing expected columns: {missing_cols}")
    
    return data

data = load_data()

Loaded data from /workspace/XAI-2/XAI/Predict Future Sales/merged_data.csv
Dataset shape: (2935849, 10)
Columns: ['date', 'date_block_num', 'shop_id', 'item_id', 'item_price', 'item_cnt_day', 'item_name', 'item_category_id', 'item_category_name', 'shop_name']


In [73]:
def handle_null_names(df):
    """Impute nulls in item_name with 'Unknown' and verify nulls."""
    print(f"Nulls in item_name before imputation: {df['item_name'].isna().sum()}")
    df['item_name'] = df['item_name'].fillna('Unknown')
    print(f"Imputed {df['item_name'].isna().sum()} nulls in item_name with 'Unknown'")
    return df

data = handle_null_names(data)

Nulls in item_name before imputation: 84


Imputed 0 nulls in item_name with 'Unknown'


In [74]:
import polars as pl
from datetime import datetime

def get_russian_holidays():
    """Return DataFrame of Russian holidays and shopping events (2013–2015)."""
    holidays = [
        ("2013-01-01", "New Year"), ("2014-01-01", "New Year"), ("2015-01-01", "New Year"),
        ("2013-02-23", "Defender Day"), ("2014-02-23", "Defender Day"), ("2015-02-23", "Defender Day"),
        ("2013-03-08", "Women's Day"), ("2014-03-08", "Women's Day"), ("2015-03-08", "Women's Day"),
        ("2013-06-12", "Russia Day"), ("2014-06-12", "Russia Day"), ("2015-06-12", "Russia Day"),
        ("2013-11-29", "Black Friday"), ("2014-11-28", "Black Friday"), ("2015-11-27", "Black Friday")
    ]
    holiday_df = pl.DataFrame({
        "date": [datetime.strptime(date, "%Y-%m-%d") for date, _ in holidays],
        "holiday": [name for _, name in holidays]
    }).with_columns([
        pl.col("date").dt.month().cast(pl.Int32).alias("month"),
        pl.col("date").dt.year().cast(pl.Int32).alias("year")
    ])
    return holiday_df

holiday_df = get_russian_holidays()
print(f"Holiday DataFrame created with {len(holiday_df)} entries")

Holiday DataFrame created with 15 entries


In [75]:
import time
import polars as pl

def initial_preprocessing(df):
    """Perform initial data cleaning and filtering, including shop selection."""
    start_time = time.time()
    
    # Convert to Polars
    df = pl.from_pandas(df)
    print(f"Initial dataset size: {len(df)}")
    print(f"Unique shops: {df['shop_id'].n_unique()}, items: {df['item_id'].n_unique()}")
    print(f"Date block range: {df['date_block_num'].min()}–{df['date_block_num'].max()}")
    
    # Parse dates
    df = df.with_columns(
        pl.col('date').str.strptime(pl.Date, "%d.%m.%Y")
    ).with_columns([
        pl.col('date').dt.month().cast(pl.Int32).alias('month'),
        pl.col('date').dt.year().cast(pl.Int32).alias('year')
    ])
    
    # Filter to date_block_num <= 32
    df = df.filter(pl.col('date_block_num') <= 32)
    print(f"Dataset size after date filter: {len(df)}")
    
    # Filter to top 54 shops by total sales volume
    shop_sales = df.group_by('shop_id').agg(
        total_sales=pl.col('item_cnt_day').sum()
    ).sort('total_sales', descending=True).head(54)
    valid_shops = shop_sales['shop_id'].to_list()
    df = df.filter(pl.col('shop_id').is_in(valid_shops))
    print(f"Dataset size after shop filter: {len(df)}")
    print(f"Unique shops after processing: {df['shop_id'].n_unique()}")
    
    # Optimize dtypes
    df = df.with_columns([
        pl.col('date_block_num').cast(pl.Int16),
        pl.col('shop_id').cast(pl.Int32),
        pl.col('item_id').cast(pl.Int32),
        pl.col('item_category_id').cast(pl.Int32),
        pl.col('item_cnt_day').cast(pl.Float32),
        pl.col('item_price').cast(pl.Float32),
        pl.col('month').cast(pl.Int32),
        pl.col('year').cast(pl.Int32)
    ])
    
    print(f"Initial preprocessing time: {time.time() - start_time:.2f} seconds")
    return df

df = initial_preprocessing(data)

Initial dataset size: 2935849
Unique shops: 60, items: 21807
Date block range: 0–33
Dataset size after date filter: 2882335
Dataset size after shop filter: 2868195
Unique shops after processing: 54
Initial preprocessing time: 1.81 seconds


In [76]:
def handle_negative_sales_and_outliers(df):
    """Apply z-score-based clipping and handle negative sales at daily level."""
    start_time = time.time()
    
    # Diagnostics: Inspect item_cnt_day distribution
    print("item_cnt_day distribution before processing:")
    print(df['item_cnt_day'].describe())
    print(f"Rows with item_cnt_day > 100: {df.filter(pl.col('item_cnt_day') > 100).height}")
    print(f"Rows with item_cnt_day > 500: {df.filter(pl.col('item_cnt_day') > 500).height}")
    negative_count = df.filter(pl.col('item_cnt_day') < 0).height
    print(f"Negative sales before processing: {negative_count} ({negative_count / len(df) * 100:.2f}%)")
    
    # Check for sparse shop-item pairs
    shop_item_counts = df.group_by(['shop_id', 'item_id']).agg(
        day_count=pl.col('date').count()
    )
    print(f"Shop-item pairs with < 30 days: {shop_item_counts.filter(pl.col('day_count') < 30).height}")
    
    # Sort for consistent processing
    df = df.sort(['shop_id', 'item_id', 'date'])
    
    # Z-score-based clipping per shop-item pair
    print("Z-score multiplier used: 4.5")
    df = df.with_columns([
        pl.col('item_cnt_day').mean().over(['shop_id', 'item_id']).alias('mean'),
        pl.col('item_cnt_day').std().over(['shop_id', 'item_id']).alias('std')
    ]).with_columns(
        item_cnt_day_winsor=pl.col('item_cnt_day').clip(
            pl.col('mean') - 4.5 * pl.col('std'),
            pl.min_horizontal(pl.col('mean') + 4.5 * pl.col('std'), 300)
        )
    )
    
    # Count outliers capped
    outlier_count = df.filter((pl.col('item_cnt_day') < pl.col('item_cnt_day_winsor')) | 
                             (pl.col('item_cnt_day') > pl.col('item_cnt_day_winsor'))).height
    print(f"Outliers capped: {outlier_count} ({outlier_count / len(df) * 100:.2f}%)")
    
    # Diagnostics: Inspect winsorized values
    print("item_cnt_day_winsor stats:")
    print(df['item_cnt_day_winsor'].describe())
    print(f"Max item_cnt_day_winsor: {df['item_cnt_day_winsor'].max()}")
    print(f"Clipped values count: {df.filter(pl.col('item_cnt_day_winsor') == 300).height}")
    
    # Handle negative sales
    df = df.with_columns([
        pl.when(pl.col('item_cnt_day_winsor') < 0)
          .then(pl.col('item_cnt_day_winsor').abs())
          .otherwise(0)
          .alias('returns'),
        pl.col('item_cnt_day_winsor').clip(lower_bound=0).alias('item_cnt_day_winsor')
    ])
    print(f"Negative sales after processing: {df.filter(pl.col('item_cnt_day_winsor') < 0).height}")
    
    # Calculate processed CV
    processed_cv = df['item_cnt_day_winsor'].std() / df['item_cnt_day_winsor'].mean()
    print(f"Processed CV after clipping: {processed_cv:.2f}")
    
    # Drop temporary columns
    df = df.drop(['mean', 'std'])
    
    print(f"Negative sales and outlier handling time: {time.time() - start_time:.2f} seconds")
    return df

df = handle_negative_sales_and_outliers(df)

item_cnt_day distribution before processing:
shape: (9, 2)
┌────────────┬────────────┐
│ statistic  ┆ value      │
│ ---        ┆ ---        │
│ str        ┆ f64        │
╞════════════╪════════════╡
│ count      ┆ 2.868195e6 │
│ null_count ┆ 0.0        │
│ mean       ┆ 1.240954   │
│ std        ┆ 2.28652    │
│ min        ┆ -22.0      │
│ 25%        ┆ 1.0        │
│ 50%        ┆ 1.0        │
│ 75%        ┆ 1.0        │
│ max        ┆ 1000.0     │
└────────────┴────────────┘
Rows with item_cnt_day > 100: 135
Rows with item_cnt_day > 500: 11
Negative sales before processing: 7186 (0.25%)


Shop-item pairs with < 30 days: 395283
Z-score multiplier used: 4.5
Outliers capped: 11090 (0.39%)
item_cnt_day_winsor stats:
shape: (9, 2)
┌────────────┬────────────┐
│ statistic  ┆ value      │
│ ---        ┆ ---        │
│ str        ┆ f64        │
╞════════════╪════════════╡
│ count      ┆ 2.868195e6 │
│ null_count ┆ 0.0        │
│ mean       ┆ 1.231432   │
│ std        ┆ 1.727064   │
│ min        ┆ -22.0      │
│ 25%        ┆ 1.0        │
│ 50%        ┆ 1.0        │
│ 75%        ┆ 1.0        │
│ max        ┆ 300.0      │
└────────────┴────────────┘
Max item_cnt_day_winsor: 300.0
Clipped values count: 14
Negative sales after processing: 0
Processed CV after clipping: 1.40
Negative sales and outlier handling time: 0.68 seconds


In [77]:
def aggregate_to_monthly(df):
    """Aggregate data to monthly level with duplicate handling and monthly cap."""
    start_time = time.time()
    
    # Handle daily duplicates
    df = df.group_by(['date', 'shop_id', 'item_id', 'item_category_id', 'month', 'year', 'date_block_num']).agg(
        item_cnt_day_winsor=pl.col('item_cnt_day_winsor').sum(),
        returns=pl.col('returns').sum(),
        item_price=pl.col('item_price').mean()
    )
    print(f"Dataset size after daily deduplication: {len(df)}")
    
    # Log daily duplicates (should be 0 after deduplication)
    daily_duplicates = df.group_by(['date', 'shop_id', 'item_id']).agg(count=pl.count()).filter(pl.col('count') > 1)
    print(f"Daily duplicates after deduplication: {daily_duplicates.height}")
    
    # Log monthly duplicates
    duplicates = df.group_by(['date_block_num', 'shop_id', 'item_id']).agg(count=pl.count()).filter(pl.col('count') > 1)
    print(f"Monthly duplicates found: {duplicates.height}")
    print(f"Duplicate details:\n{duplicates.head()}")
    
    # Aggregate to monthly
    df = df.group_by(['date_block_num', 'shop_id', 'item_id', 'item_category_id', 'month', 'year']).agg([
        pl.col('item_cnt_day_winsor').sum().alias('item_cnt_day_winsor'),
        pl.col('returns').sum().alias('returns'),
        pl.col('item_price').mean().alias('item_price')
    ]).with_columns(
        pl.col('item_cnt_day_winsor').clip(upper_bound=1000).alias('item_cnt_day_winsor')
    )
    
    # Recreate date column
    df = df.with_columns(
        pl.datetime(pl.col('year'), pl.col('month'), 1).alias('date')
    )
    
    # Diagnostics: Inspect aggregated item_cnt_day_winsor
    print("item_cnt_day_winsor distribution after aggregation:")
    print(df['item_cnt_day_winsor'].describe())
    
    print(f"Dataset size after aggregation: {len(df)}")
    print(f"Unique shops: {df['shop_id'].n_unique()}, items: {df['item_id'].n_unique()}")
    print(f"Aggregation time: {time.time() - start_time:.2f} seconds")
    return df

df = aggregate_to_monthly(df)

Dataset size after daily deduplication: 2868168
Daily duplicates after deduplication: 0
Monthly duplicates found: 519897
Duplicate details:
shape: (5, 4)
┌────────────────┬─────────┬─────────┬───────┐
│ date_block_num ┆ shop_id ┆ item_id ┆ count │
│ ---            ┆ ---     ┆ ---     ┆ ---   │
│ i16            ┆ i32     ┆ i32     ┆ u32   │
╞════════════════╪═════════╪═════════╪═══════╡
│ 5              ┆ 30      ┆ 11859   ┆ 3     │
│ 27             ┆ 37      ┆ 4178    ┆ 2     │
│ 23             ┆ 31      ┆ 18587   ┆ 2     │
│ 23             ┆ 4       ┆ 12472   ┆ 2     │
│ 19             ┆ 15      ┆ 18114   ┆ 3     │
└────────────────┴─────────┴─────────┴───────┘
item_cnt_day_winsor distribution after aggregation:
shape: (9, 2)
┌────────────┬───────────┐
│ statistic  ┆ value     │
│ ---        ┆ ---       │
│ str        ┆ f64       │
╞════════════╪═══════════╡
│ count      ┆ 1.56857e6 │
│ null_count ┆ 0.0       │
│ mean       ┆ 2.255657  │
│ std        ┆ 7.887975  │
│ min        ┆ 0.0  

In [78]:
def create_full_grid(df):
    """Create full shop-item-month grid."""
    start_time = time.time()
    
    shops = df['shop_id'].unique().to_list()
    items = df['item_id'].unique().to_list()
    date_blocks = list(range(33))  # 0–32
    
    grid = pl.DataFrame({'shop_id': shops}).join(
        pl.DataFrame({'item_id': items}), how='cross'
    ).join(
        pl.DataFrame({'date_block_num': date_blocks}), how='cross'
    ).with_columns([
        pl.col('shop_id').cast(pl.Int32),
        pl.col('item_id').cast(pl.Int32),
        pl.col('date_block_num').cast(pl.Int16),
        ((pl.col('date_block_num') % 12) + 1).cast(pl.Int32).alias('month'),
        ((pl.col('date_block_num') // 12) + 2013).cast(pl.Int32).alias('year')
    ])
    
    df = grid.join(
        df, on=['shop_id', 'item_id', 'date_block_num', 'month', 'year'], how='left'
    ).with_columns([
        pl.col('item_cnt_day_winsor').fill_null(0),
        pl.col('returns').fill_null(0),
        pl.col('item_price').fill_null(pl.col('item_price').mean().over('item_id')).fill_null(0),
        pl.col('item_category_id').fill_null(pl.col('item_category_id').first().over('item_id')).fill_null(0)
    ]).with_columns(
        pl.datetime(pl.col('year'), pl.col('month'), 1).alias('date')
    )
    
    # Verify date column
    if 'date' not in df.columns:
        raise ValueError("date column missing after grid creation")
    print(f"Columns after grid creation: {df.columns}")
    
    zero_sales = df.filter(pl.col('item_cnt_day_winsor') == 0).height
    non_zero_sales = df.filter(pl.col('item_cnt_day_winsor') > 0).height
    print(f"Grid size: {len(grid)}, after merge: {len(df)}")
    print(f"Zero sales introduced: {zero_sales} ({zero_sales / len(df) * 100:.2f}%)")
    print(f"Non-zero sales: {non_zero_sales} ({non_zero_sales / len(df) * 100:.2f}%)")
    print(f"Unique shops: {df['shop_id'].n_unique()}, items: {df['item_id'].n_unique()}")
    print(f"Grid creation time: {time.time() - start_time:.2f} seconds")
    return df

df = create_full_grid(df)

Columns after grid creation: ['shop_id', 'item_id', 'date_block_num', 'month', 'year', 'item_category_id', 'item_cnt_day_winsor', 'returns', 'item_price', 'date']
Grid size: 37972638, after merge: 37972638
Zero sales introduced: 36404950 (95.87%)
Non-zero sales: 1567688 (4.13%)
Unique shops: 54, items: 21309
Grid creation time: 4.79 seconds


In [79]:
def seasonal_imputation(df, col):
    """Apply seasonality-aware imputation to a column, treating zeros as pseudo-nulls."""
    df = df.with_columns(
        seasonal_value=pl.col(col).shift(12).over(['shop_id', 'item_id', 'month']),
        ma_value=pl.col(col).rolling_mean(window_size=12, min_periods=1).over(['shop_id', 'item_id'])
    ).with_columns(
        pl.when(pl.col(col).eq(0) & pl.col('seasonal_value').is_not_null())
          .then(pl.col('seasonal_value'))
          .when(pl.col(col).eq(0))
          .then(pl.col('ma_value'))
          .otherwise(pl.col(col))
          .alias(col)
    ).drop(['seasonal_value', 'ma_value'])
    return df

def apply_imputation(df, cols):
    """Impute missing values for specified columns."""
    start_time = time.time()
    
    for col in cols:
        df = seasonal_imputation(df, col)
    df = df.with_columns([pl.col(col).fill_null(0) for col in cols])
    
    print("item_cnt_day_winsor after imputation:")
    print(df['item_cnt_day_winsor'].describe())
    print(f"Imputation time: {time.time() - start_time:.2f} seconds")
    return df

numerical_cols = ['item_cnt_day_winsor', 'returns', 'item_price']
df = apply_imputation(df, numerical_cols)

item_cnt_day_winsor after imputation:
shape: (9, 2)
┌────────────┬─────────────┐
│ statistic  ┆ value       │
│ ---        ┆ ---         │
│ str        ┆ f64         │
╞════════════╪═════════════╡
│ count      ┆ 3.7972638e7 │
│ null_count ┆ 0.0         │
│ mean       ┆ 0.138108    │
│ std        ┆ 1.703088    │
│ min        ┆ -7.9473e-8  │
│ 25%        ┆ 0.0         │
│ 50%        ┆ 0.0         │
│ 75%        ┆ 0.0         │
│ max        ┆ 1000.0      │
└────────────┴─────────────┘
Imputation time: 53.62 seconds


In [80]:
def add_holiday_features(df, holiday_df):
    """Add holiday features to the dataset."""
    start_time = time.time()
    
    df = df.join(
        holiday_df.select(['year', 'month', 'holiday']),
        on=['year', 'month'], how='left'
    ).with_columns(
        is_holiday=pl.col('holiday').is_not_null().cast(pl.Int8),
        holiday=pl.col('holiday').fill_null('None')
    )
    
    print(f"Holiday feature addition time: {time.time() - start_time:.2f} seconds")
    return df

df = add_holiday_features(df, holiday_df)

Holiday feature addition time: 0.84 seconds


In [81]:
def filter_sparse_products(df):
    """Exclude shop-item pairs with >60% missing data and low sales."""
    start_time = time.time()
    
    # Pre-filter sparsity
    print(f"Pre-filter sparsity: {df.filter(pl.col('item_cnt_day_winsor') == 0).height / len(df):.2%}")
    
    shop_item_missing = df.group_by(['shop_id', 'item_id']).agg(
        missing_ratio=pl.col('item_cnt_day_winsor').eq(0).mean(),
        total_sales=pl.col('item_cnt_day_winsor').sum()
    )
    valid_shop_items = shop_item_missing.filter(
        (pl.col('missing_ratio') <= 0.60) & 
        (pl.col('total_sales') > 10)
    ).select(['shop_id', 'item_id'])
    
    initial_size = len(df)
    df = df.join(valid_shop_items, on=['shop_id', 'item_id'], how='inner')
    
    zero_sales = df.filter(pl.col('item_cnt_day_winsor') == 0).height
    item_count = df['item_id'].n_unique()
    shop_item_count = len(valid_shop_items)
    print(f"Records dropped due to >60% missing or low sales: {initial_size - len(df)}")
    print(f"Records after filtering: {len(df)}, shop-item pairs: {shop_item_count}")
    print(f"Sparsity after filtering: {zero_sales} ({zero_sales / len(df) * 100:.2f}%)")
    print(f"Unique items: {item_count}")
    print(f"Item count status: {'PASS' if 10000 <= item_count <= 15000 else 'FAIL'}")
    print(f"Filtering time: {time.time() - start_time:.2f} seconds")
    return df

df = filter_sparse_products(df)

Pre-filter sparsity: 82.92%


Records dropped due to >60% missing or low sales: 34490709
Records after filtering: 3481929, shop-item pairs: 105513
Sparsity after filtering: 976653 (28.05%)
Unique items: 9870
Item count status: FAIL
Filtering time: 0.60 seconds


In [82]:
def create_lag_features(df):
    """Create lag features for 1–3 months as per paper."""
    start_time = time.time()
    
    # Verify required columns
    required_cols = ['shop_id', 'item_id', 'date_block_num', 'item_cnt_day_winsor']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    print(f"Columns before lag creation: {df.columns}")
    
    # Sort by shop_id, item_id, date_block_num
    df = df.sort(['shop_id', 'item_id', 'date_block_num'])
    
    # Create lag features for item_cnt_day_winsor only
    for lag in [1, 2, 3]:
        df = df.with_columns(
            pl.col('item_cnt_day_winsor').shift(lag).over(['shop_id', 'item_id']).alias(f'lag_sales_{lag}')
        )
    
    print(f"Lag feature creation time: {time.time() - start_time:.2f} seconds")
    return df

df = create_lag_features(df)

Columns before lag creation: ['shop_id', 'item_id', 'date_block_num', 'month', 'year', 'item_category_id', 'item_cnt_day_winsor', 'returns', 'item_price', 'date', 'holiday', 'is_holiday']


Lag feature creation time: 0.89 seconds


In [83]:
# Update numerical_cols to include only sales lags
numerical_cols = ['item_cnt_day_winsor', 'returns', 'item_price', 'lag_sales_1', 'lag_sales_2', 'lag_sales_3']
print("\nMissing values before lag imputation:")
print(df.select(numerical_cols).null_count().to_pandas().to_string())

df = apply_imputation(df, numerical_cols)

print("\nMissing values after lag imputation:")
print(df.select(numerical_cols).null_count().to_pandas().to_string())


Missing values before lag imputation:
   item_cnt_day_winsor  returns  item_price  lag_sales_1  lag_sales_2  lag_sales_3
0                    0        0           0       105513       211026       316539


item_cnt_day_winsor after imputation:
shape: (9, 2)
┌────────────┬────────────┐
│ statistic  ┆ value      │
│ ---        ┆ ---        │
│ str        ┆ f64        │
╞════════════╪════════════╡
│ count      ┆ 3.481929e6 │
│ null_count ┆ 0.0        │
│ mean       ┆ 1.037785   │
│ std        ┆ 5.071925   │
│ min        ┆ -4.0916e-7 │
│ 25%        ┆ 0.083333   │
│ 50%        ┆ 0.416667   │
│ 75%        ┆ 1.0        │
│ max        ┆ 1000.0     │
└────────────┴────────────┘
Imputation time: 12.42 seconds

Missing values after lag imputation:
   item_cnt_day_winsor  returns  item_price  lag_sales_1  lag_sales_2  lag_sales_3
0                    0        0           0            0            0            0


In [84]:
import pickle
import os
import time
from sklearn.preprocessing import RobustScaler
import polars as pl
import pandas as pd

def scale_and_save_data(df, numerical_cols):
    """Apply robust scaling and save datasets."""
    start_time = time.time()
    
    # Pre-save diagnostics
    zero_sales = df.filter(pl.col('item_cnt_day_winsor') == 0).height
    print(f"Pre-save sparsity: {zero_sales} ({zero_sales / len(df) * 100:.2f}%)")
    processed_cv = df['item_cnt_day_winsor'].std() / df['item_cnt_day_winsor'].mean()
    print(f"Pre-save processed CV: {processed_cv:.2f}")
    print(f"Columns before saving: {df.columns}")
    
    # Check for unexpected columns
    expected_cols = ['date_block_num', 'shop_id', 'item_id', 'item_category_id', 'month', 'year',
                     'item_cnt_day_winsor', 'returns', 'item_price', 'is_holiday', 'date',
                     'lag_sales_1', 'lag_sales_2', 'lag_sales_3', 'holiday']
    unexpected_cols = [col for col in df.columns if col not in expected_cols]
    if unexpected_cols:
        print(f"Warning: Unexpected columns before saving: {unexpected_cols}")
    
    # Pre-scaling null check (Polars)
    print("\nNull values in numerical_cols before scaling (Polars):")
    null_counts_polars = df.select(numerical_cols).null_count().to_pandas()
    print(null_counts_polars.to_string())
    for col in numerical_cols:
        null_rows = df.filter(pl.col(col).is_null()).height
        print(f"Rows with null {col}: {null_rows}")
    
    monthly_sales = df.to_pandas()
    del df
    
    # Pre-scaling null check (Pandas)
    print("\nNull values in numerical_cols before scaling (Pandas):")
    null_counts_pandas = monthly_sales[numerical_cols].isna().sum()
    print(null_counts_pandas.to_string())
    for col in numerical_cols:
        null_rows = monthly_sales[monthly_sales[col].isna()].shape[0]
        print(f"Rows with null {col}: {null_rows}")
    
    save_dir = '/workspace/XAI-2/XAI/processed_data'
    os.makedirs(save_dir, exist_ok=True)
    
    monthly_sales.to_parquet(os.path.join(save_dir, 'monthly_sales_unscaled.parquet'))
    print("Saved unscaled data")
    
    scaler = RobustScaler()
    train_data = monthly_sales[monthly_sales['date_block_num'] < 30][numerical_cols]
    print(f"Scaler training data shape: {train_data.shape}")
    
    scaler.fit(train_data)
    monthly_sales[numerical_cols] = scaler.transform(monthly_sales[numerical_cols])
    
    if monthly_sales[numerical_cols].isna().any().any():
        print("\nNull values in numerical_cols after scaling:")
        null_counts_after = monthly_sales[numerical_cols].isna().sum()
        print(null_counts_after.to_string())
        for col in numerical_cols:
            null_rows = monthly_sales[monthly_sales[col].isna()].shape[0]
            print(f"Rows with null {col}: {null_rows}")
        raise ValueError("NaNs introduced during scaling")
    
    with open(os.path.join(save_dir, 'scaler.pkl'), 'wb') as f:
        pickle.dump(scaler, f)
    
    dtypes = {col: 'float32' for col in numerical_cols}
    dtypes.update({
        'shop_id': 'int32', 'item_id': 'int32', 'item_category_id': 'int32',
        'date_block_num': 'int16', 'month': 'int32', 'year': 'int32', 'is_holiday': 'int8'
    })
    monthly_sales = monthly_sales.astype(dtypes, errors='ignore')
    
    print(f"Scaling and saving time: {time.time() - start_time:.2f} seconds")
    return monthly_sales

numerical_cols = ['item_cnt_day_winsor', 'returns', 'item_price', 'lag_sales_1', 'lag_sales_2', 'lag_sales_3']
monthly_sales = scale_and_save_data(df, numerical_cols)

Pre-save sparsity: 563610 (16.19%)
Pre-save processed CV: 4.89
Columns before saving: ['shop_id', 'item_id', 'date_block_num', 'month', 'year', 'item_category_id', 'item_cnt_day_winsor', 'returns', 'item_price', 'date', 'holiday', 'is_holiday', 'lag_sales_1', 'lag_sales_2', 'lag_sales_3']

Null values in numerical_cols before scaling (Polars):
   item_cnt_day_winsor  returns  item_price  lag_sales_1  lag_sales_2  lag_sales_3
0                    0        0           0            0            0            0
Rows with null item_cnt_day_winsor: 0
Rows with null returns: 0
Rows with null item_price: 0
Rows with null lag_sales_1: 0
Rows with null lag_sales_2: 0
Rows with null lag_sales_3: 0



Null values in numerical_cols before scaling (Pandas):
item_cnt_day_winsor    0
returns                0
item_price             0
lag_sales_1            0
lag_sales_2            0
lag_sales_3            0
Rows with null item_cnt_day_winsor: 0
Rows with null returns: 0
Rows with null item_price: 0
Rows with null lag_sales_1: 0
Rows with null lag_sales_2: 0
Rows with null lag_sales_3: 0
Saved unscaled data
Scaler training data shape: (3165390, 6)
Scaling and saving time: 2.84 seconds


In [85]:
def split_and_save_sets(df):
    """Split data into train/val/test per paper (months 0–30, 31, 32) and save X and y."""
    start_time = time.time()
    
    train_df = df[df['date_block_num'] <= 30]
    val_df = df[df['date_block_num'] == 31]
    test_df = df[df['date_block_num'] == 32]
    
    target_col = 'item_cnt_day_winsor'
    exclude_cols = [target_col, 'holiday', 'date']
    feature_cols = [
        'date_block_num', 'shop_id', 'item_id', 'item_category_id', 'month', 'year',
        'item_price', 'returns', 'is_holiday', 'lag_sales_1', 'lag_sales_2', 'lag_sales_3'
    ]
    
    # Debug unexpected columns
    unexpected_cols = [col for col in df.columns if col not in feature_cols + exclude_cols]
    if unexpected_cols:
        print(f"Warning: Unexpected columns in dataset: {unexpected_cols}")
    
    print(f"Feature columns: {feature_cols}")
    
    X_train = train_df[feature_cols]
    y_train = train_df[[target_col]]
    X_val = val_df[feature_cols]
    y_val = val_df[[target_col]]
    X_test = test_df[feature_cols]
    y_test = test_df[[target_col]]
    
    print(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
    print(f"X_val: {X_val.shape}, y_val: {y_val.shape}")
    print(f"X_test: {X_test.shape}, y_test: {y_test.shape}")
    print(f"Unique shops in train: {X_train['shop_id'].nunique()}, items: {X_train['item_id'].nunique()}")
    
    output_dir = '/workspace/XAI-2/XAI/processed_data'
    os.makedirs(output_dir, exist_ok=True)
    
    X_train.to_parquet(os.path.join(output_dir, 'X_train_processed.parquet'))
    y_train.to_parquet(os.path.join(output_dir, 'y_train_processed.parquet'))
    X_val.to_parquet(os.path.join(output_dir, 'X_val_processed.parquet'))
    y_val.to_parquet(os.path.join(output_dir, 'y_val_processed.parquet'))
    X_test.to_parquet(os.path.join(output_dir, 'X_test_processed.parquet'))
    y_test.to_parquet(os.path.join(output_dir, 'y_test_processed.parquet'))
    
    df.to_parquet(os.path.join(output_dir, 'processed_sales.parquet'))
    
    print(f"Split and save time: {time.time() - start_time:.2f} seconds")
    return X_train, y_train, X_val, y_val, X_test, y_test

X_train, y_train, X_val, y_val, X_test, y_test = split_and_save_sets(monthly_sales)

Feature columns: ['date_block_num', 'shop_id', 'item_id', 'item_category_id', 'month', 'year', 'item_price', 'returns', 'is_holiday', 'lag_sales_1', 'lag_sales_2', 'lag_sales_3']
X_train: (3270903, 12), y_train: (3270903, 1)
X_val: (105513, 12), y_val: (105513, 1)
X_test: (105513, 12), y_test: (105513, 1)
Unique shops in train: 52, items: 9870
Split and save time: 4.06 seconds


In [86]:
import sys
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import polars as pl
from tqdm import tqdm
from pathlib import Path
import os
import logging
import pkg_resources

# Check Polars version
required_polars_version = "0.20.0"
current_polars_version = pkg_resources.get_distribution("polars").version
if pkg_resources.parse_version(current_polars_version) < pkg_resources.parse_version(required_polars_version):
    raise ImportError(
        f"Polars version {current_polars_version} is outdated. Please upgrade to {required_polars_version} or higher "
        "using 'pip install --upgrade polars'"
    )

# Set file descriptor limit
os.system('ulimit -n 4096')

# Minimal logging with DEBUG level
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# GPU Configuration
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

class SalesDataset(Dataset):
    def __init__(self, X_file, y_file, sequence_length=12, num_shops=54, num_items=22170, num_categories=84):
        logger.info(f"Loading dataset from {X_file}")
        start = time.time()
        
        # Load data
        self.X = pl.read_parquet(X_file)
        self.y = pl.read_parquet(y_file).select(['item_cnt_day_winsor']).to_numpy().flatten().astype(np.float32)
        logger.info(f"Loaded Parquet files in {time.time() - start:.2f}s")
        
        if len(self.X) != len(self.y):
            raise ValueError(f"Mismatch between X ({len(self.X)}) and y ({len(self.y)}) rows")
        
        self.sequence_length = sequence_length
        self.num_shops = num_shops
        self.num_items = num_items
        self.num_categories = num_categories
        
        self.numerical_cols = [
            'item_cnt_day_winsor', 'returns', 'item_price',
            'lag_sales_1', 'lag_sales_2', 'lag_sales_3'
        ]
        self.categorical_cols = ['shop_id', 'item_id', 'item_category_id']
        
        # Validate and normalize numerical data
        numerical_data = self.X.select(self.numerical_cols).to_numpy()
        if np.isnan(numerical_data).any() or np.isnan(self.y).any():
            raise ValueError("NaN values in X or y")
        if np.isinf(numerical_data).any() or np.isinf(self.y).any():
            raise ValueError("Infinite values in X or y")
        
        # Clip and normalize numerical data
        numerical_data = np.clip(numerical_data, -1e5, 1e5)
        mean = numerical_data.mean(axis=0, keepdims=True)
        std = numerical_data.std(axis=0, keepdims=True) + 1e-6
        numerical_data = (numerical_data - mean) / std
        
        # Clip and normalize target
        self.y = np.clip(self.y, -1e5, 1e5)
        self.y_mean = self.y.mean()
        self.y_std = self.y.std() + 1e-6
        self.y = (self.y - self.y_mean) / self.y_std
        
        # Clip categorical indices
        self.X = self.X.with_columns([
            pl.col('shop_id').clip(upper_bound=num_shops - 1),
            pl.col('item_id').clip(upper_bound=num_items - 1),
            pl.col('item_category_id').clip(upper_bound=num_categories - 1)
        ])
        
        # Preload data
        self.numerical = numerical_data.astype(np.float32)
        self.shop_ids = self.X['shop_id'].to_numpy().astype(np.int64)
        self.item_ids = self.X['item_id'].to_numpy().astype(np.int64)
        self.category_ids = self.X['item_category_id'].to_numpy().astype(np.int64)
        self.date_block_num = self.X['date_block_num'].to_numpy().astype(np.int32)
        
    def __len__(self):
        return len(self.X) - self.sequence_length + 1
    
    def __getitem__(self, idx):
        try:
            start_idx = idx
            end_idx = idx + self.sequence_length
            if end_idx > len(self.X):
                raise IndexError("Index out of range")
            
            numerical = torch.tensor(self.numerical[start_idx:end_idx], dtype=torch.float32)
            shop_ids = torch.tensor(self.shop_ids[start_idx:end_idx], dtype=torch.int64)
            item_ids = torch.tensor(self.item_ids[start_idx:end_idx], dtype=torch.int64)
            category_ids = torch.tensor(self.category_ids[start_idx:end_idx], dtype=torch.int64)
            date_block_num = torch.tensor(self.date_block_num[start_idx:end_idx], dtype=torch.int32)
            target = torch.tensor(self.y[end_idx - 1], dtype=torch.float32)
            
            if torch.isnan(numerical).any() or torch.isnan(target).any():
                raise ValueError(f"NaN detected at index {idx}")
            
            identifiers = torch.tensor([
                int(self.shop_ids[end_idx - 1]), 
                int(self.item_ids[end_idx - 1]), 
                int(self.date_block_num[end_idx - 1])
            ], dtype=torch.int32)
            
            return {
                'numerical': numerical,
                'shop_ids': shop_ids,
                'item_ids': item_ids,
                'category_ids': category_ids,
                'target': target,
                'date_block_num': date_block_num[-1],
                'identifiers': identifiers
            }
        except Exception as e:
            logger.error(f"Error in __getitem__ at index {idx}: {str(e)}")
            raise

class FeatureAttention(nn.Module):
    def __init__(self, feature_dim, attention_dim=64):
        super(FeatureAttention, self).__init__()
        self.query = nn.Linear(feature_dim, attention_dim)
        self.key = nn.Linear(feature_dim, attention_dim)
        self.value = nn.Linear(feature_dim, feature_dim)
        self.scale = 1 / (attention_dim ** 0.5)
        self.softmax = nn.Softmax(dim=-1)
        self.norm = nn.LayerNorm(feature_dim)
        nn.init.xavier_uniform_(self.query.weight)
        nn.init.xavier_uniform_(self.key.weight)
        nn.init.xavier_uniform_(self.value.weight)
    
    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        scores = torch.bmm(query, key.transpose(1, 2)) * self.scale
        weights = self.softmax(scores)
        output = torch.bmm(weights, value)
        output = self.norm(output)
        return output, weights

class HALSTM(nn.Module):
    def __init__(self, num_shops=54, num_items=22170, num_categories=84, embed_dim=16, numerical_dim=6, 
                 hidden_dim=128, num_layers=2, num_heads=4, dropout=0.4, forecast_horizon=1):
        super(HALSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.forecast_horizon = forecast_horizon
        
        self.shop_embed = nn.Embedding(num_shops, embed_dim)
        self.item_embed = nn.Embedding(num_items, embed_dim)
        self.category_embed = nn.Embedding(num_categories, embed_dim)
        nn.init.normal_(self.shop_embed.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.item_embed.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.category_embed.weight, mean=0.0, std=0.02)
        
        self.input_dim = embed_dim * 3 + numerical_dim
        self.feature_attention = FeatureAttention(self.input_dim)
        self.lstm = nn.LSTM(self.input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.lstm_norm = nn.LayerNorm(hidden_dim)
        self.mha = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
        self.mha_norm = nn.LayerNorm(hidden_dim)
        self.gate = nn.Linear(hidden_dim * 2, hidden_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc_shared = nn.Linear(hidden_dim, hidden_dim)
        self.fc_horizons = nn.ModuleList([nn.Linear(hidden_dim, 1) for _ in range(forecast_horizon)])
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        
        # Positional encoding
        self.positional_encoding = self._create_positional_encoding(max_seq_len=100, d_model=hidden_dim)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def _create_positional_encoding(self, max_seq_len, d_model):
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.to(device)
    
    def forward(self, numerical, shop_ids, item_ids, category_ids, trace_mode=False):
        batch_size, seq_len, _ = numerical.size()
        shop_embed = self.shop_embed(shop_ids)
        item_embed = self.item_embed(item_ids)
        category_embed = self.category_embed(category_ids)
        
        x = torch.cat([numerical, shop_embed, item_embed, category_embed], dim=-1).contiguous()
        x, feature_weights = self.feature_attention(x)
        x = self.dropout(x)
        
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=x.device).contiguous()
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=x.device).contiguous()
        lstm_out, _ = self.lstm(x, (h0, c0))
        lstm_out = self.lstm_norm(lstm_out)
        lstm_out = self.dropout(lstm_out)
        
        # Add positional encoding
        lstm_out = lstm_out + self.positional_encoding[:seq_len, :].unsqueeze(0).to(lstm_out.device)
        
        mha_out, mha_weights = self.mha(lstm_out, lstm_out, lstm_out)
        mha_out = self.mha_norm(mha_out)
        mha_out = self.dropout(mha_out)
        
        combined = torch.cat([lstm_out[:, -1, :], mha_out[:, -1, :]], dim=-1)
        gate = self.sigmoid(self.gate(combined))
        fused = gate * lstm_out[:, -1, :] + (1 - gate) * mha_out[:, -1, :]
        
        shared = self.relu(self.fc_shared(fused))
        outputs = torch.cat([fc(shared).unsqueeze(1) for fc in self.fc_horizons], dim=1)
        outputs = outputs.squeeze(-1)
        
        if trace_mode:
            return outputs
        return outputs, {
            'feature_weights': feature_weights,
            'mha_weights': mha_weights,
            'fused_output': fused,
            'gate_weights': gate
        }

def collate_fn(batch):
    if not batch:
        logger.warning("Empty batch received")
        return {}
    
    numerical = torch.stack([item['numerical'] for item in batch])
    shop_ids = torch.stack([item['shop_ids'] for item in batch])
    item_ids = torch.stack([item['item_ids'] for item in batch])
    category_ids = torch.stack([item['category_ids'] for item in batch])
    target = torch.stack([item['target'] for item in batch])
    date_block_num = torch.stack([item['date_block_num'] for item in batch])
    identifiers = torch.stack([item['identifiers'] for item in batch])
    
    return {
        'numerical': numerical.contiguous(),
        'shop_ids': shop_ids.contiguous(),
        'item_ids': item_ids.contiguous(),
        'category_ids': category_ids.contiguous(),
        'target': target.contiguous(),
        'date_block_num': date_block_num.contiguous(),
        'identifiers': identifiers.contiguous()
    }

def train_model(model, train_loader, val_loader, num_epochs=50, lr=0.0005, accum_steps=2):
    logger.info("Starting training")
    criterion = nn.MSELoss().to(device)
    scaler = torch.cuda.amp.GradScaler()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
    l1_lambda = 1e-5
    att_lambda = 1e-5
    temp_lambda = 1e-5
    
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=lr, epochs=num_epochs, steps_per_epoch=len(train_loader)//accum_steps, pct_start=0.1
    )
    
    best_val_loss = float('inf')
    output_dir = Path('/workspace/XAI-2/XAI/processed_data')
    output_dir.mkdir(exist_ok=True)
    
    patience = 5  # Increased for stability
    early_stop_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        optimizer.zero_grad(set_to_none=True)
        
        progress = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for batch_idx, batch in progress:
            try:
                numerical = batch['numerical'].to(device, non_blocking=True)
                shop_ids = batch['shop_ids'].to(device, non_blocking=True)
                item_ids = batch['item_ids'].to(device, non_blocking=True)
                category_ids = batch['category_ids'].to(device, non_blocking=True)
                target = batch['target'].to(device, non_blocking=True)
                
                with torch.cuda.amp.autocast():
                    output, attention_dict = model(numerical, shop_ids, item_ids, category_ids)
                    mse_loss = criterion(output[:, -1], target) / accum_steps
                    
                    att_loss = attention_dict['mha_weights'].abs().sum()
                    temp_loss = torch.zeros(1, device=device)
                    if output.shape[1] > 1:
                        temp_loss = (output[:, 1:] - output[:, :-1]).abs().sum()
                    l1_loss = sum(p.abs().sum() for p in model.parameters() if p.requires_grad)
                    
                    loss = mse_loss + l1_lambda * l1_loss + att_lambda * att_loss + temp_lambda * temp_loss
                
                scaler.scale(loss).backward()
                
                if (batch_idx + 1) % accum_steps == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                    logger.debug(f"Batch {batch_idx}: Calling optimizer.step()")
                    scaler.step(optimizer)
                    logger.debug(f"Batch {batch_idx}: Calling scheduler.step()")
                    scheduler.step()
                    scaler.update()
                    optimizer.zero_grad(set_to_none=True)
                
                train_loss += mse_loss.item() * accum_steps
                
                if batch_idx % 100 == 0:
                    try:
                        import pynvml
                        logger.info(f"Batch {batch_idx}, GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB, "
                                    f"Utilization: {torch.cuda.utilization()}%")
                    except (ImportError, Exception) as e:
                        logger.warning(f"GPU monitoring unavailable: {str(e)}")
                
                progress.set_postfix({"batch_loss": f"{mse_loss.item() * accum_steps:.6f}"})
            except Exception as e:
                logger.error(f"Error in batch {batch_idx}: {str(e)}")
                continue
        
        train_loss /= len(train_loader)
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation", leave=False):
                try:
                    numerical = batch['numerical'].to(device, non_blocking=True)
                    shop_ids = batch['shop_ids'].to(device, non_blocking=True)
                    item_ids = batch['item_ids'].to(device, non_blocking=True)
                    category_ids = batch['category_ids'].to(device, non_blocking=True)
                    target = batch['target'].to(device, non_blocking=True)
                    
                    with torch.cuda.amp.autocast():
                        output, _ = model(numerical, shop_ids, item_ids, category_ids)
                        loss = criterion(output[:, -1], target)
                    
                    if torch.isnan(loss):
                        logger.error("NaN detected in validation loss")
                        raise ValueError("Validation loss is NaN")
                    
                    val_loss += loss.item()
                except Exception as e:
                    logger.error(f"Error in validation: {str(e)}")
                    continue
        
        val_loss /= len(val_loader)
        logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), output_dir / 'best_ha_lstm.pth')
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            
        if early_stop_counter >= patience:
            logger.info(f"Early stopping after {epoch+1} epochs")
            break
        
        torch.cuda.empty_cache()
    
    logger.info(f"Training done. Best val loss: {best_val_loss:.6f}")
    model.load_state_dict(torch.load(output_dir / 'best_ha_lstm.pth'))
    return model

def predict(model, test_loader, dataset):
    logger.info("Predicting")
    model.eval()
    predictions = []
    identifiers = []
    fused_outputs = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Predicting"):
            try:
                numerical = batch['numerical'].to(device, non_blocking=True)
                shop_ids = batch['shop_ids'].to(device, non_blocking=True)
                item_ids = batch['item_ids'].to(device, non_blocking=True)
                category_ids = batch['category_ids'].to(device, non_blocking=True)
                
                with torch.cuda.amp.autocast():
                    output, attention_dict = model(numerical, shop_ids, item_ids, category_ids)
                    preds = output.cpu().numpy()
                    preds = preds * dataset.y_std + dataset.y_mean
                    fused = attention_dict['fused_output'].cpu().numpy()
                
                predictions.append(preds)
                identifiers.append(batch['identifiers'].cpu().numpy())
                fused_outputs.append(fused)
            except Exception as e:
                logger.error(f"Error in prediction: {str(e)}")
                continue
            
            torch.cuda.empty_cache()
    
    if not predictions or not identifiers or not fused_outputs:
        logger.error("No predictions, identifiers, or fused outputs collected")
        return pd.DataFrame()
        
    predictions = np.concatenate(predictions, axis=0)
    identifiers = np.concatenate(identifiers, axis=0)
    fused_outputs = np.concatenate(fused_outputs, axis=0)
    
    pred_df = pd.DataFrame({
        'shop_id': identifiers[:, 0],
        'item_id': identifiers[:, 1],
        'date_block_num': identifiers[:, 2]
    })
    for h in range(predictions.shape[1]):
        pred_df[f'forecast_h{h+1}'] = predictions[:, h]
    
    fused_df = pd.DataFrame(fused_outputs, columns=[f'fused_dim_{i}' for i in range(fused_outputs.shape[1])])
    fused_df[['shop_id', 'item_id', 'date_block_num']] = identifiers
    
    output_dir = Path('/workspace/XAI-2/XAI/results')
    output_dir.mkdir(exist_ok=True)
    pred_df.to_csv(output_dir / 'predictions.csv', index=False)
    fused_df.to_csv(output_dir / 'fused_outputs.csv', index=False)
    logger.info(f"Predictions saved to {output_dir / 'predictions.csv'}")
    logger.info(f"Fused outputs saved to {output_dir / 'fused_outputs.csv'}")
    
    return pred_df

def visualize_results(pred_df, true_df=None, output_dir='/workspace/XAI-2/XAI/results'):
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)
    logger.info(f"Saving plots to {output_dir}")
    
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.figure(figsize=(12, 8))
    
    if pred_df.empty:
        logger.error("Empty prediction dataframe")
        return None
    
    forecast_cols = [col for col in pred_df.columns if 'forecast' in col]
    if not forecast_cols:
        logger.error("No forecast columns found")
        return None
        
    for h in range(1, len(forecast_cols) + 1):
        sns.kdeplot(pred_df[f'forecast_h{h}'], label=f'Horizon {h}')
    
    plt.title('Prediction Distribution')
    plt.xlabel('Predicted Sales')
    plt.ylabel('Density')
    plt.legend()
    plt.tight_layout()
    plt.savefig(output_dir / 'prediction_distribution.png', dpi=300)
    plt.close()
    
    if true_df is not None and not true_df.empty:
        try:
            merged_df = pred_df.merge(true_df, on=['shop_id', 'item_id', 'date_block_num'], how='inner')
            if not merged_df.empty:
                merged_df['error'] = merged_df['forecast_h1'] - merged_df['item_cnt_day_winsor']
                
                plt.figure(figsize=(12, 8))
                sns.histplot(merged_df['error'], kde=True, bins=50)
                plt.title('Error Distribution')
                plt.xlabel('Prediction Error')
                plt.ylabel('Count')
                plt.tight_layout()
                plt.savefig(output_dir / 'error_distribution.png', dpi=300)
                plt.close()
                
                plt.figure(figsize=(12, 8))
                plt.scatter(merged_df['item_cnt_day_winsor'], merged_df['forecast_h1'], alpha=0.5)
                plt.plot([0, merged_df['item_cnt_day_winsor'].max()], 
                         [0, merged_df['item_cnt_day_winsor'].max()], 'r--')
                plt.title('Predicted vs Actual')
                plt.xlabel('Actual Sales')
                plt.ylabel('Predicted Sales')
                plt.tight_layout()
                plt.savefig(output_dir / 'predicted_vs_actual.png', dpi=300)
                plt.close()
                
                mae = merged_df['error'].abs().mean()
                rmse = (merged_df['error'] ** 2).mean() ** 0.5
                mape = (merged_df['error'].abs() / merged_df['item_cnt_day_winsor'].abs().replace(0, 1e-6)).mean() * 100
                logger.info(f"MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.2f}%")
                
                return mae, rmse, mape
            else:
                logger.warning("Merged dataframe is empty")
        except Exception as e:
            logger.error(f"Error in visualization: {str(e)}")
    
    return None

def main():
    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    
    data_dir = Path('/workspace/XAI-2/XAI/processed_data')
    batch_size = 8192
    num_workers = 0
    prefetch_factor = None
    num_shops = 54
    num_items = 22170
    num_categories = 84
    sequence_length = 12
    num_epochs = 50
    lr = 0.0005
    accum_steps = 2
    
    logger.info("Loading datasets...")
    required_files = [
        'X_train_processed.parquet', 'y_train_processed.parquet',
        'X_val_processed.parquet', 'y_val_processed.parquet',
        'X_test_processed.parquet', 'y_test_processed.parquet'
    ]
    
    missing_files = [f for f in required_files if not (data_dir / f).exists()]
    if missing_files:
        raise FileNotFoundError(f"Missing files: {', '.join(missing_files)}")
    
    try:
        train_dataset = SalesDataset(
            data_dir / 'X_train_processed.parquet',
            data_dir / 'y_train_processed.parquet',
            sequence_length=sequence_length,
            num_shops=num_shops,
            num_items=num_items,
            num_categories=num_categories
        )
        val_dataset = SalesDataset(
            data_dir / 'X_val_processed.parquet',
            data_dir / 'y_val_processed.parquet',
            sequence_length=sequence_length,
            num_shops=num_shops,
            num_items=num_items,
            num_categories=num_categories
        )
        test_dataset = SalesDataset(
            data_dir / 'X_test_processed.parquet',
            data_dir / 'y_test_processed.parquet',
            sequence_length=sequence_length,
            num_shops=num_shops,
            num_items=num_items,
            num_categories=num_categories
        )
        
        logger.info(f"Train dataset size: {len(train_dataset)}")
        logger.info(f"Val dataset size: {len(val_dataset)}")
        logger.info(f"Test dataset size: {len(test_dataset)}")
        
        # Verify dataset sizes against paper's expectations
        expected_sizes = {'train': 2700000, 'val': 100000, 'test': 100000}
        for split, size, expected in [('train', len(train_dataset), expected_sizes['train']),
                                      ('val', len(val_dataset), expected_sizes['val']),
                                      ('test', len(test_dataset), expected_sizes['test'])]:
            if abs(size - expected) / expected > 0.1:
                logger.warning(f"{split.capitalize()} dataset size {size} deviates significantly from expected {expected}. "
                               "Verify preprocessing and split_dataset output.")
        
    except Exception as e:
        logger.error(f"Error creating datasets: {str(e)}")
        raise
    
    try:
        train_loader = DataLoader(
            train_dataset, 
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
            prefetch_factor=prefetch_factor,
            persistent_workers=False,
            collate_fn=collate_fn,
            timeout=0
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
            prefetch_factor=prefetch_factor,
            persistent_workers=False,
            collate_fn=collate_fn,
            timeout=0
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
            prefetch_factor=prefetch_factor,
            persistent_workers=False,
            collate_fn=collate_fn,
            timeout=0
        )
        
        logger.info("DataLoaders created successfully")
        
    except Exception as e:
        logger.error(f"Error creating DataLoaders: {str(e)}")
        raise
    
    try:
        model = HALSTM(
            num_shops=num_shops,
            num_items=num_items,
            num_categories=num_categories,
            embed_dim=16,
            numerical_dim=6,
            hidden_dim=128,
            num_layers=2,
            num_heads=4,
            dropout=0.4,
            forecast_horizon=1
        ).to(device)
        
        logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        
        model = train_model(model, train_loader, val_loader, num_epochs=num_epochs, lr=lr, accum_steps=accum_steps)
        
        predictions = predict(model, test_loader, test_dataset)
        
        y_test = pl.read_parquet(data_dir / 'y_test_processed.parquet')
        x_test_identifiers = pl.read_parquet(data_dir / 'X_test_processed.parquet').select(['shop_id', 'item_id', 'date_block_num'])
        true_df = pl.concat([x_test_identifiers, y_test], how='horizontal').to_pandas()
        
        visualize_results(predictions, true_df)
        
    except Exception as e:
        logger.error(f"Error in main: {str(e)}")
        raise

if __name__ == '__main__':
    try:
        logger.info("Starting program")
        main()
        logger.info("Program completed successfully")
    except Exception as e:
        logger.error(f"Program failed: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        sys.exit(1)

2025-05-12 18:36:15 - INFO - Using device: cuda
2025-05-12 18:36:15 - INFO - Starting program
2025-05-12 18:36:15 - INFO - Loading datasets...
2025-05-12 18:36:15 - INFO - Loading dataset from /workspace/XAI-2/XAI/processed_data/X_train_processed.parquet
2025-05-12 18:36:15 - INFO - Loaded Parquet files in 0.49s
2025-05-12 18:36:15 - ERROR - Error creating datasets: item_cnt_day_winsor

Resolved plan until failure:

	---> FAILED HERE RESOLVING 'sink' <---
DF ["date_block_num", "shop_id", "item_id", "item_category_id", ...]; PROJECT */13 COLUMNS
2025-05-12 18:36:15 - ERROR - Program failed: item_cnt_day_winsor

Resolved plan until failure:

	---> FAILED HERE RESOLVING 'sink' <---
DF ["date_block_num", "shop_id", "item_id", "item_category_id", ...]; PROJECT */13 COLUMNS
2025-05-12 18:36:15 - ERROR - Traceback (most recent call last):
  File "/tmp/ipykernel_2858/2081807326.py", line 667, in <module>
    main()
  File "/tmp/ipykernel_2858/2081807326.py", line 551, in main
    train_dataset

SystemExit: 1