# Lab 04: Image Classification with Azure Custom Vision

In this lab, you'll learn how to train a custom image classification model using Azure Custom Vision. Image classification is the task of assigning a label or category to an entire image based on its content.

## Learning Objectives

By the end of this lab, you will be able to:
- Create and configure an Azure Custom Vision project for image classification
- Upload and tag training images
- Train a custom classification model
- Test and evaluate model predictions
- Iterate and improve model performance

## Prerequisites

- Azure subscription with Custom Vision resource created
- Training and prediction keys from Azure portal
- Training images organized by category

## 1. Setup and Installation

First, let's install the required packages for working with Azure Custom Vision.

In [None]:
# Install required packages
!pip install azure-cognitiveservices-vision-customvision python-dotenv pillow matplotlib

## 2. Import Required Libraries

Import the necessary libraries for authentication, training, and prediction.

In [None]:
from azure.cognitiveservices.vision.customvision.training import CustomVisionTrainingClient
from azure.cognitiveservices.vision.customvision.prediction import CustomVisionPredictionClient
from azure.cognitiveservices.vision.customvision.training.models import ImageFileCreateBatch, ImageFileCreateEntry
from msrest.authentication import ApiKeyCredentials
from dotenv import load_dotenv
from PIL import Image
import matplotlib.pyplot as plt
import time
import os
import uuid

print("Libraries imported successfully!")

## 3. Configure Azure Custom Vision Credentials

Load your Azure Custom Vision credentials from a `.env` file. This file should contain:
- `TrainingEndpoint`: Your Custom Vision training endpoint
- `TrainingKey`: Your Custom Vision training key
- `PredictionEndpoint`: Your Custom Vision prediction endpoint
- `PredictionKey`: Your Custom Vision prediction key
- `ProjectID`: (Optional) Existing project ID, or leave blank to create new

In [None]:
# Load environment variables from .env file
# You can use the .env file from python/train-classifier/ directory
env_path = 'python/train-classifier/.env'
load_dotenv(env_path)

# Get configuration settings
training_endpoint = os.getenv('TrainingEndpoint')
training_key = os.getenv('TrainingKey')
prediction_endpoint = os.getenv('PredictionEndpoint', training_endpoint)
prediction_key = os.getenv('PredictionKey', training_key)
project_id = os.getenv('ProjectID', None)

print(f"Training Endpoint: {training_endpoint}")
print(f"Prediction Endpoint: {prediction_endpoint}")
print(f"Project ID: {project_id if project_id else 'Will create new project'}")

## 4. Authenticate Training and Prediction Clients

Create authenticated clients for both training and prediction operations.

In [None]:
# Authenticate training client
training_credentials = ApiKeyCredentials(in_headers={"Training-key": training_key})
training_client = CustomVisionTrainingClient(training_endpoint, training_credentials)

# Authenticate prediction client
prediction_credentials = ApiKeyCredentials(in_headers={"Prediction-key": prediction_key})
prediction_client = CustomVisionPredictionClient(prediction_endpoint, prediction_credentials)

print("Clients authenticated successfully!")

## 5. Create or Get Custom Vision Project

Create a new Custom Vision project for image classification, or connect to an existing one.

In [None]:
# Create new project or get existing one
if project_id:
    # Get existing project
    print(f"Connecting to existing project: {project_id}")
    project = training_client.get_project(project_id)
else:
    # Create new project
    project_name = f"Fruit Classification {uuid.uuid4().hex[:8]}"
    print(f"Creating new project: {project_name}")
    project = training_client.create_project(
        name=project_name,
        description="Classification of fruits (apple, banana, orange)",
        domain_id=None  # Use default domain
    )
    project_id = project.id

print(f"\nProject Details:")
print(f"  Name: {project.name}")
print(f"  ID: {project.id}")
print(f"  Description: {project.settings.description}")
print(f"\n⚠️  Save this Project ID for future use: {project.id}")

## 6. Create Tags for Classification

Create tags (labels) for each category of images you want to classify. For this lab, we'll classify fruits: apple, banana, and orange.

In [None]:
# Define the classification categories
tag_names = ['apple', 'banana', 'orange']

# Get existing tags or create new ones
existing_tags = training_client.get_tags(project.id)
existing_tag_names = [tag.name for tag in existing_tags]

tags = {}
for tag_name in tag_names:
    if tag_name in existing_tag_names:
        # Get existing tag
        tag = next(t for t in existing_tags if t.name == tag_name)
        print(f"Found existing tag: {tag_name}")
    else:
        # Create new tag
        tag = training_client.create_tag(project.id, tag_name)
        print(f"Created new tag: {tag_name}")
    tags[tag_name] = tag

