# Churn Prediction with LSTM + Attention
## Using Real Snowflake Data with 12-Month Lookback

This notebook trains a churn prediction model using:
- **Data Source**: Snowflake tables (PHONE_USAGE_DATA, ACCOUNT_ATTRIBUTES_MONTHLY, CHURN_RECORDS)
- **Model**: LSTM with Attention mechanism
- **Features**: 12-month usage sequences
- **Target**: Predict if account will churn

---

## 1. Setup and Imports

In [None]:
# Core libraries
import warnings
import numpy as np
import pandas as pd
from datetime import datetime
warnings.filterwarnings('ignore')

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Sklearn
from sklearn.metrics import (
    auc, roc_curve, precision_score, recall_score,
    f1_score, confusion_matrix, precision_recall_curve,
    classification_report
)
from sklearn.preprocessing import StandardScaler, LabelEncoder

# Snowflake
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col, lit, count, sum as spark_sum
from snowflake.snowpark.types import StructType, StructField, StringType, FloatType, IntegerType

# Progress bar
from tqdm import tqdm

print("âœ“ Libraries imported successfully")
print(f"âœ“ PyTorch version: {torch.__version__}")
print(f"âœ“ CUDA available: {torch.cuda.is_available()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"âœ“ Using device: {device}")

## 2. Connect to Snowflake

In [None]:
# Get active Snowflake session
session = get_active_session()

# Set database and schema to MY_DATABASE.PUBLIC
session.use_database("MY_DATABASE")
session.use_schema("PUBLIC")

print("âœ“ Snowflake session active")
print(f"  Database: {session.get_current_database()}")
print(f"  Schema: {session.get_current_schema()}")
print(f"  Warehouse: {session.get_current_warehouse()}")
print(f"  Role: {session.get_current_role()}")

# Verify tables exist
print("\nâœ“ Verifying tables exist...")
tables_to_check = ["PHONE_USAGE_DATA", "ACCOUNT_ATTRIBUTES_MONTHLY", "CHURN_RECORDS"]
for table in tables_to_check:
    count = session.table(f"MY_DATABASE.PUBLIC.{table}").count()
    print(f"  {table}: {count:,} rows")

## 3. Load Data from Snowflake Tables

In [None]:
# Load data from Snowflake tables in MY_DATABASE.PUBLIC
print("Loading data from MY_DATABASE.PUBLIC...")

# 1. Usage data
usage_df = session.table("MY_DATABASE.PUBLIC.PHONE_USAGE_DATA").to_pandas()
usage_df['MONTH'] = pd.to_datetime(usage_df['MONTH'])
print(f"âœ“ PHONE_USAGE_DATA: {len(usage_df):,} rows")

# 2. Account attributes
account_df = session.table("MY_DATABASE.PUBLIC.ACCOUNT_ATTRIBUTES_MONTHLY").to_pandas()
account_df['MONTH'] = pd.to_datetime(account_df['MONTH'])
print(f"âœ“ ACCOUNT_ATTRIBUTES_MONTHLY: {len(account_df):,} rows")

# 3. Churn records
churn_df = session.table("MY_DATABASE.PUBLIC.CHURN_RECORDS").to_pandas()
churn_df['CHURN_DATE'] = pd.to_datetime(churn_df['CHURN_DATE'])
print(f"âœ“ CHURN_RECORDS: {len(churn_df):,} rows")

print(f"\nData date range: {usage_df['MONTH'].min()} to {usage_df['MONTH'].max()}")
print(f"Unique accounts: {usage_df['USERID'].nunique():,}")

In [None]:
# Display data schema
print("\n=== PHONE_USAGE_DATA Schema ===")
print(usage_df.dtypes)

print("\n=== Sample Usage Data ===")
print(usage_df.head())

## 4. Feature Engineering - Create 12-Month Sequences

In [None]:
def create_churn_sequences(usage_df, account_df, churn_df, max_lookback=12):
    """
    Create time-series sequences for churn prediction.
    
    For each account:
    - Extract up to 12 months of usage history
    - Create feature vectors for each month
    - Label with churn status
    
    Returns:
        DataFrame with columns: account_id, sequence (list of feature vectors), churn_label, seq_length
    """
    print("\n" + "="*70)
    print("Creating 12-Month Sequences for Churn Prediction")
    print("="*70)
    
    # Get list of churned accounts
    churned_accounts = set(churn_df['USERID'].unique())
    print(f"\nChurned accounts: {len(churned_accounts):,}")
    
    # Get all unique accounts
    all_accounts = usage_df['USERID'].unique()
    print(f"Total accounts: {len(all_accounts):,}")
    
    sequences = []
    
    for account_id in tqdm(all_accounts, desc="Processing accounts"):
        # Get usage history for this account
        account_usage = usage_df[usage_df['USERID'] == account_id].sort_values('MONTH')
        
        # Skip if less than 3 months of data
        if len(account_usage) < 3:
            continue
        
        # Take last 12 months (or all available if less)
        account_usage = account_usage.tail(max_lookback)
        
        # Check if churned
        is_churned = 1 if account_id in churned_accounts else 0
        
        # Extract features for each month
        monthly_features = []
        
        for _, row in account_usage.iterrows():
            # Normalize features to [0, 1] range (approximately)
            features = [
                # Call volume features
                min(row['PHONE_TOTAL_CALLS'] / 1000, 1.0),  # Normalized total calls
                min(row['PHONE_TOTAL_MINUTES_OF_USE'] / 10000, 1.0),  # Normalized minutes
                min(row['VOICE_CALLS'] / 1000, 1.0),  # Voice calls
                min(row['FAX_CALLS'] / 100, 1.0),  # Fax calls
                
                # Call direction features
                min(row['PHONE_TOTAL_NUM_INBOUND_CALLS'] / 500, 1.0),
                min(row['PHONE_TOTAL_NUM_OUTBOUND_CALLS'] / 500, 1.0),
                
                # Device usage features
                min(row['HARDPHONE_CALLS'] / 500, 1.0),
                min(row['SOFTPHONE_CALLS'] / 500, 1.0),
                min(row['MOBILE_CALLS'] / 500, 1.0),
                
                # Engagement feature
                min(row['PHONE_MAU'] / 100, 1.0),  # Monthly active users
            ]
            
            monthly_features.append(features)
        
        sequences.append({
            'account_id': account_id,
            'sequence': monthly_features,
            'churn_label': is_churned,
            'seq_length': len(monthly_features)
        })
    
    # Create DataFrame
    result_df = pd.DataFrame(sequences)
    
    print(f"\nâœ“ Created {len(result_df):,} sequences")
    print(f"  Churn rate: {result_df['churn_label'].mean():.2%}")
    print(f"  Avg sequence length: {result_df['seq_length'].mean():.1f} months")
    print(f"  Min/Max length: {result_df['seq_length'].min()}/{result_df['seq_length'].max()} months")
    
    # Show feature statistics
    print(f"\n  Feature vector size: {len(monthly_features[0])}")
    print(f"  Features:")
    feature_names = [
        'Total Calls (norm)', 'Total Minutes (norm)', 'Voice Calls (norm)', 'Fax Calls (norm)',
        'Inbound Calls (norm)', 'Outbound Calls (norm)',
        'Hardphone (norm)', 'Softphone (norm)', 'Mobile (norm)',
        'MAU (norm)'
    ]
    for i, name in enumerate(feature_names):
        print(f"    [{i}] {name}")
    
    return result_df

# Create sequences
sequence_df = create_churn_sequences(usage_df, account_df, churn_df, max_lookback=12)

In [None]:
# Display sample sequences
print("\n=== Sample Sequences ===")
print(sequence_df[['account_id', 'seq_length', 'churn_label']].head(10))

# Show example sequence
print("\n=== Example Sequence (first account) ===")
example_seq = sequence_df.iloc[0]['sequence']
print(f"Account ID: {sequence_df.iloc[0]['account_id']}")
print(f"Sequence length: {len(example_seq)} months")
print(f"Churn label: {sequence_df.iloc[0]['churn_label']}")
print(f"\nFirst 3 months of features:")
for i, month_features in enumerate(example_seq[:3]):
    print(f"  Month {i+1}: {[f'{x:.3f}' for x in month_features]}")

## 5. Split Data into Train/Val/Test

In [None]:
def split_data(df, train_ratio=0.7, val_ratio=0.15, random_state=42):
    """Split data into train, validation, and test sets"""
    # Shuffle data
    df = df.sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    n = len(df)
    train_size = int(train_ratio * n)
    val_size = int(val_ratio * n)
    
    train_df = df[:train_size].reset_index(drop=True)
    val_df = df[train_size:train_size+val_size].reset_index(drop=True)
    test_df = df[train_size+val_size:].reset_index(drop=True)
    
    print(f"\nData split:")
    print(f"  Train: {len(train_df):,} samples ({len(train_df)/n:.1%})")
    print(f"  Val:   {len(val_df):,} samples ({len(val_df)/n:.1%})")
    print(f"  Test:  {len(test_df):,} samples ({len(test_df)/n:.1%})")
    print(f"\nChurn rates:")
    print(f"  Train: {train_df['churn_label'].mean():.2%}")
    print(f"  Val:   {val_df['churn_label'].mean():.2%}")
    print(f"  Test:  {test_df['churn_label'].mean():.2%}")
    
    return train_df, val_df, test_df

train_df, val_df, test_df = split_data(sequence_df, train_ratio=0.7, val_ratio=0.15)

## 6. Define PyTorch Dataset

In [None]:
class ChurnDataset(Dataset):
    """Custom Dataset for churn prediction with time series data"""
    
    def __init__(self, df, max_lookback_window=12):
        self.df = df
        self.max_lookback_window = max_lookback_window
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Get sequence and label
        sequence = self.df.iloc[idx]['sequence']
        label = self.df.iloc[idx]['churn_label']
        
        # Convert to tensors
        sequence = torch.tensor(sequence, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)
        
        # Pad or truncate sequence to max_lookback_window
        seq_len = sequence.shape[0]
        if seq_len < self.max_lookback_window:
            # Pad with zeros at the beginning
            padding = torch.zeros((self.max_lookback_window - seq_len, sequence.shape[1]))
            sequence = torch.cat([padding, sequence], dim=0)
        elif seq_len > self.max_lookback_window:
            # Take last max_lookback_window timesteps
            sequence = sequence[-self.max_lookback_window:, :]
        
        return sequence, label

# Configuration
MAX_LOOKBACK_WINDOW = 12
BATCH_SIZE = 32
N_FEATURES = 10  # Number of features per timestep

# Create datasets
train_dataset = ChurnDataset(train_df, MAX_LOOKBACK_WINDOW)
val_dataset = ChurnDataset(val_df, MAX_LOOKBACK_WINDOW)
test_dataset = ChurnDataset(test_df, MAX_LOOKBACK_WINDOW)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print("âœ“ DataLoaders created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

# Test data loading
sample_batch = next(iter(train_loader))
print(f"\nSample batch shapes:")
print(f"  Sequences: {sample_batch[0].shape}")  # [batch_size, seq_len, n_features]
print(f"  Labels: {sample_batch[1].shape}")     # [batch_size]

## 7. Define LSTM with Attention Model

In [None]:
class LSTMWithAttention(nn.Module):
    """LSTM model with attention mechanism for sequence classification"""
    
    def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout=0.2):
        super(LSTMWithAttention, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # LSTM layer
        self.lstm = nn.LSTM(
            input_size, 
            hidden_size, 
            num_layers=num_layers, 
            batch_first=True, 
            dropout=dropout if num_layers > 1 else 0
        )
        
        # Attention layer
        self.attention = nn.Linear(hidden_size, 1)
        
        # Fully connected layers
        self.fc1 = nn.Linear(hidden_size, hidden_size // 2)
        self.fc2 = nn.Linear(hidden_size // 2, output_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x, return_attention=False):
        # LSTM output: [batch, seq_len, hidden_size]
        lstm_out, (hidden, cell) = self.lstm(x)
        
        # Attention mechanism
        # attention_weights: [batch, seq_len, 1]
        attention_weights = torch.softmax(self.attention(lstm_out), dim=1)
        
        # context_vector: [batch, hidden_size]
        context_vector = torch.sum(attention_weights * lstm_out, dim=1)
        
        # Fully connected layers
        out = self.relu(self.fc1(context_vector))
        out = self.dropout(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        
        if return_attention:
            return out, attention_weights
        return out

# Initialize model
HIDDEN_SIZE = 64
NUM_LAYERS = 2
DROPOUT = 0.3

model = LSTMWithAttention(
    input_size=N_FEATURES,
    hidden_size=HIDDEN_SIZE,
    output_size=1,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT
).to(device)

print("âœ“ Model initialized")
print(f"\nModel architecture:")
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")

## 8. Define Training and Evaluation Functions

In [None]:
def train_epoch(model, device, train_loader, criterion, optimizer):
    """Train model for one epoch"""
    model.train()
    running_loss = 0.0
    
    for sequence, labels in tqdm(train_loader, desc="Training", leave=False):
        sequence, labels = sequence.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(sequence)
        labels = labels.reshape(-1, 1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * labels.size(0)
    
    return running_loss / len(train_loader.dataset)


def evaluate_model(model, device, data_loader, criterion, threshold=0.5):
    """Evaluate model on a dataset"""
    model.eval()
    running_loss = 0.0
    y_probs = []
    y_labels = []
    
    with torch.no_grad():
        for sequence, labels in data_loader:
            sequence, labels = sequence.to(device), labels.to(device)
            outputs = model(sequence)
            labels = labels.reshape(-1, 1)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * labels.size(0)
            
            # Save predictions and labels
            y_probs.append(outputs.cpu())
            y_labels.append(labels.cpu())
    
    # Concatenate all batches
    y_probs = torch.cat(y_probs).numpy().flatten()
    y_labels = torch.cat(y_labels).numpy().flatten()
    
    # Calculate metrics
    avg_loss = running_loss / len(data_loader.dataset)
    y_pred = (y_probs > threshold).astype(int)
    
    precision = precision_score(y_labels, y_pred, zero_division=0)
    recall = recall_score(y_labels, y_pred, zero_division=0)
    f1 = f1_score(y_labels, y_pred, zero_division=0)
    
    return avg_loss, precision, recall, f1, y_probs, y_labels

print("âœ“ Training and evaluation functions defined")

## 9. Train the Model

In [None]:
# Training configuration
LEARNING_RATE = 0.001
NUM_EPOCHS = 50
PATIENCE = 7
THRESHOLD = 0.5

# Initialize optimizer and loss function
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Training loop with early stopping
best_val_loss = float('inf')
best_f1 = 0
patience_counter = 0
history = {
    'train_loss': [], 'val_loss': [],
    'val_precision': [], 'val_recall': [], 'val_f1': []
}

print("="*70)
print("Starting Training")
print("="*70)

for epoch in range(NUM_EPOCHS):
    # Train
    train_loss = train_epoch(model, device, train_loader, criterion, optimizer)
    
    # Validate
    val_loss, precision, recall, f1, _, _ = evaluate_model(
        model, device, val_loader, criterion, THRESHOLD
    )
    
    # Save history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_precision'].append(precision)
    history['val_recall'].append(recall)
    history['val_f1'].append(f1)
    
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"  Val Metrics - P: {precision:.4f} | R: {recall:.4f} | F1: {f1:.4f}")
    
    # Early stopping based on F1 score
    if f1 > best_f1:
        best_f1 = f1
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        best_model_state = model.state_dict()
        print("  âœ“ New best model!")
    else:
        patience_counter += 1
        print(f"  Patience: {patience_counter}/{PATIENCE}")
    
    if patience_counter >= PATIENCE:
        print(f"\nâœ“ Early stopping triggered after {epoch+1} epochs")
        break

# Load best model
model.load_state_dict(best_model_state)
print("\n" + "="*70)
print("Training completed! Best model loaded.")
print("="*70)

## 10. Evaluate on Test Set

In [None]:
print("\n" + "="*70)
print("Final Evaluation on Test Set")
print("="*70)

test_loss, test_precision, test_recall, test_f1, test_probs, test_labels = evaluate_model(
    model, device, test_loader, criterion, THRESHOLD
)

# Calculate AUC
fpr, tpr, _ = roc_curve(test_labels, test_probs)
test_auc = auc(fpr, tpr)

# Calculate confusion matrix
test_pred = (test_probs > THRESHOLD).astype(int)
cm = confusion_matrix(test_labels, test_pred)

print(f"\nTest Set Results:")
print(f"  Loss:      {test_loss:.4f}")
print(f"  Precision: {test_precision:.4f}")
print(f"  Recall:    {test_recall:.4f}")
print(f"  F1 Score:  {test_f1:.4f}")
print(f"  AUC-ROC:   {test_auc:.4f}")

print(f"\nConfusion Matrix:")
print(f"                Predicted")
print(f"              No Churn  Churn")
print(f"Actual No Churn  {cm[0,0]:4d}    {cm[0,1]:4d}")
print(f"       Churn     {cm[1,0]:4d}    {cm[1,1]:4d}")

# Classification report
print(f"\nDetailed Classification Report:")
print(classification_report(test_labels, test_pred, target_names=['No Churn', 'Churn']))

## 11. Visualizations

In [None]:
# Training History
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0, 0].legend(fontsize=11)
axes[0, 0].grid(True, alpha=0.3)

# Precision
axes[0, 1].plot(history['val_precision'], label='Precision', color='blue', linewidth=2)
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Precision', fontsize=12)
axes[0, 1].set_title('Validation Precision', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0, 1])

# Recall
axes[1, 0].plot(history['val_recall'], label='Recall', color='green', linewidth=2)
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Recall', fontsize=12)
axes[1, 0].set_title('Validation Recall', fontsize=14, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim([0, 1])

# F1 Score
axes[1, 1].plot(history['val_f1'], label='F1 Score', color='red', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('F1 Score', fontsize=12)
axes[1, 1].set_title('Validation F1 Score', fontsize=14, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim([0, 1])

plt.tight_layout()
plt.show()

print("âœ“ Training history plotted")

In [None]:
# ROC Curve and Confusion Matrix
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# ROC Curve
axes[0].plot(fpr, tpr, label=f'ROC curve (AUC = {test_auc:.3f})', linewidth=3, color='#2E86AB')
axes[0].plot([0, 1], [0, 1], 'k--', label='Random Classifier', linewidth=2)
axes[0].set_xlabel('False Positive Rate', fontsize=12)
axes[0].set_ylabel('True Positive Rate', fontsize=12)
axes[0].set_title('ROC Curve - Churn Prediction', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Confusion Matrix
im = axes[1].imshow(cm, interpolation='nearest', cmap='Blues')
axes[1].figure.colorbar(im, ax=axes[1])
axes[1].set(xticks=[0, 1], yticks=[0, 1],
            xticklabels=['No Churn', 'Churn'],
            yticklabels=['No Churn', 'Churn'],
            xlabel='Predicted Label',
            ylabel='True Label',
            title='Confusion Matrix')

# Add text annotations
thresh = cm.max() / 2
for i in range(2):
    for j in range(2):
        text = axes[1].text(j, i, f'{cm[i, j]}\n({cm[i, j]/cm.sum()*100:.1f}%)',
                           ha="center", va="center",
                           color="white" if cm[i, j] > thresh else "black",
                           fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print("âœ“ ROC curve and confusion matrix plotted")

In [None]:
# Precision-Recall Curve
precision_vals, recall_vals, thresholds_pr = precision_recall_curve(test_labels, test_probs)

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(recall_vals, precision_vals, linewidth=3, color='#A23B72')
ax.set_xlabel('Recall', fontsize=12)
ax.set_ylabel('Precision', fontsize=12)
ax.set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("âœ“ Precision-Recall curve plotted")

## 12. Generate Predictions for All Accounts

In [None]:
# Generate predictions for test set
model.eval()
test_predictions = []

with torch.no_grad():
    for idx in range(len(test_df)):
        sequence, label = test_dataset[idx]
        sequence = sequence.unsqueeze(0).to(device)  # Add batch dimension
        
        # Get prediction and attention weights
        prob, attention = model(sequence, return_attention=True)
        
        test_predictions.append({
            'account_id': test_df.iloc[idx]['account_id'],
            'actual_churn': int(label.item()),
            'churn_probability': float(prob.cpu().item()),
            'predicted_churn': int((prob.cpu().item() > THRESHOLD)),
            'sequence_length': int(test_df.iloc[idx]['seq_length'])
        })

# Create predictions DataFrame
predictions_df = pd.DataFrame(test_predictions)

print(f"\nâœ“ Generated predictions for {len(predictions_df):,} test accounts")
print(f"\nSample predictions:")
print(predictions_df.head(10))

# Show high-risk accounts
print(f"\n=== High Churn Risk Accounts (Probability > 0.7) ===")
high_risk = predictions_df[predictions_df['churn_probability'] > 0.7].sort_values('churn_probability', ascending=False)
print(high_risk[['account_id', 'churn_probability', 'actual_churn', 'predicted_churn']].head(10))

## 13. Save Results to Snowflake

In [None]:
# Save predictions to Snowflake (MY_DATABASE.PUBLIC)
print("\nSaving predictions to MY_DATABASE.PUBLIC.CHURN_PREDICTIONS...")

try:
    # Create Snowpark DataFrame from predictions
    predictions_snowpark = session.create_dataframe(predictions_df)
    
    # Write to table in MY_DATABASE.PUBLIC
    predictions_snowpark.write.mode("overwrite").save_as_table("MY_DATABASE.PUBLIC.CHURN_PREDICTIONS")
    
    print(f"âœ“ Predictions saved to MY_DATABASE.PUBLIC.CHURN_PREDICTIONS table")
    print(f"  Rows: {len(predictions_df):,}")
    
    # Verify
    result_count = session.table("MY_DATABASE.PUBLIC.CHURN_PREDICTIONS").count()
    print(f"  Verified row count: {result_count:,}")
    
except Exception as e:
    print(f"âœ— Error saving predictions: {str(e)}")

In [None]:
# Save model metrics to Snowflake (MY_DATABASE.PUBLIC)
print("\nSaving model metrics to MY_DATABASE.PUBLIC.CHURN_MODEL_METRICS...")

try:
    # Create metrics DataFrame
    metrics_df = pd.DataFrame({
        'model_name': ['LSTM_with_Attention'],
        'train_date': [datetime.now()],
        'test_loss': [test_loss],
        'test_precision': [test_precision],
        'test_recall': [test_recall],
        'test_f1_score': [test_f1],
        'test_auc_roc': [test_auc],
        'num_features': [N_FEATURES],
        'lookback_window': [MAX_LOOKBACK_WINDOW],
        'hidden_size': [HIDDEN_SIZE],
        'num_layers': [NUM_LAYERS],
        'learning_rate': [LEARNING_RATE],
        'train_samples': [len(train_df)],
        'test_samples': [len(test_df)]
    })
    
    # Save to Snowflake in MY_DATABASE.PUBLIC
    metrics_snowpark = session.create_dataframe(metrics_df)
    metrics_snowpark.write.mode("append").save_as_table("MY_DATABASE.PUBLIC.CHURN_MODEL_METRICS")
    
    print(f"âœ“ Model metrics saved to MY_DATABASE.PUBLIC.CHURN_MODEL_METRICS table")
    
except Exception as e:
    print(f"âœ— Error saving metrics: {str(e)}")

## 14. Summary and Next Steps

In [None]:
print("\n" + "="*70)
print("CHURN PREDICTION MODEL - SUMMARY")
print("="*70)

print(f"\nðŸ“Š Model Performance:")
print(f"  F1 Score:   {test_f1:.4f}")
print(f"  Precision:  {test_precision:.4f}")
print(f"  Recall:     {test_recall:.4f}")
print(f"  AUC-ROC:    {test_auc:.4f}")

print(f"\nðŸ“ˆ Data Statistics:")
print(f"  Data source: MY_DATABASE.PUBLIC")
print(f"  Total accounts processed: {len(sequence_df):,}")
print(f"  Training samples: {len(train_df):,}")
print(f"  Test samples: {len(test_df):,}")
print(f"  Churn rate: {sequence_df['churn_label'].mean():.2%}")

print(f"\nðŸŽ¯ Model Configuration:")
print(f"  Architecture: LSTM with Attention")
print(f"  Features: {N_FEATURES}")
print(f"  Lookback window: {MAX_LOOKBACK_WINDOW} months")
print(f"  Hidden size: {HIDDEN_SIZE}")
print(f"  Layers: {NUM_LAYERS}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")

print(f"\nðŸ’¾ Saved to Snowflake (MY_DATABASE.PUBLIC):")
print(f"  CHURN_PREDICTIONS - Account-level predictions")
print(f"  CHURN_MODEL_METRICS - Model performance metrics")

print(f"\nâœ… Next Steps:")
print(f"  1. Query predictions: SELECT * FROM MY_DATABASE.PUBLIC.CHURN_PREDICTIONS WHERE CHURN_PROBABILITY > 0.7")
print(f"  2. Monitor model performance over time")
print(f"  3. Retrain model with new data periodically")
print(f"  4. Integrate predictions into business workflows")
print(f"  5. Consider feature engineering improvements")

print("\n" + "="*70)
print("âœ“ Churn Prediction Pipeline Complete!")
print("="*70)