# Snowflake Native ML Demo: Time Series Forecasting & Customer Classification

This notebook demonstrates Snowflake's built-in ML capabilities using the retail streaming data:
1. **Time Series Forecasting** - Predict future sales trends using ML.FORECAST
2. **Customer Classification** - Predict customer segments using ML.CLASSIFICATION

## Prerequisites
- Running in **Snowflake Notebooks** with container runtime
- Historical data loaded into the Gold layer (run `CALL GENERATE_HISTORICAL_DATA()` and bronze and gold transformation SPROCs)
- Warehouse: RETAIL_TRANSFORM_WH or similar

## Setup: Install & Import Required Packages

This cell installs all necessary packages in the Snowflake Notebooks container runtime. Run this first!

In [None]:
# Install required packages in Snowflake Notebooks container runtime
import subprocess
import sys

def install_package(package):
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package, "-q"])
        print(f"✅ {package} installed successfully")
    except Exception as e:
        print(f"⚠️ Failed to install {package}: {e}")

# Install essential packages for this notebook
packages = [
    "snowflake-snowpark-python[pandas]>=1.11.0",
    "pandas>=2.0.0",
    "numpy>=1.24.0",
    "matplotlib>=3.0.0",
    "seaborn>=0.12.0"
]

print("📦 Installing required packages...")
for package in packages:
    install_package(package)

print("🎉 Package installation complete!")

## 1. Connect to Snowflake

Using Snowflake's native session (works in Snowflake Notebooks with container runtime)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
from snowflake.snowpark.context import get_active_session
import warnings
warnings.filterwarnings('ignore')

# Get the active Snowflake session
session = get_active_session()

# Set context
session.sql("USE DATABASE RETAIL_STREAMING_DEMO").collect()
session.sql("USE SCHEMA PUBLIC").collect() 

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

print("✅ Libraries imported successfully")
print(f"✅ Connected to Snowflake session")
print(f"   Database: {session.get_current_database()}")
print(f"   Schema: {session.get_current_schema()}")
print(f"   Warehouse: {session.get_current_warehouse()}")

In [None]:
# Helper functions to execute SQL and return pandas DataFrame
def query_to_df(query, use_arrow=False):
    """
    Execute a SQL query and return results as a pandas DataFrame
    
    Args:
        query: SQL query string
        use_arrow: Set to True for ML functions (FORECAST, CLASSIFICATION) for better compatibility
    """
    if use_arrow:
        # Use ARROW format for ML functions to avoid JSON parsing issues
        import snowflake.snowpark as snowpark
        original_format = session.sql("SHOW PARAMETERS LIKE 'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT' IN SESSION").collect()
        session.sql("ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'ARROW'").collect()
        try:
            result = session.sql(query).to_pandas()
        finally:
            # Restore original format
            session.sql("ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'JSON'").collect()
        return result
    else:
        return session.sql(query).to_pandas()

print("✅ Helper functions defined")

## 2. Explore the Gold Layer Data

In [None]:
-- Check data availability
SELECT 
    'DIM_CUSTOMERS' as table_name,
    COUNT(*) as record_count
FROM S3_GOLD.DIM_CUSTOMERS
UNION ALL
SELECT 
    'FCT_SALES',
    COUNT(*),
FROM S3_GOLD.FCT_SALES


In [None]:
# Preview customer data
customers_sample = query_to_df("""
    SELECT 
        CUSTOMER_ID,
        CUSTOMER_SEGMENT,
        AGE,
        ANNUAL_INCOME,
        TOTAL_ORDERS,
        TOTAL_SPENT,
        AVG_ORDER_VALUE,
        CHURN_PROBABILITY,
        ENGAGEMENT_LEVEL
    FROM S3_GOLD.DIM_CUSTOMERS
    LIMIT 10
""")

print("👥 Customer Data Sample:")
customers_sample

In [None]:
# Preview sales data
sales_sample = query_to_df("""
    SELECT 
        PURCHASE_TIMESTAMP,
        CUSTOMER_ID,
        CUSTOMER_SEGMENT,
        PRODUCT_CATEGORY,
        FINAL_TOTAL,
        SALES_CHANNEL,
        IS_REPEAT_CUSTOMER
    FROM S3_GOLD.FCT_SALES
    ORDER BY PURCHASE_TIMESTAMP DESC
    LIMIT 10
""")

