# Databricks Integration with Pydantic

This notebook demonstrates how to integrate Pydantic models with Databricks workflows for robust data processing, validation, and machine learning pipelines.

## Learning Objectives
- Use Pydantic models in Databricks environments
- Implement data validation in ETL pipelines
- Create type-safe data transformations
- Build ML feature stores with validated schemas
- Handle data quality and monitoring

In [None]:
# Install required packages for Databricks environment
# %pip install pydantic pandas numpy scikit-learn
# Note: In actual Databricks, use %pip install or dbutils.library.installPyPI()

## 1. Setting Up Pydantic Models for Data Pipeline

Define data schemas that ensure consistency across your data pipeline:

In [None]:
from pydantic import BaseModel, Field, validator
from typing import Optional, List, Dict, Any, Union
from datetime import datetime, date
from enum import Enum
import pandas as pd
import numpy as np
from decimal import Decimal

# Base model for audit tracking
class DataPipelineModel(BaseModel):
    """Base model for all data pipeline entities"""
    created_at: datetime = Field(default_factory=datetime.now)
    source_system: str = Field(default="databricks")
    data_quality_score: Optional[float] = Field(None, ge=0.0, le=1.0)
    
    class Config:
        json_encoders = {
            datetime: lambda dt: dt.isoformat(),
            date: lambda d: d.isoformat()
        }

# Raw data models
class CustomerRaw(DataPipelineModel):
    """Raw customer data from source systems"""
    customer_id: str
    first_name: str
    last_name: str
    email: str
    phone: Optional[str] = None
    birth_date: Optional[date] = None
    registration_date: date
    country: str
    
    @validator('email')
    def validate_email(cls, v):
        if '@' not in v or '.' not in v.split('@')[1]:
            raise ValueError('Invalid email format')
        return v.lower().strip()
    
    @validator('customer_id')
    def validate_customer_id(cls, v):
        if not v or len(v) < 3:
            raise ValueError('Customer ID must be at least 3 characters')
        return v.strip().upper()

class TransactionRaw(DataPipelineModel):
    """Raw transaction data"""
    transaction_id: str
    customer_id: str
    product_id: str
    quantity: int = Field(gt=0)
    unit_price: float = Field(gt=0)
    transaction_date: datetime
    currency: str = Field(default="USD")
    payment_method: str
    discount_amount: float = Field(default=0.0, ge=0)
    
    @property
    def total_amount(self) -> float:
        return (self.quantity * self.unit_price) - self.discount_amount
    
    @validator('currency')
    def validate_currency(cls, v):
        valid_currencies = ['USD', 'EUR', 'GBP', 'JPY', 'CAD']
        if v.upper() not in valid_currencies:
            raise ValueError(f'Currency must be one of {valid_currencies}')
        return v.upper()

print("Data models defined successfully!")

## 2. Data Ingestion and Validation

Implement robust data ingestion with validation:

In [None]:
from pydantic import ValidationError
from typing import Tuple, List
import json

class DataIngestionPipeline:
    """Data ingestion pipeline with Pydantic validation"""
    
    def __init__(self):
        self.validation_errors = []
        self.processed_records = 0
        self.failed_records = 0
    
    def validate_and_transform_customers(self, raw_data: List[Dict]) -> Tuple[List[CustomerRaw], List[Dict]]:
        """Validate customer data and return valid/invalid records"""
        valid_customers = []
        invalid_records = []
        
        for i, record in enumerate(raw_data):
            try:
                # Validate and create customer instance
                customer = CustomerRaw(**record)
                
                # Additional business logic validation
                if customer.birth_date and customer.birth_date > date.today():
                    raise ValueError("Birth date cannot be in the future")
                
                # Calculate data quality score
                quality_score = self._calculate_customer_quality_score(customer)
                customer.data_quality_score = quality_score
                
                valid_customers.append(customer)
                self.processed_records += 1
                
            except ValidationError as e:
                error_info = {
                    'record_index': i,
                    'raw_data': record,
                    'errors': [f"{error['loc'][0]}: {error['msg']}" for error in e.errors()],
                    'error_type': 'validation'
                }
                invalid_records.append(error_info)
                self.failed_records += 1
            except Exception as e:
                error_info = {
                    'record_index': i,
                    'raw_data': record,
                    'errors': [str(e)],
                    'error_type': 'business_logic'
                }
                invalid_records.append(error_info)
                self.failed_records += 1
        
        return valid_customers, invalid_records
    
    def _calculate_customer_quality_score(self, customer: CustomerRaw) -> float:
        """Calculate data quality score for customer record"""
        score = 0.0
        
        # Email validation (0.3)
        if customer.email and '@' in customer.email:
            score += 0.3
        
        # Phone number (0.2)
        if customer.phone and len(customer.phone.strip()) >= 10:
            score += 0.2
        
        # Birth date (0.2)
        if customer.birth_date:
            score += 0.2
        
        # Name completeness (0.3)
        if customer.first_name and customer.last_name:
            if len(customer.first_name.strip()) > 1 and len(customer.last_name.strip()) > 1:
                score += 0.3
        
        return min(score, 1.0)
    
    def get_pipeline_stats(self) -> Dict[str, Any]:
        """Get pipeline processing statistics"""
        total = self.processed_records + self.failed_records
        success_rate = (self.processed_records / total) * 100 if total > 0 else 0
        
        return {
            'total_records': total,
            'successful': self.processed_records,
            'failed': self.failed_records,
            'success_rate': round(success_rate, 2)
        }

