# Rotation Classification Inference

This notebook performs inference on images using the trained rotation classification model.
It processes images from a specified folder and displays results for non-zero predictions only.

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import resnet18
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob

## Configuration

In [None]:
# Configuration
MODEL_PATH = "pipeline/checkpoints/best_model.pth"  # Path to your trained model
IMAGE_SIZE = 300
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Class names (rotation angles)
CLASS_NAMES = ['0', '180', '270', '90']  # Adjust based on your actual class order

print(f"Using device: {DEVICE}")

## Model Loading

In [None]:
def load_model(model_path, num_classes, device):
    """Load the trained model from checkpoint."""
    model = resnet18()
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    model.to(device)
    model.eval()
    
    print(f"Model loaded successfully from {model_path}")
    print(f"Model was trained for {checkpoint['epoch']} epochs")
    print(f"Best validation accuracy: {checkpoint['val_acc']:.2f}%")
    
    return model

# Load the model
model = load_model(MODEL_PATH, len(CLASS_NAMES), DEVICE)

## Image Preprocessing

In [None]:
# Define the same preprocessing as used during training
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

def preprocess_image(image_path):
    """Preprocess a single image for inference."""
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension
    return image, image_tensor

## Inference Function

In [None]:
def predict_rotation(model, image_tensor, device):
    """Predict rotation angle for a single image."""
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        outputs = model(image_tensor)
        
        # Get probabilities
        probabilities = F.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
        
        return predicted.item(), confidence.item(), probabilities.squeeze().cpu().numpy()

def visualize_prediction(image, image_path, predicted_class, confidence, probabilities, class_names):
    """Visualize the image with prediction results."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Display image
    ax1.imshow(image)
    ax1.set_title(f"Image: {os.path.basename(image_path)}")
    ax1.axis('off')
    
    # Display prediction results
    ax2.bar(class_names, probabilities)
    ax2.set_title(f"Prediction: {class_names[predicted_class]}° (Confidence: {confidence:.3f})")
    ax2.set_ylabel('Probability')
    ax2.set_xlabel('Rotation Angle')
    
    # Highlight the predicted class
    ax2.bar(class_names[predicted_class], probabilities[predicted_class], color='red', alpha=0.7)
    
    plt.tight_layout()
    plt.show()
    
    print(f"File: {os.path.basename(image_path)}")
    print(f"Predicted rotation: {class_names[predicted_class]}°")
    print(f"Confidence: {confidence:.3f}")
    print(f"All probabilities: {dict(zip(class_names, [f'{p:.3f}' for p in probabilities]))}")
    print("-" * 50)

## Batch Inference on Folder

In [None]:
def process_folder(folder_path, model, device, class_names, show_only_non_zero=True):
    """Process all images in a folder and show results."""
    # Supported image extensions
    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff', '*.tif']
    
    # Get all image files
    image_files = []
    for ext in image_extensions:
        image_files.extend(glob.glob(os.path.join(folder_path, ext)))
        image_files.extend(glob.glob(os.path.join(folder_path, ext.upper())))
    
    if not image_files:
        print(f"No image files found in {folder_path}")
        return
    
    print(f"Found {len(image_files)} images in {folder_path}")
    print("=" * 60)
    
    results = []
    
    for image_path in sorted(image_files):
        try:
            # Preprocess image
            image, image_tensor = preprocess_image(image_path)
            
            # Predict
            predicted_class, confidence, probabilities = predict_rotation(model, image_tensor, device)
            
            # Store result
            results.append({
                'path': image_path,
                'image': image,
                'predicted_class': predicted_class,
                'confidence': confidence,
                'probabilities': probabilities
            })
            
            # Show only non-zero predictions if requested
            if show_only_non_zero and class_names[predicted_class] != '0':
                visualize_prediction(image, image_path, predicted_class, confidence, probabilities, class_names)
            elif not show_only_non_zero:
                visualize_prediction(image, image_path, predicted_class, confidence, probabilities, class_names)
                
        except Exception as e:
            print(f"Error processing {image_path}: {str(e)}")
    
    return results

## Run Inference

In [7]:
# Update the INPUT_FOLDER path to your actual folder
BASE_FOLDER = "data/rotation/batches/"
ADDENDUM="/images/boxes"
FOLDER="task_lyd batch 26_backup_2025_07_18_13_47_53_COCO"
INPUT_FOLDER = BASE_FOLDER+FOLDER+ADDENDUM


# Check if folder exists
if not os.path.exists(INPUT_FOLDER):
    print(f"Error: Folder '{INPUT_FOLDER}' does not exist.")
else:
    # Process all images in the folder
    results = process_folder(INPUT_FOLDER, model, DEVICE, CLASS_NAMES, show_only_non_zero=True)
    
    # Summary statistics
    if results:
        print("\n" + "=" * 60)
        print("SUMMARY")
        print("=" * 60)
        
        total_images = len(results)
        class_counts = {class_name: 0 for class_name in CLASS_NAMES}
        
        for result in results:
            predicted_class_name = CLASS_NAMES[result['predicted_class']]
            class_counts[predicted_class_name] += 1
        
        print(f"Total images processed: {total_images}")
        print("\nPrediction distribution:")
        for class_name, count in class_counts.items():
            percentage = (count / total_images) * 100
            print(f"  {class_name}°: {count} images ({percentage:.1f}%)")
        
        non_zero_count = sum(count for class_name, count in class_counts.items() if class_name != '0')
        print(f"\nImages requiring rotation: {non_zero_count} ({(non_zero_count/total_images)*100:.1f}%)")

KeyboardInterrupt: 

## Optional: Show All Results (including 0° predictions)

In [None]:
# Uncomment and run this cell if you want to see ALL predictions (including 0° rotations)

# show_all = input("Do you want to see ALL predictions including 0° rotations? (y/n): ")
# if show_all.lower() == 'y':
#     print("\n" + "=" * 60)
#     print("SHOWING ALL PREDICTIONS (INCLUDING 0° ROTATIONS)")
#     print("=" * 60)
#     
#     if 'results' in locals():
#         for result in results:
#             if CLASS_NAMES[result['predicted_class']] == '0':  # Show only 0° predictions this time
#                 visualize_prediction(
#                     result['image'], 
#                     result['path'], 
#                     result['predicted_class'], 
#                     result['confidence'], 
#                     result['probabilities'], 
#                     CLASS_NAMES
#                 )