# Lab 04: Advanced Image Classification with Azure Custom Vision

This notebook covers advanced topics in image classification using Azure Custom Vision. Build on the basic concepts to implement production-ready solutions.

## Advanced Topics Covered

1. **Transfer Learning** - Understanding and leveraging pre-trained models
2. **Domain Selection** - Choosing specialized domains for better performance
3. **Data Augmentation** - Expanding training data programmatically
4. **Advanced Metrics** - Confusion matrices and ROC curves
5. **Multi-label Classification** - Assigning multiple tags to single images
6. **Model Optimization** - Improving speed and accuracy
7. **Model Export** - Deploying to edge devices
8. **Batch Predictions** - Efficient processing of multiple images

## 1. Setup and Installation

In [None]:
# Install required packages including image processing libraries
!pip install azure-cognitiveservices-vision-customvision python-dotenv pillow matplotlib numpy scikit-learn seaborn

In [None]:
from azure.cognitiveservices.vision.customvision.training import CustomVisionTrainingClient
from azure.cognitiveservices.vision.customvision.prediction import CustomVisionPredictionClient
from azure.cognitiveservices.vision.customvision.training.models import (
    ImageFileCreateBatch, ImageFileCreateEntry, Classification
)
from msrest.authentication import ApiKeyCredentials
from dotenv import load_dotenv
from PIL import Image, ImageEnhance, ImageOps
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import time
import os
import io
import uuid
from collections import defaultdict

print("Libraries imported successfully!")

## 2. Configure Credentials

In [None]:
env_path = 'python/train-classifier/.env'
load_dotenv(env_path)

training_endpoint = os.getenv('TrainingEndpoint')
training_key = os.getenv('TrainingKey')
prediction_endpoint = os.getenv('PredictionEndpoint', training_endpoint)
prediction_key = os.getenv('PredictionKey', training_key)
project_id = os.getenv('ProjectID', None)

training_credentials = ApiKeyCredentials(in_headers={"Training-key": training_key})
training_client = CustomVisionTrainingClient(training_endpoint, training_credentials)

prediction_credentials = ApiKeyCredentials(in_headers={"Prediction-key": prediction_key})
prediction_client = CustomVisionPredictionClient(prediction_endpoint, prediction_credentials)

print("Clients authenticated successfully!")

## 3. Understanding Transfer Learning and Domains

Azure Custom Vision uses **transfer learning** - it starts with pre-trained models and adapts them to your specific classification task. This allows excellent performance with relatively few training images.

### Available Domains:
- **General**: Default, works well for most scenarios
- **Food**: Optimized for food/meal recognition
- **Landmarks**: Buildings and natural landmarks
- **Retail**: Products and items in retail context
- **General (compact)**: Smaller model for edge deployment
- **Custom**: For specialized scenarios

In [None]:
# List all available domains
domains = training_client.get_domains()

print("Available Classification Domains:")
print("=" * 80)
print(f"{'Name':<25} {'Type':<15} {'Exportable':<12} {'ID'}")
print("-" * 80)

for domain in domains:
    if domain.type == 'Classification':
        exportable = "Yes" if domain.exportable else "No"
        print(f"{domain.name:<25} {domain.type:<15} {exportable:<12} {domain.id}")

# Select the Food domain for our fruit classification
food_domain = next((d for d in domains if 'food' in d.name.lower()), None)

if food_domain:
    print(f"\n✓ Selected domain: {food_domain.name}")
    print(f"  This domain is optimized for food classification tasks.")
else:
    print("\n⚠️  Food domain not found, using General domain")
    food_domain = next(d for d in domains if 'general' in d.name.lower() and d.type == 'Classification')

## 4. Create Project with Custom Domain

When creating a project, choosing the right domain can significantly improve performance.

In [None]:
# Create project with Food domain
project_name = f"Advanced Fruit Classification {uuid.uuid4().hex[:8]}"