# Sample raw customer data (simulating data from various sources)
raw_customer_data = [
    {
        "customer_id": "CUST001",
        "first_name": "Alice",
        "last_name": "Johnson",
        "email": "alice.johnson@email.com",
        "phone": "+1-555-123-4567",
        "birth_date": "1985-03-15",
        "registration_date": "2023-01-10",
        "country": "USA"
    },
    {
        "customer_id": "CUST002",
        "first_name": "Bob",
        "last_name": "Smith",
        "email": "bob@email.com",
        "registration_date": "2023-02-01",
        "country": "Canada"
    },
    {
        "customer_id": "",  # Invalid - empty ID
        "first_name": "Charlie",
        "last_name": "Brown",
        "email": "invalid-email",  # Invalid email
        "registration_date": "2023-03-01",
        "country": "UK"
    },
    {
        "customer_id": "CUST004",
        "first_name": "Diana",
        "last_name": "Wilson",
        "email": "diana@email.com",
        "phone": "555-9876",
        "birth_date": "1990-07-22",
        "registration_date": "2023-04-15",
        "country": "Australia"
    }
]

# Process the data
pipeline = DataIngestionPipeline()
valid_customers, invalid_customers = pipeline.validate_and_transform_customers(raw_customer_data)

print(f"Data Ingestion Results:")
print(f"Valid customers: {len(valid_customers)}")
print(f"Invalid records: {len(invalid_customers)}")
print(f"Pipeline stats: {pipeline.get_pipeline_stats()}")

print("\nValid Customers:")
for customer in valid_customers:
    print(f"  - {customer.customer_id}: {customer.first_name} {customer.last_name} (Quality: {customer.data_quality_score:.1f})")

print("\nInvalid Records:")
for invalid in invalid_customers:
    print(f"  - Record {invalid['record_index']}: {'; '.join(invalid['errors'])}")

## 3. Feature Engineering with Pydantic

Create feature engineering pipelines with validated schemas:

In [None]:
from datetime import timedelta
import math

class CustomerFeatures(BaseModel):
    """Engineered features for customer analytics"""
    customer_id: str
    
    # Demographic features
    age: Optional[int] = Field(None, ge=0, le=150)
    age_group: Optional[str] = None
    country: str
    
    # Engagement features
    days_since_registration: int = Field(ge=0)
    registration_month: int = Field(ge=1, le=12)
    registration_year: int = Field(ge=1900, le=2030)
    
    # Data quality features
    has_phone: bool
    has_birth_date: bool
    data_completeness_score: float = Field(ge=0.0, le=1.0)
    
    # Computed fields
    is_new_customer: bool = False  # < 30 days since registration
    customer_tier: str = "bronze"  # bronze, silver, gold based on completeness
    
    @validator('age_group', pre=False, always=True)
    def set_age_group(cls, v, values):
        age = values.get('age')
        if age is None:
            return "unknown"
        elif age < 25:
            return "young"
        elif age < 45:
            return "middle"
        elif age < 65:
            return "mature"
        else:
            return "senior"
    
    @validator('customer_tier', pre=False, always=True)
    def set_customer_tier(cls, v, values):
        score = values.get('data_completeness_score', 0)
        if score >= 0.8:
            return "gold"
        elif score >= 0.6:
            return "silver"
        else:
            return "bronze"