print("🛒 Sales Data Sample:")
sales_sample

## 3. Time Series Forecasting with ML.FORECAST

We'll forecast daily sales revenue for the next 30 days using Snowflake's built-in forecasting function.

### 3.1 Prepare Historical Sales Data

In [None]:
# Aggregate daily sales for time series analysis
daily_sales = query_to_df("""
    SELECT 
        DATE(PURCHASE_TIMESTAMP) as sales_date,
        COUNT(*) as transaction_count,
        SUM(FINAL_TOTAL) as total_revenue,
        AVG(FINAL_TOTAL) as avg_order_value,
        COUNT(DISTINCT CUSTOMER_ID) as unique_customers
    FROM S3_GOLD.FCT_SALES
    GROUP BY sales_date
    ORDER BY sales_date
""")

print(f"📈 Historical data: {len(daily_sales)} days")

# Visualize historical sales trend
fig, ax = plt.subplots(2, 1, figsize=(14, 10))

ax[0].plot(daily_sales['SALES_DATE'], daily_sales['TOTAL_REVENUE'], linewidth=2, color='#1f77b4')
ax[0].set_title('Daily Total Revenue', fontsize=14, fontweight='bold')
ax[0].set_ylabel('Revenue ($)', fontsize=12)
ax[0].grid(True, alpha=0.3)