project = training_client.create_project(
    name=project_name,
    description="Advanced fruit classification with optimized domain",
    domain_id=food_domain.id,
    classification_type="Multiclass"  # Single label per image
)

print(f"✓ Created project: {project.name}")
print(f"  Project ID: {project.id}")
print(f"  Domain: {food_domain.name}")
print(f"  Classification Type: Multiclass")

## 5. Data Augmentation Techniques

Data augmentation artificially expands your training dataset by creating modified versions of existing images. This helps the model generalize better and prevents overfitting.

### Common Augmentation Techniques:
- **Rotation**: Rotate images by various angles
- **Flipping**: Horizontal and vertical flips
- **Brightness adjustment**: Simulate different lighting conditions
- **Contrast adjustment**: Enhance or reduce contrast
- **Color adjustment**: Modify color saturation
- **Cropping**: Random crops to focus on different areas

In [None]:
def augment_image(image, augmentation_type):
    """
    Apply data augmentation to an image.
    
    Args:
        image: PIL Image object
        augmentation_type: Type of augmentation to apply
    
    Returns:
        Augmented PIL Image
    """
    if augmentation_type == 'rotate_90':
        return image.rotate(90, expand=True)
    elif augmentation_type == 'rotate_270':
        return image.rotate(270, expand=True)
    elif augmentation_type == 'flip_horizontal':
        return ImageOps.mirror(image)
    elif augmentation_type == 'flip_vertical':
        return ImageOps.flip(image)
    elif augmentation_type == 'brightness_increase':
        enhancer = ImageEnhance.Brightness(image)
        return enhancer.enhance(1.3)
    elif augmentation_type == 'brightness_decrease':
        enhancer = ImageEnhance.Brightness(image)
        return enhancer.enhance(0.7)
    elif augmentation_type == 'contrast_increase':
        enhancer = ImageEnhance.Contrast(image)
        return enhancer.enhance(1.3)
    elif augmentation_type == 'contrast_decrease':
        enhancer = ImageEnhance.Contrast(image)
        return enhancer.enhance(0.7)
    elif augmentation_type == 'saturation_increase':
        enhancer = ImageEnhance.Color(image)
        return enhancer.enhance(1.3)
    elif augmentation_type == 'saturation_decrease':
        enhancer = ImageEnhance.Color(image)
        return enhancer.enhance(0.7)
    else:
        return image

