In [None]:
import boto3
import json
from PIL import Image
import io
import os
import random
import matplotlib.pyplot as plt

# Create a SageMaker runtime client
runtime = boto3.client('sagemaker-runtime')

def test_prediction(image_path):
    """
    Test the endpoint with an image
    Args:
        image_path (str): Path to the image file
    Returns:
        dict: Prediction results with probabilities
    """
    # Read and display the image
    img = Image.open(image_path)
    plt.figure(figsize=(6, 6))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

    # Convert image to bytes for prediction
    img_byte_arr = io.BytesIO()
    img.save(img_byte_arr, format=img.format)
    image_data = img_byte_arr.getvalue()

    # Call the endpoint
    response = runtime.invoke_endpoint(
        EndpointName='cat-dog-classifier-v1',
        ContentType='application/x-image',
        Body=image_data
    )

    # Parse the response
    result = json.loads(response['Body'].read().decode())

    # Print formatted results
    print("\nPrediction Results:")
    print("-----------------")
    print(f"Cat Probability: {result['cat_probability']:.2%}")
    print(f"Dog Probability: {result['dog_probability']:.2%}")

    return result

def get_random_test_image():
    """
    Get a random image from the training data
    Returns:
        tuple: (image_path, true_label)
    """
    # Define the base path for training data
    base_path = '../data/sample_data'

    # Randomly choose cat or dog
    category = random.choice(['cats', 'dogs'])

    # Get list of images in the chosen category
    category_path = os.path.join(base_path, category)
    images = os.listdir(category_path)

    # Choose a random image
    image_name = random.choice(images)
    image_path = os.path.join(category_path, image_name)

    return image_path, category.rstrip('s')  # remove 's' to get 'cat' or 'dog'

In [None]:
# Test with a random image
image_path, true_label = get_random_test_image()
print(f"Testing with a {true_label} image: {image_path}")
result = test_prediction(image_path)