ax[1].plot(daily_sales['SALES_DATE'], daily_sales['TRANSACTION_COUNT'], linewidth=2, color='#ff7f0e')
ax[1].set_title('Daily Transaction Count', fontsize=14, fontweight='bold')
ax[1].set_xlabel('Date', fontsize=12)
ax[1].set_ylabel('Transactions', fontsize=12)
ax[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### 3.2 Create Forecasting View and Generate Predictions

In [None]:
CREATE OR REPLACE VIEW S3_GOLD.DAILY_SALES_TS AS
SELECT 
    DATE(PURCHASE_TIMESTAMP) as sales_date,
    SUM(FINAL_TOTAL) as total_revenue
FROM S3_GOLD.FCT_SALES
GROUP BY sales_date
ORDER BY sales_date

In [None]:
CREATE OR REPLACE SNOWFLAKE.ML.FORECAST daily_sales(
            INPUT_DATA => SYSTEM$REFERENCE('VIEW', 'S3_GOLD.DAILY_SALES_TS'),
            TIMESTAMP_COLNAME => 'SALES_DATE',
            TARGET_COLNAME => 'TOTAL_REVENUE'
        )

In [None]:
forecast_results = session.sql("""
SELECT
    ts::DATE AS forecast_date,
    forecast AS predicted_revenue,
    lower_bound,
    upper_bound
FROM TABLE(daily_sales!FORECAST(FORECASTING_PERIODS=>300))
""")

forecast_results = pd.DataFrame(forecast_results.collect())

# Calculate forecast statistics
avg_forecast = forecast_results['PREDICTED_REVENUE'].mean()
total_forecast = forecast_results['PREDICTED_REVENUE'].sum()
print(f"Forecast Summary:")
print(f"   Average Daily Revenue (Next 300 Days): ${avg_forecast:,.2f}")
print(f"   Total Projected Revenue (300 Days): ${total_forecast:,.2f}")

### 3.3 Visualize Forecast Results

In [None]:
# Combine historical and forecast data for visualization
# Filter to show only the last 1 year of historical data for better readability
from datetime import datetime, timedelta

latest_date = daily_sales['SALES_DATE'].max()
one_year_ago = latest_date - timedelta(days=1000)
daily_sales_last_year = daily_sales[daily_sales['SALES_DATE'] >= one_year_ago]

print(f"📅 Showing last year of data: {one_year_ago.strftime('%Y-%m-%d')} to {latest_date.strftime('%Y-%m-%d')}")
print(f"   Total historical days: {len(daily_sales)} | Displaying: {len(daily_sales_last_year)} days")

fig, ax = plt.subplots(figsize=(16, 8))

# Plot historical data (last year only)
ax.plot(daily_sales_last_year['SALES_DATE'], daily_sales_last_year['TOTAL_REVENUE'], 
        linewidth=2.5, label='Historical Revenue (Last Year)', color='#1f77b4', marker='o', markersize=3)

# Plot forecast
ax.plot(forecast_results['FORECAST_DATE'], forecast_results['PREDICTED_REVENUE'], 
        linewidth=2.5, label='Forecasted Revenue (Next 30 Days)', color='#ff7f0e', linestyle='--', marker='s', markersize=4)

# Plot confidence interval
ax.fill_between(forecast_results['FORECAST_DATE'], 
                 forecast_results['LOWER_BOUND'], 
                 forecast_results['UPPER_BOUND'], 
                 alpha=0.3, color='#ff7f0e', label='95% Confidence Interval')

# Add vertical line to separate historical from forecast
split_date = daily_sales['SALES_DATE'].max()
ax.axvline(x=split_date, color='red', linestyle=':', linewidth=2, label='Forecast Start')

ax.set_title('📈 Revenue Forecast: Last Year Historical vs. Next 30 Days Predicted', fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel('Date', fontsize=13, fontweight='bold')
ax.set_ylabel('Revenue ($)', fontsize=13, fontweight='bold')
ax.legend(loc='best', fontsize=11, framealpha=0.9)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Customer Segment Classification with ML.CLASSIFICATION

We'll train a classification model to predict customer segments based on behavioral features.

### 4.1 Prepare Training Data

In [None]:
# Explore customer segment distribution
segment_dist = query_to_df("""
    SELECT 
        CUSTOMER_SEGMENT,
        COUNT(*) as customer_count,
        AVG(AGE) as avg_age,
        AVG(ANNUAL_INCOME) as avg_income,
        AVG(TOTAL_ORDERS) as avg_orders,
        AVG(AVG_ORDER_VALUE) as avg_order_value,
        AVG(CHURN_PROBABILITY) as avg_churn_prob
    FROM S3_GOLD.DIM_CUSTOMERS
    GROUP BY CUSTOMER_SEGMENT
    ORDER BY customer_count DESC
""")

print("👥 Customer Segment Distribution:")
display(segment_dist)

# Visualize segment distribution
fig, ax = plt.subplots(1, 2, figsize=(14, 5))

# Segment counts
ax[0].bar(segment_dist['CUSTOMER_SEGMENT'], segment_dist['CUSTOMER_COUNT'], color=['#2ecc71', '#3498db', '#e74c3c'])
ax[0].set_title('Customer Count by Segment', fontsize=13, fontweight='bold')
ax[0].set_ylabel('Count', fontsize=11)
ax[0].grid(axis='y', alpha=0.3)

# Average order value by segment
ax[1].bar(segment_dist['CUSTOMER_SEGMENT'], segment_dist['AVG_ORDER_VALUE'], color=['#2ecc71', '#3498db', '#e74c3c'])
ax[1].set_title('Average Order Value by Segment', fontsize=13, fontweight='bold')
ax[1].set_ylabel('Avg Order Value ($)', fontsize=11)
ax[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Create training view with features and target
session.sql("""
    CREATE OR REPLACE VIEW S3_GOLD.CUSTOMER_CLASSIFICATION_TRAINING AS
    SELECT 
        CUSTOMER_ID,
        CUSTOMER_SEGMENT as target,  -- Target variable
        AGE,
        ANNUAL_INCOME,
        TOTAL_ORDERS,
        TOTAL_SPENT,
        AVG_ORDER_VALUE,
        CUSTOMER_TENURE_DAYS,
        CHURN_PROBABILITY,
        CASE WHEN PREFERRED_SALES_CHANNEL = 'web' THEN 1 ELSE 0 END as is_web_customer,
        CASE WHEN PREFERRED_SALES_CHANNEL = 'mobile' THEN 1 ELSE 0 END as is_mobile_customer,
        CASE WHEN ENGAGEMENT_LEVEL = 'Highly Engaged' THEN 1 ELSE 0 END as is_highly_engaged
    FROM S3_GOLD.DIM_CUSTOMERS
    WHERE CUSTOMER_SEGMENT IS NOT NULL
        AND AGE IS NOT NULL
        AND TOTAL_ORDERS > 0  -- Only customers with purchase history
""").collect()

print("✅ Training view created: S3_GOLD.CUSTOMER_CLASSIFICATION_TRAINING")
# SQL executed

# Preview training data
training_preview = query_to_df("""
    SELECT * FROM S3_GOLD.CUSTOMER_CLASSIFICATION_TRAINING LIMIT 10
""")
display(training_preview)

### 4.2 Train Classification Model

In [None]:
# Train a classification model using Snowflake ML
# Note: This creates a model object that can be used for predictions

print("🤖 Training customer segment classification model...")
print("   This may take a few moments...\n")

session.sql("""
    CREATE OR REPLACE SNOWFLAKE.ML.CLASSIFICATION CUSTOMER_SEGMENT_CLASSIFIER(
        INPUT_DATA => SYSTEM$REFERENCE('VIEW', 'S3_GOLD.CUSTOMER_CLASSIFICATION_TRAINING'),
        TARGET_COLNAME => 'TARGET',
        CONFIG_OBJECT => {'evaluate': TRUE, 'on_error': 'skip'}
    )
""").collect()

print("✅ Model trained successfully: CUSTOMER_SEGMENT_CLASSIFIER")
# SQL executed

### 4.3 Evaluate Model Performance

In [None]:
# Get model evaluation metrics
model_metrics = query_to_df("""
    CALL CUSTOMER_SEGMENT_CLASSIFIER!SHOW_EVALUATION_METRICS()
""")

print("📊 Model Evaluation Metrics:")
display(model_metrics)

### 4.4 Generate Predictions

In [None]:
# Create a test dataset (simulate new customers with unknown segments)
session.sql("""
    CREATE OR REPLACE VIEW S3_GOLD.CUSTOMER_CLASSIFICATION_TEST AS
    SELECT 
        CUSTOMER_ID,
        AGE,
        ANNUAL_INCOME,
        TOTAL_ORDERS,
        TOTAL_SPENT,
        AVG_ORDER_VALUE,
        CUSTOMER_TENURE_DAYS,
        CHURN_PROBABILITY,
        CASE WHEN PREFERRED_SALES_CHANNEL = 'web' THEN 1 ELSE 0 END as is_web_customer,
        CASE WHEN PREFERRED_SALES_CHANNEL = 'mobile' THEN 1 ELSE 0 END as is_mobile_customer,
        CASE WHEN ENGAGEMENT_LEVEL = 'Highly Engaged' THEN 1 ELSE 0 END as is_highly_engaged
    FROM S3_GOLD.DIM_CUSTOMERS
    WHERE TOTAL_ORDERS > 0
    ORDER BY RANDOM()
    LIMIT 100  -- Sample 100 customers for prediction
""").collect()

print("✅ Test dataset created")
# SQL executed

In [None]:
# Make predictions using the trained model
predictions = query_to_df("""
    SELECT 
        t.CUSTOMER_ID,
        t.AGE,
        t.ANNUAL_INCOME,
        t.TOTAL_ORDERS,
        t.AVG_ORDER_VALUE,
        c.CUSTOMER_SEGMENT as actual_segment,
        p.class as predicted_segment,
        ROUND(p.probability, 3) as confidence
    FROM S3_GOLD.CUSTOMER_CLASSIFICATION_TEST t
    INNER JOIN S3_GOLD.DIM_CUSTOMERS c ON t.CUSTOMER_ID = c.CUSTOMER_ID,
    TABLE(CUSTOMER_SEGMENT_CLASSIFIER!PREDICT(INPUT_DATA => OBJECT_CONSTRUCT(*))) p
    LIMIT 20
""")

print("🎯 Sample Predictions (Actual vs Predicted):")
display(predictions)

# Calculate accuracy
accuracy = (predictions['ACTUAL_SEGMENT'] == predictions['PREDICTED_SEGMENT']).mean() * 100
print(f"\n✨ Model Accuracy on Sample: {accuracy:.1f}%")

## 5. Business Insights & Applications

### 5.1 Identify High-Value Customers at Risk of Churning

In [None]:
# Find premium customers with high churn risk
at_risk_customers = query_to_df("""
    SELECT 
        CUSTOMER_ID,
        CUSTOMER_SEGMENT,
        AGE,
        TOTAL_SPENT,
        TOTAL_ORDERS,
        AVG_ORDER_VALUE,
        CHURN_PROBABILITY,
        ENGAGEMENT_LEVEL,
        CUSTOMER_TENURE_DAYS
    FROM S3_GOLD.DIM_CUSTOMERS
    WHERE CUSTOMER_SEGMENT = 'premium'
        AND CHURN_PROBABILITY > 0.5
        AND TOTAL_SPENT > 1000
    ORDER BY TOTAL_SPENT DESC
    LIMIT 20
""")

print("⚠️ High-Value Customers at Risk (Premium + High Churn Probability):")
display(at_risk_customers)

total_at_risk_value = at_risk_customers['TOTAL_SPENT'].sum()
print(f"\n💰 Total Lifetime Value at Risk: ${total_at_risk_value:,.2f}")

### 5.2 Product Category Performance by Segment

In [None]:
# Analyze product preferences by customer segment
category_performance = query_to_df("""
    SELECT 
        CUSTOMER_SEGMENT,
        PRODUCT_CATEGORY,
        COUNT(*) as purchase_count,
        SUM(FINAL_TOTAL) as total_revenue,
        AVG(FINAL_TOTAL) as avg_purchase_value
    FROM S3_GOLD.FCT_SALES
    GROUP BY CUSTOMER_SEGMENT, PRODUCT_CATEGORY
    ORDER BY CUSTOMER_SEGMENT, total_revenue DESC
""")

# Pivot for better visualization
pivot_revenue = category_performance.pivot_table(
    values='TOTAL_REVENUE', 
    index='PRODUCT_CATEGORY', 
    columns='CUSTOMER_SEGMENT',
    fill_value=0
)

print("🛍️ Revenue by Product Category & Customer Segment:")
display(pivot_revenue)

# Visualize
pivot_revenue.plot(kind='bar', figsize=(14, 6), width=0.8)
plt.title('Revenue by Product Category and Customer Segment', fontsize=14, fontweight='bold', pad=15)
plt.xlabel('Product Category', fontsize=12)
plt.ylabel('Total Revenue ($)', fontsize=12)
plt.legend(title='Customer Segment', fontsize=10)
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Summary & Next Steps

In [None]:
print("="*80)
print("🎉 DEMO SUMMARY")
print("="*80)
print("\n✅ Completed Tasks:")
print("   1. Connected to Snowflake Gold Layer")
print("   2. Time Series Forecasting with ML.FORECAST")
print("      - Forecasted 30 days of revenue")
print("      - Generated confidence intervals")
print("   3. Customer Classification with ML.CLASSIFICATION")
print("      - Trained segment prediction model")
print("      - Evaluated model performance")
print("      - Generated predictions for new customers")
print("   4. Business Insights Analysis")
print("      - Identified at-risk high-value customers")
print("      - Analyzed product preferences by segment")
print("\n🚀 Next Steps:")
print("   • Deploy models to production for real-time scoring")
print("   • Create automated alerts for high-churn customers")
print("   • Build personalized marketing campaigns by segment")
print("   • Integrate forecasts into inventory planning")
print("   • Set up automated model retraining on new data")
print("\n💡 Snowflake ML Benefits Demonstrated:")
print("   ✓ No data movement required")
print("   ✓ Native SQL integration")
print("   ✓ Automatic feature engineering")
print("   ✓ Scalable to billions of rows")
print("   ✓ Secure and governed")
print("="*80)

## Cleanup (Optional)

In [None]:
# Session is managed by Snowflake Notebooks
# No need to close the session
print("✅ Notebook complete! Session remains active for your Snowflake Notebook.")