class FeatureEngineer:
    """Feature engineering pipeline"""
    
    def __init__(self, reference_date: date = None):
        self.reference_date = reference_date or date.today()
    
    def engineer_customer_features(self, customers: List[CustomerRaw]) -> List[CustomerFeatures]:
        """Generate features from customer data"""
        features = []
        
        for customer in customers:
            # Calculate age
            age = None
            if customer.birth_date:
                age = self._calculate_age(customer.birth_date)
            
            # Calculate days since registration
            days_since_reg = (self.reference_date - customer.registration_date).days
            
            # Create features
            feature_record = CustomerFeatures(
                customer_id=customer.customer_id,
                age=age,
                country=customer.country,
                days_since_registration=days_since_reg,
                registration_month=customer.registration_date.month,
                registration_year=customer.registration_date.year,
                has_phone=bool(customer.phone),
                has_birth_date=bool(customer.birth_date),
                data_completeness_score=customer.data_quality_score or 0.0,
                is_new_customer=days_since_reg <= 30
            )
            
            features.append(feature_record)
        
        return features
    
    def _calculate_age(self, birth_date: date) -> int:
        """Calculate age from birth date"""
        today = self.reference_date
        age = today.year - birth_date.year
        
        # Adjust if birthday hasn't occurred this year
        if today.month < birth_date.month or (today.month == birth_date.month and today.day < birth_date.day):
            age -= 1
            
        return max(0, age)

# Generate features
feature_engineer = FeatureEngineer()
customer_features = feature_engineer.engineer_customer_features(valid_customers)

print("Customer Features Generated:")
print(f"Total features: {len(customer_features)}")

for features in customer_features:
    print(f"\n{features.customer_id}:")
    print(f"  Age: {features.age} ({features.age_group})")
    print(f"  Days since registration: {features.days_since_registration}")
    print(f"  Data completeness: {features.data_completeness_score:.1f}")
    print(f"  Customer tier: {features.customer_tier}")
    print(f"  New customer: {features.is_new_customer}")

# Convert to DataFrame for further processing
features_df_data = [feature.dict() for feature in customer_features]
print(f"\nFeatures ready for DataFrame: {len(features_df_data)} records")
print(f"Feature columns: {list(features_df_data[0].keys())}")

## 4. Data Quality Monitoring

Implement data quality monitoring with Pydantic models:

In [None]:
from collections import defaultdict
from typing import Counter

class DataQualityMetrics(BaseModel):
    """Data quality metrics for monitoring"""
    dataset_name: str
    total_records: int = Field(ge=0)
    valid_records: int = Field(ge=0)
    invalid_records: int = Field(ge=0)
    
    # Quality scores
    average_quality_score: float = Field(ge=0.0, le=1.0)
    completeness_rate: float = Field(ge=0.0, le=1.0)
    validity_rate: float = Field(ge=0.0, le=1.0)
    
    # Field-specific metrics
    field_completeness: Dict[str, float] = {}
    field_validity: Dict[str, float] = {}
    
    # Error analysis
    common_errors: List[Dict[str, Any]] = []
    error_categories: Dict[str, int] = {}
    
    # Timestamp
    measured_at: datetime = Field(default_factory=datetime.now)
    
    @validator('validity_rate', pre=False, always=True)
    def calculate_validity_rate(cls, v, values):
        total = values.get('total_records', 0)
        valid = values.get('valid_records', 0)
        return (valid / total) if total > 0 else 0.0
    
    class Config:
        json_encoders = {
            datetime: lambda dt: dt.isoformat()
        }

