# Core ML Model Validation Notebook

This notebook validates converted Core ML style transfer models by:
1. Loading and inspecting model metadata
2. Testing inference on sample images
3. Measuring performance (FPS, latency)
4. Comparing visual quality
5. Generating validation reports

In [None]:
import sys
import time
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import json

try:
    import coremltools as ct
    from coremltools.models import MLModel
except ImportError:
    print("Installing coremltools...")
    !pip install coremltools
    import coremltools as ct
    from coremltools.models import MLModel

print(f"Core ML Tools version: {ct.__version__}")

## Configuration

In [None]:
# Model to validate
MODEL_PATH = "../models/exported/sci-fi.mlmodel"  # Change this to your model
TEST_IMAGE_PATH = "../datasets/sample/content/test_001.jpg"  # Test image

# Performance testing
NUM_WARMUP_RUNS = 5
NUM_BENCHMARK_RUNS = 20

# Output directory
OUTPUT_DIR = Path("./validation_results")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

## 1. Load and Inspect Model

In [None]:
def inspect_model(model_path):
    """Inspect Core ML model metadata and structure."""
    print(f"Loading model: {model_path}\n")
    
    model = MLModel(str(model_path))
    spec = model.get_spec()
    
    # Model metadata
    print("=" * 60)
    print("MODEL METADATA")
    print("=" * 60)
    print(f"Author: {model.author}")
    print(f"Description: {model.short_description}")
    print(f"Version: {model.version}")
    
    # Custom metadata
    if spec.description.metadata.userDefined:
        print("\nCustom Metadata:")
        for key, value in spec.description.metadata.userDefined.items():
            print(f"  {key}: {value}")
    
    # Input/Output specs
    print("\n" + "=" * 60)
    print("INPUT/OUTPUT SPECIFICATIONS")
    print("=" * 60)
    
    for inp in spec.description.input:
        print(f"\nInput: {inp.name}")
        print(f"  Type: {inp.type.WhichOneof('Type')}")
        if inp.type.multiArrayType:
            print(f"  Shape: {list(inp.type.multiArrayType.shape)}")
            print(f"  Data Type: {inp.type.multiArrayType.dataType}")
        print(f"  Description: {inp.shortDescription}")
    
    for out in spec.description.output:
        print(f"\nOutput: {out.name}")
        print(f"  Type: {out.type.WhichOneof('Type')}")
        if out.type.multiArrayType:
            print(f"  Shape: {list(out.type.multiArrayType.shape)}")
            print(f"  Data Type: {out.type.multiArrayType.dataType}")
        print(f"  Description: {out.shortDescription}")
    
    return model, spec

model, spec = inspect_model(MODEL_PATH)

## 2. Load and Preprocess Test Image

In [None]:
def load_and_preprocess_image(image_path, target_size=(256, 256), framework='pytorch'):
    """Load and preprocess image for model input."""
    img = Image.open(image_path).convert('RGB')
    img = img.resize(target_size)
    
    img_array = np.array(img).astype(np.float32) / 255.0
    
    if framework == 'pytorch':
        # PyTorch: (B, C, H, W)
        img_array = np.transpose(img_array, (2, 0, 1))
        img_array = np.expand_dims(img_array, axis=0)
    else:
        # TensorFlow: (B, H, W, C)
        img_array = np.expand_dims(img_array, axis=0)
    
    return img, img_array

def display_images(original, styled, title="Style Transfer Result"):
    """Display original and styled images side by side."""
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    axes[0].imshow(original)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    axes[1].imshow(styled)
    axes[1].set_title('Stylized')
    axes[1].axis('off')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# Detect framework from model
framework = spec.description.metadata.userDefined.get('framework', 'pytorch')
print(f"Detected framework: {framework}")

# Get input shape
input_shape = spec.description.input[0].type.multiArrayType.shape
if framework == 'pytorch':
    h, w = int(input_shape[2]), int(input_shape[3])
else:
    h, w = int(input_shape[1]), int(input_shape[2])

print(f"Input size: {h}x{w}")

# Load test image
original_img, input_array = load_and_preprocess_image(
    TEST_IMAGE_PATH,
    target_size=(w, h),
    framework=framework
)

plt.figure(figsize=(6, 6))
plt.imshow(original_img)
plt.title('Original Test Image')
plt.axis('off')
plt.show()

print(f"Input array shape: {input_array.shape}")
print(f"Input array range: [{input_array.min():.3f}, {input_array.max():.3f}]")

## 3. Run Inference

In [None]:
def run_inference(model, input_array):
    """Run model inference."""
    start_time = time.time()
    
    output = model.predict({'input_image': input_array})
    
    inference_time = time.time() - start_time
    
    return output['stylized_image'], inference_time

def postprocess_output(output_array, framework='pytorch'):
    """Convert model output to displayable image."""
    if framework == 'pytorch':
        # (B, C, H, W) -> (H, W, C)
        output_array = np.transpose(output_array[0], (1, 2, 0))
    else:
        # (B, H, W, C) -> (H, W, C)
        output_array = output_array[0]
    
    # Clip and convert to uint8
    output_array = np.clip(output_array * 255, 0, 255).astype(np.uint8)
    
    return Image.fromarray(output_array)

# Run inference
print("Running inference...")
output_array, inference_time = run_inference(model, input_array)

print(f"Inference time: {inference_time*1000:.2f} ms")
print(f"Output shape: {output_array.shape}")
print(f"Output range: [{output_array.min():.3f}, {output_array.max():.3f}]")

