In [None]:
import torch
from torchvision import transforms
from PIL import Image
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.datasets import ImageFolder
import torch.nn.functional as F  

dataset = ImageFolder(r"C:\Users\jrgaynor\Downloads\checkin-photos\training")
class_to_idx = dataset.class_to_idx

pose_name_map = {
    'pose_1': "Front Relaxed",
    'pose_2': "Back Relaxed",
    'pose_3': "Quarter Turn (Left)",
    'pose_4': "Quarter Turn (Right)",
    'pose_5': "Back Double Biceps",
    'pose_6': "Front Double Biceps",
    'pose_7': "Front Lat Spread",
    'pose_8': "Side Chest (Left)",
    'pose_11': "Abs & Thighs",
}

idx_to_class = {v: k for k, v in class_to_idx.items()}
class_names = [idx_to_class[i] for i in range(len(idx_to_class))]

# --- Config ---
MODEL_PATH = "pose_classifier.pth"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Transform (same as validation transforms during training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Load model architecture & weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.fc = nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()

def predict_pose(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = model(input_tensor)
        probabilities = F.softmax(output, dim=1)  # Get probabilities from logits
        predicted_idx = probabilities.argmax(dim=1).item()
        confidence = probabilities[0, predicted_idx].item()

        predicted_class = class_names[predicted_idx]
        predicted_pose_name = pose_name_map.get(predicted_class, "Unknown Pose")

    return predicted_pose_name, confidence

# Example usage:
pose, confidence = predict_pose("test_images/cbum.png")
print(f"Predicted Pose: {pose} (Confidence: {confidence:.2%})")

Predicted Pose: Front Lat Spread (Confidence: 27.15%)