class DataQualityMonitor:
    """Monitor data quality across pipelines"""
    
    def __init__(self):
        self.quality_history = []
    
    def analyze_customer_quality(self, 
                                valid_customers: List[CustomerRaw], 
                                invalid_records: List[Dict]) -> DataQualityMetrics:
        """Analyze data quality for customer dataset"""
        
        total_records = len(valid_customers) + len(invalid_records)
        valid_count = len(valid_customers)
        invalid_count = len(invalid_records)
        
        # Calculate average quality score
        avg_quality = 0.0
        if valid_customers:
            total_quality = sum(c.data_quality_score or 0 for c in valid_customers)
            avg_quality = total_quality / len(valid_customers)
        
        # Field completeness analysis
        field_completeness = self._analyze_field_completeness(valid_customers)
        
        # Error analysis
        error_analysis = self._analyze_errors(invalid_records)
        
        # Calculate overall completeness
        completeness_rate = sum(field_completeness.values()) / len(field_completeness) if field_completeness else 0.0
        
        metrics = DataQualityMetrics(
            dataset_name="customers",
            total_records=total_records,
            valid_records=valid_count,
            invalid_records=invalid_count,
            average_quality_score=avg_quality,
            completeness_rate=completeness_rate,
            field_completeness=field_completeness,
            common_errors=error_analysis['common_errors'],
            error_categories=error_analysis['error_categories']
        )
        
        self.quality_history.append(metrics)
        return metrics
    
    def _analyze_field_completeness(self, customers: List[CustomerRaw]) -> Dict[str, float]:
        """Analyze completeness of each field"""
        if not customers:
            return {}
        
        field_counts = defaultdict(int)
        total = len(customers)
        
        for customer in customers:
            # Check each field for completeness
            if customer.email:
                field_counts['email'] += 1
            if customer.phone:
                field_counts['phone'] += 1
            if customer.birth_date:
                field_counts['birth_date'] += 1
            if customer.first_name:
                field_counts['first_name'] += 1
            if customer.last_name:
                field_counts['last_name'] += 1
            if customer.country:
                field_counts['country'] += 1
        
        # Calculate percentages
        return {field: (count / total) for field, count in field_counts.items()}
    
    def _analyze_errors(self, invalid_records: List[Dict]) -> Dict[str, Any]:
        """Analyze common errors and patterns"""
        error_counts = defaultdict(int)
        error_categories = defaultdict(int)
        
        for record in invalid_records:
            error_type = record.get('error_type', 'unknown')
            error_categories[error_type] += 1
            
            for error in record.get('errors', []):
                error_counts[error] += 1
        
        # Get top 5 most common errors
        common_errors = [
            {'error': error, 'count': count} 
            for error, count in sorted(error_counts.items(), key=lambda x: x[1], reverse=True)[:5]
        ]
        
        return {
            'common_errors': common_errors,
            'error_categories': dict(error_categories)
        }
    
    def generate_quality_report(self, metrics: DataQualityMetrics) -> str:
        """Generate a human-readable quality report"""
        report = f"""
=== DATA QUALITY REPORT ===
Dataset: {metrics.dataset_name}
Measured at: {metrics.measured_at.strftime('%Y-%m-%d %H:%M:%S')}

OVERALL METRICS:
  Total Records: {metrics.total_records:,}
  Valid Records: {metrics.valid_records:,} ({metrics.validity_rate:.1%})
  Invalid Records: {metrics.invalid_records:,}
  Average Quality Score: {metrics.average_quality_score:.2f}
  Data Completeness: {metrics.completeness_rate:.1%}

FIELD COMPLETENESS:
"""
        for field, rate in metrics.field_completeness.items():
            report += f"  {field}: {rate:.1%}\n"
        
        if metrics.common_errors:
            report += "\nCOMMON ERRORS:\n"
            for error in metrics.common_errors:
                report += f"  - {error['error']} ({error['count']} times)\n"
        
        if metrics.error_categories:
            report += "\nERROR CATEGORIES:\n"
            for category, count in metrics.error_categories.items():
                report += f"  - {category}: {count}\n"
        
        return report

# Monitor data quality
monitor = DataQualityMonitor()
quality_metrics = monitor.analyze_customer_quality(valid_customers, invalid_customers)

# Generate and print report
quality_report = monitor.generate_quality_report(quality_metrics)
print(quality_report)

# Quality metrics as JSON (for storing in data lake)
print("\nQuality Metrics JSON:")
print(quality_metrics.json(indent=2))

## 5. ML Feature Store Integration

Create ML-ready feature stores with Pydantic validation:

In [None]:
from sklearn.preprocessing import StandardScaler, LabelEncoder
import numpy as np

