In [None]:
import torch
from torchvision import transforms
from PIL import Image
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn.functional as F
import json

# --- Load class mappings from saved file ---
CLASS_MAPPING_PATH = "class_mapping.json"

try:
    with open(CLASS_MAPPING_PATH, 'r') as f:
        class_mapping = json.load(f)
    
    class_to_idx = class_mapping['class_to_idx']
    idx_to_class = class_mapping['idx_to_class']
    # Convert string keys back to integers for idx_to_class
    idx_to_class = {int(k): v for k, v in idx_to_class.items()}
    class_names = class_mapping['classes']
    
    print(f"Loaded class mappings from {CLASS_MAPPING_PATH}")
    print(f"Available classes: {class_names}")
    print(f"Class to index mapping: {class_to_idx}")
    
except FileNotFoundError:
    print(f"Warning: {CLASS_MAPPING_PATH} not found. Trying to load from complete checkpoint...")
    
    # Alternative: Load from complete checkpoint
    try:
        checkpoint = torch.load('pose_classifier_complete.pth', map_location='cpu')
        class_mapping = checkpoint['class_mapping']
        class_to_idx = class_mapping['class_to_idx']
        idx_to_class = {int(k): v for k, v in class_mapping['idx_to_class'].items()}
        class_names = class_mapping['classes']
        print("Loaded class mappings from complete checkpoint")
    except FileNotFoundError:
        raise FileNotFoundError("Neither class_mapping.json nor pose_classifier_complete.pth found. Please run training first.")

# --- Pose name mapping (human-readable names) ---
pose_name_map = {
    'pose_1': "Front Relaxed",
    'pose_2': "Back Relaxed", 
    'pose_3': "Quarter Turn (Left)",
    'pose_4': "Quarter Turn (Right)",
    'pose_5': "Back Double Biceps",
    'pose_6': "Front Double Biceps",
    'pose_7': "Front Lat Spread",
    'pose_8': "Side Chest (Left)",
    'pose_11': "Abs & Thighs",
}

# --- Config ---
MODEL_PATH = "pose_classifier.pth"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Transform (same as validation transforms during training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Load model architecture & weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.fc = nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()

def predict_pose(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = model(input_tensor)
        probabilities = F.softmax(output, dim=1)  # Get probabilities from logits
        predicted_idx = probabilities.argmax(dim=1).item()
        confidence = probabilities[0, predicted_idx].item()

        predicted_class = idx_to_class[predicted_idx]
        predicted_pose_name = pose_name_map.get(predicted_class, f"Unknown Pose ({predicted_class})")

    return predicted_pose_name, confidence, predicted_class

def predict_pose_with_top_k(image_path, k=3):
    """Get top k predictions with confidence scores"""
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = model(input_tensor)
        probabilities = F.softmax(output, dim=1)
        top_k_probs, top_k_indices = torch.topk(probabilities, k)
        
        results = []
        for i in range(k):
            idx = top_k_indices[0][i].item()
            prob = top_k_probs[0][i].item()
            class_name = idx_to_class[idx]
            pose_name = pose_name_map.get(class_name, f"Unknown Pose ({class_name})")
            results.append({
                'pose_name': pose_name,
                'class_name': class_name,
                'confidence': prob
            })
    
    return results

# Example usage:
pose, confidence, class_name = predict_pose("test_images/cbum.png")
print(f"Predicted Pose: {pose} (Class: {class_name}, Confidence: {confidence:.2%})")

Predicted Pose: Front Lat Spread (Confidence: 27.15%)


In [4]:
!nbstripout --strip-files

'nbstripout' is not recognized as an internal or external command,
operable program or batch file.