# Postprocess and display
styled_img = postprocess_output(output_array, framework)
display_images(original_img, styled_img)

## 4. Performance Benchmarking

In [None]:
def benchmark_model(model, input_array, num_warmup=5, num_runs=20):
    """Benchmark model performance."""
    print(f"Warming up ({num_warmup} runs)...")
    for _ in range(num_warmup):
        _ = model.predict({'input_image': input_array})
    
    print(f"\nBenchmarking ({num_runs} runs)...")
    times = []
    for i in range(num_runs):
        start_time = time.time()
        _ = model.predict({'input_image': input_array})
        elapsed = time.time() - start_time
        times.append(elapsed)
        
        if (i + 1) % 5 == 0:
            print(f"  Progress: {i + 1}/{num_runs}")
    
    times = np.array(times)
    
    results = {
        'mean_ms': np.mean(times) * 1000,
        'std_ms': np.std(times) * 1000,
        'min_ms': np.min(times) * 1000,
        'max_ms': np.max(times) * 1000,
        'median_ms': np.median(times) * 1000,
        'fps': 1.0 / np.mean(times)
    }
    
    return results, times

# Run benchmark
results, times = benchmark_model(
    model,
    input_array,
    num_warmup=NUM_WARMUP_RUNS,
    num_runs=NUM_BENCHMARK_RUNS
)

# Display results
print("\n" + "=" * 60)
print("PERFORMANCE RESULTS")
print("=" * 60)
print(f"Mean inference time: {results['mean_ms']:.2f} ± {results['std_ms']:.2f} ms")
print(f"Median inference time: {results['median_ms']:.2f} ms")
print(f"Min inference time: {results['min_ms']:.2f} ms")
print(f"Max inference time: {results['max_ms']:.2f} ms")
print(f"Average FPS: {results['fps']:.2f}")

# Plot timing distribution
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.hist(times * 1000, bins=20, edgecolor='black')
plt.xlabel('Inference Time (ms)')
plt.ylabel('Frequency')
plt.title('Inference Time Distribution')
plt.axvline(results['mean_ms'], color='r', linestyle='--', label=f"Mean: {results['mean_ms']:.2f} ms")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(times * 1000, marker='o', markersize=3)
plt.xlabel('Run Number')
plt.ylabel('Inference Time (ms)')
plt.title('Inference Time Over Runs')
plt.axhline(results['mean_ms'], color='r', linestyle='--', alpha=0.7)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'performance_benchmark.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Quality Assessment

In [None]:
def assess_quality(original, styled):
    """Assess visual quality metrics."""
    original_np = np.array(original).astype(np.float32)
    styled_np = np.array(styled).astype(np.float32)
    
    # Mean Squared Error
    mse = np.mean((original_np - styled_np) ** 2)
    
    # Peak Signal-to-Noise Ratio
    if mse == 0:
        psnr = float('inf')
    else:
        psnr = 20 * np.log10(255.0 / np.sqrt(mse))
    
    # Color distribution
    original_mean = original_np.mean(axis=(0, 1))
    styled_mean = styled_np.mean(axis=(0, 1))
    
    original_std = original_np.std(axis=(0, 1))
    styled_std = styled_np.std(axis=(0, 1))
    
    return {
        'mse': mse,
        'psnr': psnr,
        'original_color_mean': original_mean.tolist(),
        'styled_color_mean': styled_mean.tolist(),
        'original_color_std': original_std.tolist(),
        'styled_color_std': styled_std.tolist()
    }

quality_metrics = assess_quality(original_img, styled_img)

print("\n" + "=" * 60)
print("QUALITY METRICS")
print("=" * 60)
print(f"Mean Squared Error: {quality_metrics['mse']:.2f}")
print(f"PSNR: {quality_metrics['psnr']:.2f} dB")
print(f"\nOriginal Color Mean (RGB): {quality_metrics['original_color_mean']}")
print(f"Styled Color Mean (RGB): {quality_metrics['styled_color_mean']}")
print(f"\nOriginal Color Std (RGB): {quality_metrics['original_color_std']}")
print(f"Styled Color Std (RGB): {quality_metrics['styled_color_std']}")

## 6. Generate Validation Report

In [None]:
# Compile validation report
validation_report = {
    'model_path': str(MODEL_PATH),
    'test_image': str(TEST_IMAGE_PATH),
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
    'model_info': {
        'author': model.author,
        'description': model.short_description,
        'version': model.version,
        'framework': framework,
        'input_shape': list(input_shape),
    },
    'performance': results,
    'quality': quality_metrics,
    'verdict': {
        'realtime_capable': results['fps'] >= 15,
        'mobile_optimized': results['mean_ms'] <= 100,
        'quality_acceptable': quality_metrics['psnr'] >= 20
    }
}

# Save report
report_path = OUTPUT_DIR / 'validation_report.json'
with open(report_path, 'w') as f:
    json.dump(validation_report, f, indent=2)

print(f"\n✓ Validation report saved to: {report_path}")

# Save output image
output_img_path = OUTPUT_DIR / 'styled_output.jpg'
styled_img.save(output_img_path)
print(f"✓ Styled image saved to: {output_img_path}")

# Print verdict
print("\n" + "=" * 60)
print("VALIDATION VERDICT")
print("=" * 60)
for key, value in validation_report['verdict'].items():
    status = "✓ PASS" if value else "✗ FAIL"
    print(f"{key.replace('_', ' ').title()}: {status}")

print("\n🎉 Validation complete!")