class MLFeatureVector(BaseModel):
    """ML-ready feature vector"""
    customer_id: str
    
    # Numerical features
    age_normalized: float = Field(ge=0.0, le=1.0)
    days_since_registration_log: float
    data_completeness_score: float = Field(ge=0.0, le=1.0)
    
    # Categorical features (encoded)
    country_encoded: int = Field(ge=0)
    age_group_encoded: int = Field(ge=0)
    customer_tier_encoded: int = Field(ge=0)
    
    # Binary features
    has_phone: int = Field(ge=0, le=1)
    has_birth_date: int = Field(ge=0, le=1)
    is_new_customer: int = Field(ge=0, le=1)
    
    # Feature metadata
    feature_version: str = "v1.0"
    created_at: datetime = Field(default_factory=datetime.now)
    
    def to_array(self) -> np.ndarray:
        """Convert to numpy array for ML models"""
        return np.array([
            self.age_normalized,
            self.days_since_registration_log,
            self.data_completeness_score,
            self.country_encoded,
            self.age_group_encoded,
            self.customer_tier_encoded,
            self.has_phone,
            self.has_birth_date,
            self.is_new_customer
        ], dtype=np.float32)
    
    @classmethod
    def get_feature_names(cls) -> List[str]:
        """Get feature names for ML models"""
        return [
            'age_normalized',
            'days_since_registration_log',
            'data_completeness_score',
            'country_encoded',
            'age_group_encoded',
            'customer_tier_encoded',
            'has_phone',
            'has_birth_date',
            'is_new_customer'
        ]

class FeatureStore:
    """ML Feature Store with validation"""
    
    def __init__(self):
        self.encoders = {}
        self.scalers = {}
        self.feature_stats = {}
    
    def prepare_ml_features(self, features: List[CustomerFeatures]) -> List[MLFeatureVector]:
        """Transform customer features to ML-ready features"""
        
        # Extract data for preprocessing
        ages = [f.age for f in features if f.age is not None]
        days_reg = [f.days_since_registration for f in features]
        countries = [f.country for f in features]
        age_groups = [f.age_group for f in features]
        tiers = [f.customer_tier for f in features]
        
        # Fit encoders and scalers
        if 'country' not in self.encoders:
            self.encoders['country'] = LabelEncoder()
            self.encoders['country'].fit(countries)
        
        if 'age_group' not in self.encoders:
            self.encoders['age_group'] = LabelEncoder()
            self.encoders['age_group'].fit(age_groups)
        
        if 'customer_tier' not in self.encoders:
            self.encoders['customer_tier'] = LabelEncoder()
            self.encoders['customer_tier'].fit(tiers)
        
        # Age normalization (0-100 years -> 0-1)
        if ages:
            self.feature_stats['age_max'] = max(100, max(ages))
        else:
            self.feature_stats['age_max'] = 100
        
        # Create ML feature vectors
        ml_features = []
        
        for feature in features:
            # Normalize age
            age_norm = (feature.age / self.feature_stats['age_max']) if feature.age else 0.0
            
            # Log transform days since registration
            days_log = np.log1p(feature.days_since_registration)
            
            # Encode categorical variables
            country_enc = self.encoders['country'].transform([feature.country])[0]
            age_group_enc = self.encoders['age_group'].transform([feature.age_group])[0]
            tier_enc = self.encoders['customer_tier'].transform([feature.customer_tier])[0]
            
            # Create ML feature vector
            ml_feature = MLFeatureVector(
                customer_id=feature.customer_id,
                age_normalized=age_norm,
                days_since_registration_log=days_log,
                data_completeness_score=feature.data_completeness_score,
                country_encoded=int(country_enc),
                age_group_encoded=int(age_group_enc),
                customer_tier_encoded=int(tier_enc),
                has_phone=int(feature.has_phone),
                has_birth_date=int(feature.has_birth_date),
                is_new_customer=int(feature.is_new_customer)
            )
            
            ml_features.append(ml_feature)
        
        return ml_features
    
    def create_training_dataset(self, ml_features: List[MLFeatureVector]) -> Tuple[np.ndarray, List[str]]:
        """Create training dataset from ML features"""
        feature_matrix = np.array([f.to_array() for f in ml_features])
        feature_names = MLFeatureVector.get_feature_names()
        customer_ids = [f.customer_id for f in ml_features]
        
        return feature_matrix, feature_names, customer_ids
    
    def get_feature_stats(self, ml_features: List[MLFeatureVector]) -> Dict[str, Any]:
        """Calculate feature statistics"""
        feature_matrix, feature_names, _ = self.create_training_dataset(ml_features)
        
        stats = {}
        for i, name in enumerate(feature_names):
            column = feature_matrix[:, i]
            stats[name] = {
                'mean': float(np.mean(column)),
                'std': float(np.std(column)),
                'min': float(np.min(column)),
                'max': float(np.max(column)),
                'null_count': int(np.sum(np.isnan(column)))
            }
        
        return stats