print(f"\nTotal tags: {len(tags)}")

## 7. Upload and Tag Training Images

Upload training images from the `training-images/` directory. Each subdirectory represents a category, and images within it will be tagged accordingly.

**Note**: Azure Custom Vision requires at least 5 images per tag for training.

In [None]:
def upload_images_from_folder(folder_path, project_id, tags_dict):
    """
    Upload images from a folder structure where each subdirectory is a tag name.
    
    Args:
        folder_path: Path to the parent folder containing category subdirectories
        project_id: Custom Vision project ID
        tags_dict: Dictionary mapping tag names to tag objects
    """
    print(f"Uploading images from: {folder_path}\n")
    
    for tag_name, tag in tags_dict.items():
        tag_folder = os.path.join(folder_path, tag_name)
        
        if not os.path.exists(tag_folder):
            print(f"⚠️  Warning: Folder not found for tag '{tag_name}': {tag_folder}")
            continue
        
        # Get all image files in the folder
        image_files = [f for f in os.listdir(tag_folder) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
        
        if len(image_files) < 5:
            print(f"⚠️  Warning: Only {len(image_files)} images found for '{tag_name}'. "
                  f"At least 5 images are recommended.")
        
        # Upload images in batches (max 64 per batch)
        batch_size = 64
        uploaded_count = 0
        
        for i in range(0, len(image_files), batch_size):
            batch_files = image_files[i:i + batch_size]
            image_list = []
            
            for image_file in batch_files:
                image_path = os.path.join(tag_folder, image_file)
                with open(image_path, "rb") as image_data:
                    image_list.append(ImageFileCreateEntry(
                        name=image_file,
                        contents=image_data.read(),
                        tag_ids=[tag.id]
                    ))
            
            # Upload the batch
            upload_result = training_client.create_images_from_files(
                project_id, 
                ImageFileCreateBatch(images=image_list)
            )
            
            if upload_result.is_batch_successful:
                uploaded_count += len(batch_files)
            else:
                print(f"  ⚠️  Some images in batch failed to upload for '{tag_name}'")
                for image in upload_result.images:
                    if image.status != "OK":
                        print(f"    - {image.source_url}: {image.status}")
        
        print(f"✓ Uploaded {uploaded_count} images for tag '{tag_name}'")
    
    print("\nAll images uploaded successfully!")

# Upload training images
training_images_path = 'training-images'
upload_images_from_folder(training_images_path, project.id, tags)

## 8. Verify Uploaded Images

Check that images were uploaded correctly and display summary statistics.

In [None]:
# Get image count for each tag
print("Image Count Summary:")
print("-" * 40)

total_images = 0
for tag_name, tag in tags.items():
    tag_info = training_client.get_tag(project.id, tag.id)
    print(f"{tag_name:15} : {tag_info.image_count} images")
    total_images += tag_info.image_count

print("-" * 40)
print(f"{'Total':15} : {total_images} images")

# Display a few sample images
print("\nSample Training Images:")
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, tag_name in enumerate(tag_names):
    tag_folder = os.path.join(training_images_path, tag_name)
    if os.path.exists(tag_folder):
        image_files = [f for f in os.listdir(tag_folder) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
        if image_files:
            sample_image = Image.open(os.path.join(tag_folder, image_files[0]))
            axes[idx].imshow(sample_image)
            axes[idx].set_title(tag_name.capitalize())
            axes[idx].axis('off')

plt.tight_layout()
plt.show()

## 9. Train the Classification Model

Train the image classification model using the uploaded training images. This process may take several minutes.

**Training Process:**
1. The model learns patterns from your training images
2. It associates visual features with each tag
3. The training status is monitored until completion

In [None]:
def train_model(project_id, training_client):
    """
    Train the Custom Vision model and monitor progress.
    
    Args:
        project_id: Custom Vision project ID
        training_client: Authenticated training client
    
    Returns:
        Completed iteration object
    """
    print("Starting model training...")
    print("This may take several minutes. Please wait...\n")
    
    # Start training
    iteration = training_client.train_project(project_id)
    
    # Monitor training status
    while iteration.status != "Completed":
        iteration = training_client.get_iteration(project_id, iteration.id)
        print(f"Training status: {iteration.status}")
        
        if iteration.status == "Failed":
            print("❌ Training failed!")
            return None
        
        time.sleep(5)
    
    print(f"\n✓ Model trained successfully!")
    print(f"  Iteration Name: {iteration.name}")
    print(f"  Iteration ID: {iteration.id}")
    
    return iteration

# Train the model
iteration = train_model(project.id, training_client)

## 10. Publish the Model for Prediction

Publish the trained iteration to make it available for predictions. You'll need to specify a publish name (also called model name).

In [None]:
# Check if we need a prediction resource ID
# For newer Custom Vision resources, you may need to provide the prediction resource ID
publish_name = "FruitClassifier"
prediction_resource_id = os.getenv('PredictionResourceId', None)

if iteration:
    try:
        # Publish the iteration
        if prediction_resource_id:
            training_client.publish_iteration(
                project.id, 
                iteration.id, 
                publish_name, 
                prediction_resource_id
            )
        else:
            # Try without resource ID (older API)
            training_client.publish_iteration(
                project.id, 
                iteration.id, 
                publish_name
            )
        
        print(f"✓ Model published successfully as '{publish_name}'")
        print(f"\n⚠️  Save this publish name for predictions: {publish_name}")
    except Exception as e:
        print(f"Error publishing model: {e}")
        print("You may need to publish manually through the Custom Vision portal.")

## 11. Test the Model with Predictions

Now let's test the trained model by making predictions on test images. The model will classify each image and return prediction scores for each tag.

In [None]:
def classify_image(image_path, project_id, publish_name, prediction_client, confidence_threshold=0.5):
    """
    Classify an image using the trained model.
    
    Args:
        image_path: Path to the image file
        project_id: Custom Vision project ID
        publish_name: Published model name
        prediction_client: Authenticated prediction client
        confidence_threshold: Minimum confidence score to display (0-1)
    
    Returns:
        Prediction results
    """
    with open(image_path, "rb") as image_data:
        results = prediction_client.classify_image(
            project_id, 
            publish_name, 
            image_data.read()
        )
    
    # Display image
    img = Image.open(image_path)
    plt.figure(figsize=(8, 6))
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Test Image: {os.path.basename(image_path)}")
    plt.show()
    
    # Display predictions
    print(f"\nPredictions for {os.path.basename(image_path)}:")
    print("-" * 50)
    
    predictions_found = False
    for prediction in results.predictions:
        if prediction.probability >= confidence_threshold:
            print(f"{prediction.tag_name:15} : {prediction.probability:.2%} confidence")
            predictions_found = True
    
    if not predictions_found:
        print(f"No predictions above {confidence_threshold:.0%} confidence threshold")
        print("\nAll predictions:")
        for prediction in results.predictions:
            print(f"{prediction.tag_name:15} : {prediction.probability:.2%} confidence")
    
    return results

# Test with images from the test-images folder
test_images_path = 'test-images'

if os.path.exists(test_images_path):
    test_images = [f for f in os.listdir(test_images_path) 
                   if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
    
    print(f"Found {len(test_images)} test images\n")
    
    for test_image in test_images[:3]:  # Test first 3 images
        image_path = os.path.join(test_images_path, test_image)
        classify_image(image_path, project.id, publish_name, prediction_client)
        print("\n" + "="*50 + "\n")
else:
    print(f"Test images folder not found: {test_images_path}")
    print("You can use images from python/test-classifier/test-images/")

## 12. Evaluate Model Performance

Let's evaluate the model's performance metrics from the training iteration.

In [None]:
if iteration:
    # Get iteration performance
    performance = training_client.get_iteration_performance(project.id, iteration.id)
    
    print("Model Performance Metrics:")
    print("=" * 50)
    print(f"\nOverall Metrics:")
    print(f"  Precision: {performance.precision:.2%}")
    print(f"  Recall: {performance.recall:.2%}")
    print(f"  Average Precision: {performance.average_precision:.2%}")
    
    print(f"\nPer-Tag Performance:")
    print("-" * 50)
    
    for tag_performance in performance.per_tag_performance:
        print(f"\n{tag_performance.name}:")
        print(f"  Precision: {tag_performance.precision:.2%}")
        print(f"  Recall: {tag_performance.recall:.2%}")
        print(f"  Average Precision: {tag_performance.average_precision:.2%}")
    
    # Visualize performance
    tag_names_perf = [tp.name for tp in performance.per_tag_performance]
    precisions = [tp.precision for tp in performance.per_tag_performance]
    recalls = [tp.recall for tp in performance.per_tag_performance]
    
    fig, ax = plt.subplots(figsize=(10, 6))
    x = range(len(tag_names_perf))
    width = 0.35
    
    ax.bar([i - width/2 for i in x], precisions, width, label='Precision', color='skyblue')
    ax.bar([i + width/2 for i in x], recalls, width, label='Recall', color='lightcoral')
    
    ax.set_xlabel('Tags')
    ax.set_ylabel('Score')
    ax.set_title('Model Performance by Tag')
    ax.set_xticks(x)
    ax.set_xticklabels(tag_names_perf)
    ax.legend()
    ax.set_ylim([0, 1.1])
    
    plt.tight_layout()
    plt.show()

## 13. Improving Model Performance

If your model's performance isn't satisfactory, here are strategies to improve it:

### Add More Training Images

More diverse training data typically leads to better performance. Let's add more images from the `more-training-images` folder.

In [None]:
# Upload additional training images
more_images_path = 'python/train-classifier/more-training-images'

if os.path.exists(more_images_path):
    print("Uploading additional training images...\n")
    upload_images_from_folder(more_images_path, project.id, tags)
    
    # Verify new image counts
    print("\nUpdated Image Count:")
    print("-" * 40)
    total_images = 0
    for tag_name, tag in tags.items():
        tag_info = training_client.get_tag(project.id, tag.id)
        print(f"{tag_name:15} : {tag_info.image_count} images")
        total_images += tag_info.image_count
    print("-" * 40)
    print(f"{'Total':15} : {total_images} images")
else:
    print(f"Additional images folder not found: {more_images_path}")

### Retrain with More Data

After adding more images, retrain the model to see if performance improves.

In [None]:
# Retrain the model with additional data
print("Retraining model with additional images...\n")
new_iteration = train_model(project.id, training_client)

if new_iteration:
    # Publish the new iteration
    new_publish_name = f"FruitClassifier_v2"
    
    try:
        if prediction_resource_id:
            training_client.publish_iteration(
                project.id, 
                new_iteration.id, 
                new_publish_name, 
                prediction_resource_id
            )
        else:
            training_client.publish_iteration(
                project.id, 
                new_iteration.id, 
                new_publish_name
            )
        print(f"✓ New model published as '{new_publish_name}'")
    except Exception as e:
        print(f"Error publishing model: {e}")
    
    # Compare performance
    new_performance = training_client.get_iteration_performance(project.id, new_iteration.id)
    
    print("\nPerformance Comparison:")
    print("=" * 60)
    print(f"{'Metric':<20} {'Original':<15} {'After Retraining':<15} {'Change'}")
    print("-" * 60)
    print(f"{'Precision':<20} {performance.precision:<15.2%} {new_performance.precision:<15.2%} "
          f"{(new_performance.precision - performance.precision):+.2%}")
    print(f"{'Recall':<20} {performance.recall:<15.2%} {new_performance.recall:<15.2%} "
          f"{(new_performance.recall - performance.recall):+.2%}")
    print(f"{'Avg Precision':<20} {performance.average_precision:<15.2%} "
          f"{new_performance.average_precision:<15.2%} "
          f"{(new_performance.average_precision - performance.average_precision):+.2%}")

## 14. Best Practices and Tips

### Image Quality Guidelines:
- **Variety**: Include images with different angles, lighting, and backgrounds
- **Quality**: Use clear, well-lit images (minimum 256x256 pixels)
- **Quantity**: At least 50 images per tag for good performance
- **Balance**: Try to have similar numbers of images for each tag

### Model Improvement Strategies:
1. **Add negative examples**: Include images that should NOT be classified
2. **Review misclassifications**: Identify patterns in errors
3. **Adjust probability threshold**: Balance precision vs recall
4. **Use appropriate domain**: Choose specialized domains (Food, Retail, etc.)

### Production Considerations:
- Monitor prediction confidence scores
- Implement fallback logic for low confidence predictions
- Regularly retrain with new data
- Version control your models (track publish names)

## 15. Summary and Next Steps

### What You've Learned:
✓ Created an Azure Custom Vision classification project  
✓ Uploaded and tagged training images  
✓ Trained and published a custom classification model  
✓ Made predictions on test images  
✓ Evaluated model performance  
✓ Improved the model with additional training data  

### Next Steps:
- Explore the **Advanced Image Classification** notebook for:
  - Transfer learning concepts
  - Data augmentation techniques
  - Advanced evaluation metrics
  - Model export for edge deployment
  - Batch predictions

### Important Information to Save:
```python
# Save these for future use:
Project ID: [YOUR_PROJECT_ID]
Published Model Name: [YOUR_MODEL_NAME]
```

### Cleanup (Optional):
```python
# Uncomment to delete the project
# training_client.delete_project(project.id)
# print("Project deleted")
```

## Additional Resources

- [Azure Custom Vision Documentation](https://docs.microsoft.com/azure/cognitive-services/custom-vision-service/)
- [Custom Vision Python SDK Reference](https://docs.microsoft.com/python/api/azure-cognitiveservices-vision-customvision/)
- [Image Classification Best Practices](https://docs.microsoft.com/azure/cognitive-services/custom-vision-service/getting-started-improving-your-classifier)