# Demonstrate augmentation on a sample image
sample_path = 'training-images/apple'
if os.path.exists(sample_path):
    sample_files = [f for f in os.listdir(sample_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    if sample_files:
        sample_image = Image.open(os.path.join(sample_path, sample_files[0]))
        
        # Show different augmentations
        augmentation_types = ['rotate_90', 'flip_horizontal', 'brightness_increase', 
                            'contrast_increase', 'saturation_increase']
        
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        axes[0].imshow(sample_image)
        axes[0].set_title('Original')
        axes[0].axis('off')
        
        for idx, aug_type in enumerate(augmentation_types, 1):
            augmented = augment_image(sample_image, aug_type)
            axes[idx].imshow(augmented)
            axes[idx].set_title(aug_type.replace('_', ' ').title())
            axes[idx].axis('off')
        
        plt.tight_layout()
        plt.suptitle('Data Augmentation Examples', y=1.02, fontsize=16)
        plt.show()

## 6. Upload Images with Augmentation

Now let's upload training images along with augmented versions to increase our training dataset size.

In [None]:
# Create tags
tag_names = ['apple', 'banana', 'orange']
tags = {}

for tag_name in tag_names:
    tag = training_client.create_tag(project.id, tag_name)
    tags[tag_name] = tag
    print(f"Created tag: {tag_name}")

def upload_images_with_augmentation(folder_path, project_id, tags_dict, augment=True, max_augmentations=3):
    """
    Upload images with optional data augmentation.
    
    Args:
        folder_path: Path to images folder
        project_id: Custom Vision project ID
        tags_dict: Dictionary of tag objects
        augment: Whether to apply data augmentation
        max_augmentations: Maximum number of augmented versions per image
    """
    augmentation_types = ['flip_horizontal', 'brightness_increase', 'brightness_decrease',
                         'contrast_increase', 'saturation_increase']
    
    for tag_name, tag in tags_dict.items():
        tag_folder = os.path.join(folder_path, tag_name)
        
        if not os.path.exists(tag_folder):
            continue
        
        image_files = [f for f in os.listdir(tag_folder) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
        
        total_uploaded = 0
        
        for image_file in image_files:
            image_path = os.path.join(tag_folder, image_file)
            
            # Upload original image
            with open(image_path, "rb") as image_data:
                training_client.create_images_from_data(
                    project_id, 
                    image_data.read(), 
                    [tag.id]
                )
            total_uploaded += 1
            
            # Upload augmented versions
            if augment:
                original_image = Image.open(image_path)
                
                for aug_idx, aug_type in enumerate(augmentation_types[:max_augmentations]):
                    augmented_image = augment_image(original_image, aug_type)
                    
                    # Convert to bytes
                    img_byte_arr = io.BytesIO()
                    augmented_image.save(img_byte_arr, format='JPEG')
                    img_byte_arr.seek(0)
                    
                    training_client.create_images_from_data(
                        project_id,
                        img_byte_arr.read(),
                        [tag.id]
                    )
                    total_uploaded += 1
        
        print(f"✓ Uploaded {total_uploaded} images (including augmented) for tag '{tag_name}'")

# Upload with augmentation
print("Uploading images with data augmentation...\n")
upload_images_with_augmentation('training-images', project.id, tags, augment=True, max_augmentations=3)

## 7. Train and Publish Model

In [None]:
print("Training model with augmented dataset...\n")

iteration = training_client.train_project(project.id)

while iteration.status != "Completed":
    iteration = training_client.get_iteration(project.id, iteration.id)
    print(f"Training status: {iteration.status}")
    time.sleep(5)

print(f"\n✓ Model trained successfully!")

# Publish the model
publish_name = "AdvancedFruitClassifier"
prediction_resource_id = os.getenv('PredictionResourceId', None)

try:
    if prediction_resource_id:
        training_client.publish_iteration(project.id, iteration.id, publish_name, prediction_resource_id)
    else:
        training_client.publish_iteration(project.id, iteration.id, publish_name)
    print(f"✓ Model published as '{publish_name}'")
except Exception as e:
    print(f"Publishing info: {e}")

## 8. Advanced Performance Evaluation

Beyond basic precision and recall, let's examine advanced evaluation metrics including confusion matrices and per-class performance.

In [None]:
def evaluate_model_advanced(project_id, iteration_id, training_client):
    """
    Perform advanced model evaluation with detailed metrics.
    """
    performance = training_client.get_iteration_performance(project_id, iteration_id)
    
    print("Advanced Performance Metrics")
    print("=" * 70)
    
    # Overall metrics
    print(f"\nOverall Performance:")
    print(f"  Precision:          {performance.precision:.4f} ({performance.precision:.2%})")
    print(f"  Recall:             {performance.recall:.4f} ({performance.recall:.2%})")
    print(f"  Average Precision:  {performance.average_precision:.4f} ({performance.average_precision:.2%})")
    
    # Calculate F1 score
    if performance.precision + performance.recall > 0:
        f1_score = 2 * (performance.precision * performance.recall) / (performance.precision + performance.recall)
        print(f"  F1 Score:           {f1_score:.4f} ({f1_score:.2%})")
    
    # Per-tag detailed metrics
    print(f"\nPer-Tag Performance:")
    print("-" * 70)
    print(f"{'Tag':<15} {'Precision':<12} {'Recall':<12} {'AP':<12} {'F1':<12}")
    print("-" * 70)
    
    tag_metrics = []
    for tag_perf in performance.per_tag_performance:
        if tag_perf.precision + tag_perf.recall > 0:
            f1 = 2 * (tag_perf.precision * tag_perf.recall) / (tag_perf.precision + tag_perf.recall)
        else:
            f1 = 0
        
        print(f"{tag_perf.name:<15} {tag_perf.precision:<12.2%} {tag_perf.recall:<12.2%} "
              f"{tag_perf.average_precision:<12.2%} {f1:<12.2%}")
        
        tag_metrics.append({
            'name': tag_perf.name,
            'precision': tag_perf.precision,
            'recall': tag_perf.recall,
            'f1': f1
        })
    
    # Visualize metrics comparison
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Bar chart of metrics
    tag_names_eval = [tm['name'] for tm in tag_metrics]
    precisions = [tm['precision'] for tm in tag_metrics]
    recalls = [tm['recall'] for tm in tag_metrics]
    f1_scores = [tm['f1'] for tm in tag_metrics]
    
    x = np.arange(len(tag_names_eval))
    width = 0.25
    
    axes[0].bar(x - width, precisions, width, label='Precision', color='skyblue')
    axes[0].bar(x, recalls, width, label='Recall', color='lightcoral')
    axes[0].bar(x + width, f1_scores, width, label='F1 Score', color='lightgreen')
    axes[0].set_xlabel('Tags')
    axes[0].set_ylabel('Score')
    axes[0].set_title('Performance Metrics by Tag')
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(tag_names_eval)
    axes[0].legend()
    axes[0].set_ylim([0, 1.1])
    axes[0].grid(axis='y', alpha=0.3)
    
    # Radar chart
    if len(tag_metrics) >= 3:
        angles = np.linspace(0, 2 * np.pi, len(tag_names_eval), endpoint=False).tolist()
        precisions_plot = precisions + [precisions[0]]
        recalls_plot = recalls + [recalls[0]]
        angles_plot = angles + [angles[0]]
        
        ax = plt.subplot(122, projection='polar')
        ax.plot(angles_plot, precisions_plot, 'o-', linewidth=2, label='Precision', color='skyblue')
        ax.fill(angles_plot, precisions_plot, alpha=0.25, color='skyblue')
        ax.plot(angles_plot, recalls_plot, 'o-', linewidth=2, label='Recall', color='lightcoral')
        ax.fill(angles_plot, recalls_plot, alpha=0.25, color='lightcoral')
        ax.set_xticks(angles)
        ax.set_xticklabels(tag_names_eval)
        ax.set_ylim(0, 1)
        ax.set_title('Performance Radar Chart')
        ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
        ax.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    return performance

performance = evaluate_model_advanced(project.id, iteration.id, training_client)

## 9. Confusion Matrix Analysis

A confusion matrix helps visualize where the model makes mistakes. Let's create one by testing on multiple images.

In [None]:
def create_confusion_matrix(test_folder, project_id, publish_name, prediction_client, tags_dict):
    """
    Create confusion matrix by testing on images organized in folders by true label.
    """
    y_true = []
    y_pred = []
    
    print("Testing model and building confusion matrix...\n")
    
    for tag_name in tags_dict.keys():
        tag_folder = os.path.join(test_folder, tag_name)
        
        if not os.path.exists(tag_folder):
            continue
        
        test_images = [f for f in os.listdir(tag_folder) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
        
        for test_image in test_images:
            image_path = os.path.join(tag_folder, test_image)
            
            with open(image_path, "rb") as image_data:
                results = prediction_client.classify_image(project_id, publish_name, image_data.read())
            
            # Get top prediction
            if results.predictions:
                top_prediction = max(results.predictions, key=lambda p: p.probability)
                y_true.append(tag_name)
                y_pred.append(top_prediction.tag_name)
    
    if not y_true:
        print("No test images found")
        return
    
    # Create confusion matrix
    labels = sorted(tags_dict.keys())
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=labels, yticklabels=labels,
                cbar_kws={'label': 'Count'})
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()
    
    # Print classification report
    print("\nClassification Report:")
    print("=" * 70)
    print(classification_report(y_true, y_pred, labels=labels, target_names=labels))
    
    # Calculate and display accuracy
    accuracy = np.trace(cm) / np.sum(cm)
    print(f"\nOverall Accuracy: {accuracy:.2%}")
    
    return cm

# Create confusion matrix if test images are available
test_images_path = 'test-images'
if os.path.exists(test_images_path):
    cm = create_confusion_matrix(test_images_path, project.id, publish_name, prediction_client, tags)
else:
    print(f"Test images not found at {test_images_path}")
    print("Organize test images in folders by category to generate confusion matrix")

## 10. Multi-Label Classification

Unlike multi-class classification (one label per image), **multi-label classification** allows assigning multiple tags to a single image.

For example, an image might contain both an apple and a banana.

### Creating a Multi-Label Project:

In [None]:
# Create a multi-label classification project
multilabel_project_name = f"Multi-Label Fruit {uuid.uuid4().hex[:8]}"

multilabel_project = training_client.create_project(
    name=multilabel_project_name,
    description="Multi-label fruit classification - images can have multiple fruit types",
    domain_id=food_domain.id,
    classification_type="Multilabel"  # Key difference!
)

print(f"✓ Created multi-label project: {multilabel_project.name}")
print(f"  Project ID: {multilabel_project.id}")
print(f"  Classification Type: Multilabel")
print(f"\nIn this project, images can be tagged with multiple labels simultaneously.")

## 11. Batch Predictions

For processing multiple images efficiently, use batch predictions instead of individual calls.

In [None]:
def batch_predict_images(image_folder, project_id, publish_name, prediction_client, 
                        confidence_threshold=0.5, max_images=10):
    """
    Perform batch predictions on multiple images.
    
    Args:
        image_folder: Path to folder containing images
        project_id: Custom Vision project ID
        publish_name: Published model name
        prediction_client: Authenticated prediction client
        confidence_threshold: Minimum confidence to report
        max_images: Maximum number of images to process
    
    Returns:
        Dictionary of results
    """
    results_dict = defaultdict(list)
    
    image_files = [f for f in os.listdir(image_folder) 
                  if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
    
    print(f"Processing {min(len(image_files), max_images)} images...\n")
    
    start_time = time.time()
    
    for idx, image_file in enumerate(image_files[:max_images]):
        image_path = os.path.join(image_folder, image_file)
        
        with open(image_path, "rb") as image_data:
            results = prediction_client.classify_image(
                project_id, 
                publish_name, 
                image_data.read()
            )
        
        # Store results
        top_prediction = max(results.predictions, key=lambda p: p.probability)
        
        if top_prediction.probability >= confidence_threshold:
            results_dict[top_prediction.tag_name].append({
                'file': image_file,
                'confidence': top_prediction.probability
            })
            status = "✓"
        else:
            results_dict['low_confidence'].append({
                'file': image_file,
                'confidence': top_prediction.probability,
                'predicted': top_prediction.tag_name
            })
            status = "⚠"
        
        print(f"{status} {idx+1}/{min(len(image_files), max_images)}: {image_file} -> "
              f"{top_prediction.tag_name} ({top_prediction.probability:.2%})")
    
    elapsed_time = time.time() - start_time
    
    # Summary statistics
    print(f"\n" + "=" * 70)
    print(f"Batch Processing Summary")
    print("=" * 70)
    print(f"Total images processed: {min(len(image_files), max_images)}")
    print(f"Total time: {elapsed_time:.2f} seconds")
    print(f"Average time per image: {elapsed_time / min(len(image_files), max_images):.2f} seconds")
    print(f"\nResults by category:")
    
    for category, items in results_dict.items():
        if category != 'low_confidence':
            avg_conf = np.mean([item['confidence'] for item in items])
            print(f"  {category}: {len(items)} images (avg confidence: {avg_conf:.2%})")
    
    if 'low_confidence' in results_dict:
        print(f"  Low confidence: {len(results_dict['low_confidence'])} images")
    
    return dict(results_dict)

# Run batch predictions
if os.path.exists(test_images_path):
    batch_results = batch_predict_images(
        test_images_path, 
        project.id, 
        publish_name, 
        prediction_client,
        confidence_threshold=0.7
    )

## 12. Model Export for Edge Deployment

Azure Custom Vision allows exporting models for edge deployment, enabling offline predictions on devices.

### Supported Export Formats:
- **TensorFlow**: For general ML applications
- **CoreML**: For iOS devices
- **ONNX**: Cross-platform format
- **Dockerfile**: Containerized deployment

**Note**: Only models trained with "compact" domains can be exported.

In [None]:
def export_model(project_id, iteration_id, platform, training_client):
    """
    Export model for edge deployment.
    
    Args:
        project_id: Custom Vision project ID
        iteration_id: Iteration to export
        platform: Export platform (e.g., 'TensorFlow', 'CoreML', 'ONNX')
        training_client: Authenticated training client
    """
    try:
        print(f"Attempting to export model to {platform} format...")
        
        # Check if iteration can be exported
        iteration_details = training_client.get_iteration(project_id, iteration_id)
        
        if not iteration_details.exportable:
            print(f"\n⚠️  This iteration cannot be exported.")
            print(f"To export models, create a project with a 'compact' domain.")
            print(f"\nAvailable compact domains:")
            
            domains = training_client.get_domains()
            for domain in domains:
                if 'compact' in domain.name.lower() and domain.exportable:
                    print(f"  - {domain.name}")
            return
        
        # Export the model
        export = training_client.export_iteration(project_id, iteration_id, platform)
        
        # Wait for export to complete
        while export.status == "Exporting":
            print(f"Export status: {export.status}...")
            time.sleep(5)
            export = training_client.get_exports(project_id, iteration_id)[0]
        
        if export.status == "Done":
            print(f"\n✓ Model exported successfully!")
            print(f"  Download URL: {export.download_uri}")
            print(f"  Platform: {export.platform}")
            print(f"\n⚠️  Download link expires after a certain time period.")
        else:
            print(f"\n❌ Export failed with status: {export.status}")
    
    except Exception as e:
        print(f"Export error: {e}")

# Attempt to export (will show instructions if not using compact domain)
export_model(project.id, iteration.id, 'TensorFlow', training_client)

## 13. Model Optimization Strategies

### Probability Threshold Tuning

Adjusting the probability threshold affects the precision-recall tradeoff:
- **Higher threshold** (e.g., 0.8): Higher precision, lower recall (fewer false positives)
- **Lower threshold** (e.g., 0.3): Higher recall, lower precision (fewer false negatives)

In [None]:
def analyze_threshold_impact(test_folder, project_id, publish_name, prediction_client, tags_dict):
    """
    Analyze the impact of different probability thresholds on classification.
    """
    # Collect all predictions
    all_predictions = []
    
    for tag_name in tags_dict.keys():
        tag_folder = os.path.join(test_folder, tag_name)
        
        if not os.path.exists(tag_folder):
            continue
        
        test_images = [f for f in os.listdir(tag_folder) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
        
        for test_image in test_images[:5]:  # Limit for demo
            image_path = os.path.join(tag_folder, test_image)
            
            with open(image_path, "rb") as image_data:
                results = prediction_client.classify_image(project_id, publish_name, image_data.read())
            
            if results.predictions:
                top_pred = max(results.predictions, key=lambda p: p.probability)
                all_predictions.append({
                    'true_label': tag_name,
                    'predicted_label': top_pred.tag_name,
                    'confidence': top_pred.probability
                })
    
    if not all_predictions:
        print("No predictions to analyze")
        return
    
    # Test different thresholds
    thresholds = [0.3, 0.5, 0.7, 0.9]
    results_by_threshold = []
    
    for threshold in thresholds:
        correct = 0
        total = 0
        rejected = 0
        
        for pred in all_predictions:
            if pred['confidence'] >= threshold:
                total += 1
                if pred['true_label'] == pred['predicted_label']:
                    correct += 1
            else:
                rejected += 1
        
        accuracy = correct / total if total > 0 else 0
        acceptance_rate = total / len(all_predictions)
        
        results_by_threshold.append({
            'threshold': threshold,
            'accuracy': accuracy,
            'acceptance_rate': acceptance_rate,
            'rejected': rejected
        })
    
    # Display results
    print("Threshold Impact Analysis")
    print("=" * 70)
    print(f"{'Threshold':<12} {'Accuracy':<12} {'Acceptance':<15} {'Rejected'}")
    print("-" * 70)
    
    for result in results_by_threshold:
        print(f"{result['threshold']:<12.1f} {result['accuracy']:<12.2%} "
              f"{result['acceptance_rate']:<15.2%} {result['rejected']}")
    
    # Visualize
    fig, ax1 = plt.subplots(figsize=(10, 6))
    
    ax1.set_xlabel('Confidence Threshold')
    ax1.set_ylabel('Accuracy', color='blue')
    ax1.plot([r['threshold'] for r in results_by_threshold],
             [r['accuracy'] for r in results_by_threshold],
             'o-', color='blue', linewidth=2, markersize=8, label='Accuracy')
    ax1.tick_params(axis='y', labelcolor='blue')
    
    ax2 = ax1.twinx()
    ax2.set_ylabel('Acceptance Rate', color='red')
    ax2.plot([r['threshold'] for r in results_by_threshold],
             [r['acceptance_rate'] for r in results_by_threshold],
             's-', color='red', linewidth=2, markersize=8, label='Acceptance Rate')
    ax2.tick_params(axis='y', labelcolor='red')
    
    plt.title('Threshold Impact: Accuracy vs Acceptance Rate')
    fig.tight_layout()
    plt.grid(True, alpha=0.3)
    plt.show()

# Analyze threshold impact
if os.path.exists(test_images_path):
    analyze_threshold_impact(test_images_path, project.id, publish_name, prediction_client, tags)

## 14. Production Deployment Best Practices

### Key Considerations:

1. **Model Versioning**: Track iterations and publish names
2. **Monitoring**: Log predictions and confidence scores
3. **Fallback Logic**: Handle low-confidence predictions
4. **Continuous Improvement**: Collect production data for retraining
5. **Error Handling**: Implement retry logic and timeouts
6. **Performance**: Use batch predictions when possible
7. **Security**: Protect API keys, use managed identities
8. **Cost Management**: Monitor API usage and optimize calls

### Example Production Code Structure:

In [None]:
class ProductionClassifier:
    """
    Production-ready image classifier with error handling and logging.
    """
    
    def __init__(self, endpoint, key, project_id, model_name, 
                 confidence_threshold=0.7, retry_attempts=3):
        self.endpoint = endpoint
        self.project_id = project_id
        self.model_name = model_name
        self.confidence_threshold = confidence_threshold
        self.retry_attempts = retry_attempts
        
        credentials = ApiKeyCredentials(in_headers={"Prediction-key": key})
        self.client = CustomVisionPredictionClient(endpoint, credentials)
        
        # Tracking
        self.prediction_count = 0
        self.low_confidence_count = 0
        self.error_count = 0
    
    def classify(self, image_data, return_all_predictions=False):
        """
        Classify an image with error handling and retry logic.
        
        Args:
            image_data: Image bytes
            return_all_predictions: Whether to return all predictions or just top one
        
        Returns:
            Dictionary with prediction results and metadata
        """
        for attempt in range(self.retry_attempts):
            try:
                results = self.client.classify_image(
                    self.project_id, 
                    self.model_name, 
                    image_data
                )
                
                self.prediction_count += 1
                
                if not results.predictions:
                    return {
                        'success': False,
                        'error': 'No predictions returned',
                        'confidence': 0
                    }
                
                top_prediction = max(results.predictions, key=lambda p: p.probability)
                
                # Check confidence
                if top_prediction.probability < self.confidence_threshold:
                    self.low_confidence_count += 1
                
                if return_all_predictions:
                    predictions = [
                        {'label': p.tag_name, 'confidence': p.probability}
                        for p in results.predictions
                    ]
                else:
                    predictions = {
                        'label': top_prediction.tag_name,
                        'confidence': top_prediction.probability
                    }
                
                return {
                    'success': True,
                    'prediction': predictions,
                    'meets_threshold': top_prediction.probability >= self.confidence_threshold,
                    'iteration_id': results.iteration
                }
            
            except Exception as e:
                self.error_count += 1
                if attempt == self.retry_attempts - 1:
                    return {
                        'success': False,
                        'error': str(e),
                        'attempts': attempt + 1
                    }
                time.sleep(1 * (attempt + 1))  # Exponential backoff
    
    def get_stats(self):
        """Get classifier statistics."""
        return {
            'total_predictions': self.prediction_count,
            'low_confidence': self.low_confidence_count,
            'low_confidence_rate': self.low_confidence_count / max(1, self.prediction_count),
            'errors': self.error_count,
            'error_rate': self.error_count / max(1, self.prediction_count + self.error_count)
        }

# Example usage
print("Production Classifier Example:")
print("=" * 70)

prod_classifier = ProductionClassifier(
    prediction_endpoint,
    prediction_key,
    project.id,
    publish_name,
    confidence_threshold=0.7
)

# Test with a few images
if os.path.exists(test_images_path):
    test_files = [f for f in os.listdir(test_images_path) 
                  if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:3]
    
    for test_file in test_files:
        with open(os.path.join(test_images_path, test_file), 'rb') as f:
            image_data = f.read()
        
        result = prod_classifier.classify(image_data)
        
        if result['success']:
            pred = result['prediction']
            threshold_met = "✓" if result['meets_threshold'] else "⚠"
            print(f"{threshold_met} {test_file}: {pred['label']} ({pred['confidence']:.2%})")
        else:
            print(f"❌ {test_file}: Error - {result['error']}")
    
    # Print statistics
    stats = prod_classifier.get_stats()
    print(f"\nClassifier Statistics:")
    print(f"  Total predictions: {stats['total_predictions']}")
    print(f"  Low confidence: {stats['low_confidence']} ({stats['low_confidence_rate']:.1%})")
    print(f"  Errors: {stats['errors']} ({stats['error_rate']:.1%})")

## 15. Summary

### Advanced Concepts Covered:
✓ Transfer learning and domain selection  
✓ Data augmentation for improved performance  
✓ Advanced evaluation metrics (F1, confusion matrix)  
✓ Multi-label vs multi-class classification  
✓ Batch predictions for efficiency  
✓ Model export for edge deployment  
✓ Threshold tuning for precision-recall tradeoff  
✓ Production-ready deployment patterns  

### Key Takeaways:
1. **Domain selection** can significantly impact model performance
2. **Data augmentation** helps with limited training data
3. **Evaluation metrics** beyond accuracy provide deeper insights
4. **Threshold tuning** allows balancing precision and recall
5. **Production deployment** requires error handling and monitoring

### Next Steps:
- Explore Lab 05: Object Detection for locating and classifying objects
- Experiment with different domains and augmentation strategies
- Implement active learning to improve models over time
- Deploy models to edge devices using exported formats

## Cleanup

In [None]:
# Uncomment to delete projects
# training_client.delete_project(project.id)
# training_client.delete_project(multilabel_project.id)
# print("Projects deleted")