# Create feature store and prepare ML features
feature_store = FeatureStore()
ml_features = feature_store.prepare_ml_features(customer_features)

print(f"ML Features prepared: {len(ml_features)}")

# Create training dataset
X, feature_names, customer_ids = feature_store.create_training_dataset(ml_features)
print(f"\nTraining matrix shape: {X.shape}")
print(f"Feature names: {feature_names}")

# Show sample data
print("\nSample ML Features:")
for i, ml_feat in enumerate(ml_features[:2]):
    print(f"\nCustomer {ml_feat.customer_id}:")
    array_repr = ml_feat.to_array()
    for j, name in enumerate(feature_names):
        print(f"  {name}: {array_repr[j]:.4f}")

# Feature statistics
feature_stats = feature_store.get_feature_stats(ml_features)
print("\nFeature Statistics:")
for name, stats in feature_stats.items():
    print(f"  {name}: mean={stats['mean']:.3f}, std={stats['std']:.3f}")

## 6. Databricks-Specific Integration Patterns

Patterns for using Pydantic in Databricks notebooks and jobs:

In [None]:
# Databricks-specific utilities
class DatabricksConfig(BaseModel):
    """Configuration for Databricks jobs"""
    job_name: str
    cluster_size: str = "small"
    timeout_minutes: int = Field(default=60, gt=0)
    max_retries: int = Field(default=3, ge=0)
    
    # Data paths
    input_path: str
    output_path: str
    checkpoint_path: Optional[str] = None
    
    # Quality thresholds
    min_quality_score: float = Field(default=0.7, ge=0.0, le=1.0)
    max_error_rate: float = Field(default=0.1, ge=0.0, le=1.0)
    
    @validator('input_path', 'output_path')
    def validate_paths(cls, v):
        if not v.startswith(('/dbfs/', 's3://', 'abfss://')):
            raise ValueError('Path must be valid Databricks path')
        return v

class DatabricksJobResult(BaseModel):
    """Result of Databricks job execution"""
    job_name: str
    status: str = Field(regex=r'^(success|failed|running|cancelled)$')
    start_time: datetime
    end_time: Optional[datetime] = None
    
    # Processing metrics
    records_processed: int = Field(ge=0)
    records_failed: int = Field(ge=0)
    processing_time_seconds: Optional[float] = None
    
    # Quality metrics
    average_quality_score: Optional[float] = Field(None, ge=0.0, le=1.0)
    error_rate: Optional[float] = Field(None, ge=0.0, le=1.0)
    
    # Output information
    output_location: Optional[str] = None
    output_format: str = "parquet"
    
    # Error information
    error_message: Optional[str] = None
    stack_trace: Optional[str] = None
    
    @property
    def success_rate(self) -> float:
        total = self.records_processed + self.records_failed
        return (self.records_processed / total) if total > 0 else 0.0
    
    @validator('end_time')
    def validate_end_time(cls, v, values):
        start_time = values.get('start_time')
        if v and start_time and v < start_time:
            raise ValueError('End time cannot be before start time')
        return v

def simulate_databricks_job(config: DatabricksConfig) -> DatabricksJobResult:
    """Simulate a Databricks job execution"""
    import time
    import random
    
    start_time = datetime.now()
    
    # Simulate processing
    time.sleep(0.1)  # Simulate processing time
    
    # Simulate results
    total_records = random.randint(1000, 10000)
    failed_records = int(total_records * random.uniform(0, config.max_error_rate))
    processed_records = total_records - failed_records
    
    end_time = datetime.now()
    processing_time = (end_time - start_time).total_seconds()
    
    # Determine status based on quality thresholds
    error_rate = failed_records / total_records
    avg_quality = random.uniform(0.6, 1.0)
    
    status = "success" if (error_rate <= config.max_error_rate and 
                         avg_quality >= config.min_quality_score) else "failed"
    
    result = DatabricksJobResult(
        job_name=config.job_name,
        status=status,
        start_time=start_time,
        end_time=end_time,
        records_processed=processed_records,
        records_failed=failed_records,
        processing_time_seconds=processing_time,
        average_quality_score=avg_quality,
        error_rate=error_rate,
        output_location=config.output_path,
        error_message="Quality threshold not met" if status == "failed" else None
    )
    
    return result

# Example Databricks job configuration
job_config = DatabricksConfig(
    job_name="customer_data_processing",
    cluster_size="medium",
    timeout_minutes=120,
    input_path="/dbfs/raw/customers/",
    output_path="/dbfs/processed/customers/",
    checkpoint_path="/dbfs/checkpoints/customers/",
    min_quality_score=0.8,
    max_error_rate=0.05
)

print("Databricks Job Configuration:")
print(job_config.json(indent=2))

# Simulate job execution
print("\nExecuting Databricks job...")
job_result = simulate_databricks_job(job_config)

print(f"\nJob Result:")
print(f"Status: {job_result.status}")
print(f"Records processed: {job_result.records_processed:,}")
print(f"Records failed: {job_result.records_failed:,}")
print(f"Success rate: {job_result.success_rate:.1%}")
print(f"Processing time: {job_result.processing_time_seconds:.2f} seconds")
print(f"Average quality score: {job_result.average_quality_score:.3f}")
print(f"Error rate: {job_result.error_rate:.1%}")

if job_result.error_message:
    print(f"Error: {job_result.error_message}")

# Job result as JSON (for logging/monitoring)
print("\nJob Result JSON:")
print(job_result.json(indent=2))

## 7. Integration with Databricks Delta Lake

Use Pydantic models with Delta Lake for schema evolution:

In [None]:
class DeltaTableSchema(BaseModel):
    """Schema definition for Delta tables"""
    table_name: str
    database_name: str = "default"
    schema_version: str = "1.0"
    
    # Table properties
    partition_columns: List[str] = []
    z_order_columns: List[str] = []
    table_properties: Dict[str, str] = {}
    
    # Schema fields
    fields: List[Dict[str, str]] = []
    
    # Data quality constraints
    check_constraints: List[str] = []
    
    @validator('table_name', 'database_name')
    def validate_names(cls, v):
        if not v.replace('_', '').replace('-', '').isalnum():
            raise ValueError('Name must be alphanumeric with underscores/hyphens')
        return v.lower()

def create_delta_schema_from_pydantic(model_class: BaseModel) -> DeltaTableSchema:
    """Generate Delta table schema from Pydantic model"""
    schema = model_class.schema()
    
    fields = []
    constraints = []
    
    for field_name, field_info in schema['properties'].items():
        # Map Pydantic types to Spark SQL types
        spark_type = map_pydantic_to_spark_type(field_info)
        
        # Check if field is required
        nullable = field_name not in schema.get('required', [])
        
        field_def = {
            'name': field_name,
            'type': spark_type,
            'nullable': str(nullable).lower()
        }
        
        fields.append(field_def)
        
        # Add constraints based on Pydantic field definitions
        field_constraints = extract_field_constraints(field_name, field_info)
        constraints.extend(field_constraints)
    
    table_name = model_class.__name__.lower().replace('raw', '').replace('features', '')
    
    delta_schema = DeltaTableSchema(
        table_name=table_name,
        fields=fields,
        check_constraints=constraints,
        table_properties={
            'delta.autoOptimize.optimizeWrite': 'true',
            'delta.autoOptimize.autoCompact': 'true'
        }
    )
    
    return delta_schema

def map_pydantic_to_spark_type(field_info: Dict) -> str:
    """Map Pydantic field type to Spark SQL type"""
    field_type = field_info.get('type', 'string')
    
    type_mapping = {
        'integer': 'bigint',
        'number': 'double',
        'string': 'string',
        'boolean': 'boolean',
        'array': 'array<string>',  # Simplified
        'object': 'string',  # JSON string
    }
    
    # Handle date/datetime types
    if field_info.get('format') == 'date-time':
        return 'timestamp'
    elif field_info.get('format') == 'date':
        return 'date'
    
    return type_mapping.get(field_type, 'string')

def extract_field_constraints(field_name: str, field_info: Dict) -> List[str]:
    """Extract check constraints from Pydantic field info"""
    constraints = []
    
    # Minimum value constraints
    if 'minimum' in field_info:
        constraints.append(f"{field_name} >= {field_info['minimum']}")
    
    # Maximum value constraints
    if 'maximum' in field_info:
        constraints.append(f"{field_name} <= {field_info['maximum']}")
    
    # String length constraints
    if 'minLength' in field_info:
        constraints.append(f"length({field_name}) >= {field_info['minLength']}")
    
    if 'maxLength' in field_info:
        constraints.append(f"length({field_name}) <= {field_info['maxLength']}")
    
    # Pattern constraints
    if 'pattern' in field_info:
        # Note: This would need proper regex translation
        constraints.append(f"{field_name} RLIKE '{field_info['pattern']}'")
    
    return constraints

def generate_delta_ddl(schema: DeltaTableSchema) -> str:
    """Generate Delta table DDL from schema"""
    
    # Create table statement
    ddl = f"CREATE TABLE IF NOT EXISTS {schema.database_name}.{schema.table_name} (\n"
    
    # Add fields
    field_definitions = []
    for field in schema.fields:
        nullable = "" if field['nullable'] == 'true' else " NOT NULL"
        field_definitions.append(f"  {field['name']} {field['type']}{nullable}")
    
    ddl += ",\n".join(field_definitions)
    ddl += "\n)"
    
    # Add partitioning
    if schema.partition_columns:
        ddl += f"\nPARTITIONED BY ({', '.join(schema.partition_columns)})"
    
    # Add table properties
    if schema.table_properties:
        props = [f"'{k}' = '{v}'" for k, v in schema.table_properties.items()]
        ddl += f"\nTBLPROPERTIES ({', '.join(props)})"
    
    # Add constraints
    if schema.check_constraints:
        ddl += "\n\n-- Check Constraints\n"
        for i, constraint in enumerate(schema.check_constraints):
            ddl += f"ALTER TABLE {schema.database_name}.{schema.table_name} ADD CONSTRAINT check_{i} CHECK ({constraint});\n"
    
    return ddl

# Generate Delta schemas from our Pydantic models
customer_schema = create_delta_schema_from_pydantic(CustomerRaw)
feature_schema = create_delta_schema_from_pydantic(CustomerFeatures)

print("Delta Lake Schema for CustomerRaw:")
print(customer_schema.json(indent=2))

print("\n" + "="*60)
print("DELTA TABLE DDL for CustomerRaw:")
print("="*60)
customer_ddl = generate_delta_ddl(customer_schema)
print(customer_ddl)

print("\n" + "="*60)
print("DELTA TABLE DDL for CustomerFeatures:")
print("="*60)
feature_ddl = generate_delta_ddl(feature_schema)
print(feature_ddl)

## Summary and Best Practices

This notebook demonstrated comprehensive Databricks integration patterns with Pydantic:

### Key Integration Patterns:
1. **Data Validation**: Robust input validation for ETL pipelines
2. **Feature Engineering**: Type-safe feature transformations
3. **Quality Monitoring**: Automated data quality checks
4. **ML Feature Stores**: Validated ML-ready feature vectors
5. **Job Configuration**: Type-safe job configuration and results
6. **Delta Lake Integration**: Schema evolution and constraint management

### Best Practices for Databricks + Pydantic:

#### 1. **Schema Management**
- Use Pydantic models as single source of truth for schemas
- Generate Delta table schemas from Pydantic models
- Version your data models for schema evolution

#### 2. **Data Quality**
- Implement validation at ingestion time
- Separate valid/invalid records for processing
- Monitor data quality metrics continuously

#### 3. **Error Handling**
- Gracefully handle validation errors
- Log detailed error information for debugging
- Implement retry logic for transient failures

#### 4. **Performance**
- Batch validation for better performance
- Use appropriate data types for Spark
- Consider validation overhead in job planning

#### 5. **Monitoring**
- Track validation success rates
- Monitor feature drift and quality scores
- Set up alerting for quality thresholds

### Next Steps:
1. Implement these patterns in your Databricks workspaces
2. Create reusable libraries for common validation patterns
3. Set up CI/CD pipelines for model versioning
4. Integrate with MLflow for model and feature tracking
5. Build automated data quality dashboards