# Day 4.3: Predictions & Analysis

**Goal:** Visualize and analyze model predictions

**What we'll do:**
1. Load trained model and test data
2. Visualize correct predictions with confidence scores
3. Visualize misclassifications with detailed analysis
4. Show top confident and least confident predictions
5. Analyze prediction patterns
6. Make predictions on new sample images

**Expected time:** 15-20 minutes

---

## 1. Import Libraries

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from datetime import datetime

# TensorFlow and Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator

print(f"TensorFlow version: {tf.__version__}")
print("Libraries loaded successfully!")

## 2. Configuration

In [None]:
# Paths
TEST_CSV = "../../outputs/data_splits/test_split.csv"
MODEL_DIR = "../../outputs/models"
VIZ_DIR = "../../outputs/visualizations"
RESULTS_DIR = "../../outputs/evaluation_results"

os.makedirs(VIZ_DIR, exist_ok=True)

# Model parameters
IMG_SIZE = (128, 128)
BATCH_SIZE = 32

# Class names
CLASS_NAMES = ['glioma', 'meningioma', 'pituitary']
NUM_CLASSES = len(CLASS_NAMES)

print("Configuration loaded!")

## 3. Load Model and Data

In [None]:
# Find most recent model
model_files = [f for f in os.listdir(MODEL_DIR) if f.endswith('.keras')]
model_files.sort(key=lambda x: os.path.getmtime(os.path.join(MODEL_DIR, x)), reverse=True)
model_path = os.path.join(MODEL_DIR, model_files[0])

print(f"Loading model: {model_files[0]}")
model = keras.models.load_model(model_path)
print("✅ Model loaded!\n")

# Load test data
test_df = pd.read_csv(TEST_CSV)
test_df['label'] = test_df['label'].astype(str)

print(f"Test set: {len(test_df)} images")

# Create generator
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col='filepath',
    y_col='label',
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    color_mode='grayscale',
    shuffle=False
)

# Get predictions
print("\nMaking predictions...")
y_pred_probs = model.predict(test_generator, verbose=0)
y_pred_classes = np.argmax(y_pred_probs, axis=1)
y_true_classes = test_generator.classes
class_labels = list(test_generator.class_indices.keys())

print("✅ Predictions completed!")

## 4. Helper Function to Display Images

In [None]:
def display_predictions(indices, title, n_cols=5, figsize=(20, 12)):
    """
    Display images with their predictions and confidence scores.
    
    Args:
        indices: List of sample indices to display
        title: Title for the figure
        n_cols: Number of columns in the grid
        figsize: Figure size
    """
    n_samples = len(indices)
    n_rows = (n_samples + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    axes = axes.flatten() if n_samples > 1 else [axes]
    
    for idx, ax in enumerate(axes):
        if idx < n_samples:
            sample_idx = indices[idx]
            
            # Load image
            img_path = test_df.iloc[sample_idx]['filepath']
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            
            # Get predictions
            true_label = class_labels[y_true_classes[sample_idx]]
            pred_label = class_labels[y_pred_classes[sample_idx]]
            confidence = y_pred_probs[sample_idx][y_pred_classes[sample_idx]]
            
            # Get all class probabilities
            probs = y_pred_probs[sample_idx]
            
            # Display image
            ax.imshow(img, cmap='gray')
            
            # Set title with color coding
            is_correct = true_label == pred_label
            title_color = 'green' if is_correct else 'red'
            
            title_text = f"True: {true_label}\n"
            title_text += f"Pred: {pred_label} ({confidence:.1%})\n"
            title_text += f"Patient: {test_df.iloc[sample_idx]['patient_id']}"
            
            ax.set_title(title_text, fontsize=10, color=title_color, fontweight='bold')
            ax.axis('off')
            
            # Add text box with all probabilities
            prob_text = "\n".join([f"{class_labels[i]}: {probs[i]:.1%}" for i in range(NUM_CLASSES)])
            ax.text(
                0.02, 0.98, prob_text, 
                transform=ax.transAxes,
                fontsize=8,
                verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
            )
        else:
            ax.axis('off')
    
    plt.suptitle(title, fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    return fig

## 5. Visualize Correct Predictions (High Confidence)

In [None]:
# Find correct predictions with high confidence
correct_mask = y_pred_classes == y_true_classes
confidence_scores = np.max(y_pred_probs, axis=1)

correct_high_conf_indices = np.where(
    correct_mask & (confidence_scores > 0.95)
)[0]

# Select random samples
np.random.seed(42)
if len(correct_high_conf_indices) > 15:
    sample_indices = np.random.choice(correct_high_conf_indices, 15, replace=False)
else:
    sample_indices = correct_high_conf_indices

# Display
fig = display_predictions(
    sample_indices, 
    "Correct Predictions (High Confidence > 95%)",
    n_cols=5
)

# Save
save_path = os.path.join(VIZ_DIR, 'day4_03_correct_predictions_high_conf.png')
plt.savefig(save_path, dpi=200, bbox_inches='tight')
print(f"Saved to: {save_path}")
plt.show()

## 6. Visualize Correct Predictions (Low Confidence)

In [None]:
# Find correct predictions with low confidence (< 70%)
correct_low_conf_indices = np.where(
    correct_mask & (confidence_scores < 0.7)
)[0]

if len(correct_low_conf_indices) > 0:
    # Sort by confidence (lowest first)
    sorted_indices = correct_low_conf_indices[
        np.argsort(confidence_scores[correct_low_conf_indices])
    ]
    
    # Take first 10
    sample_indices = sorted_indices[:min(10, len(sorted_indices))]
    
    # Display
    fig = display_predictions(
        sample_indices,
        "Correct Predictions (Low Confidence < 70%) - Uncertain Cases",
        n_cols=5
    )
    
    # Save
    save_path = os.path.join(VIZ_DIR, 'day4_03_correct_predictions_low_conf.png')
    plt.savefig(save_path, dpi=200, bbox_inches='tight')
    print(f"Saved to: {save_path}")
    plt.show()
else:
    print("No correct predictions with low confidence found!")

## 7. Visualize Misclassifications

In [None]:
# Find misclassified samples
misclassified_indices = np.where(y_pred_classes != y_true_classes)[0]

print(f"Total misclassifications: {len(misclassified_indices)}")

if len(misclassified_indices) > 0:
    # Sort by confidence (highest confidence mistakes first - most interesting!)
    sorted_indices = misclassified_indices[
        np.argsort(-confidence_scores[misclassified_indices])
    ]
    
    # Take first 15
    sample_indices = sorted_indices[:min(15, len(sorted_indices))]
    
    # Display
    fig = display_predictions(
        sample_indices,
        "Misclassifications (Sorted by Confidence - Model was very confident but wrong!)",
        n_cols=5
    )
    
    # Save
    save_path = os.path.join(VIZ_DIR, 'day4_03_misclassifications.png')
    plt.savefig(save_path, dpi=200, bbox_inches='tight')
    print(f"Saved to: {save_path}")
    plt.show()
else:
    print("No misclassifications found! Perfect model!")

## 8. Analyze Misclassification Patterns

In [None]:
if len(misclassified_indices) > 0:
    # Create confusion pairs
    confusion_pairs = {}
    for idx in misclassified_indices:
        true = class_labels[y_true_classes[idx]]
        pred = class_labels[y_pred_classes[idx]]
        pair = f"{true} → {pred}"
        confusion_pairs[pair] = confusion_pairs.get(pair, 0) + 1
    
    # Sort by frequency
    sorted_pairs = sorted(confusion_pairs.items(), key=lambda x: x[1], reverse=True)
    
    print("\n" + "="*60)
    print("🔍 MISCLASSIFICATION PATTERNS")
    print("="*60)
    for pair, count in sorted_pairs:
        percentage = (count / len(misclassified_indices)) * 100
        print(f"  {pair}: {count} times ({percentage:.1f}% of errors)")
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 6))
    pairs, counts = zip(*sorted_pairs)
    ax.barh(pairs, counts, color='coral', alpha=0.7)
    ax.set_xlabel('Number of Misclassifications', fontsize=12)
    ax.set_ylabel('True → Predicted', fontsize=12)
    ax.set_title('Misclassification Patterns', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='x')
    plt.tight_layout()
    
    # Save
    save_path = os.path.join(VIZ_DIR, 'day4_03_misclassification_patterns.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\nSaved to: {save_path}")
    plt.show()

## 9. Per-Class Prediction Examples

In [None]:
# Show examples for each class
for class_idx, class_name in enumerate(class_labels):
    # Find correct predictions for this class
    class_indices = np.where(
        (y_true_classes == class_idx) & 
        (y_pred_classes == class_idx)
    )[0]
    
    if len(class_indices) > 0:
        # Select random samples
        np.random.seed(42)
        sample_indices = np.random.choice(
            class_indices, 
            min(10, len(class_indices)), 
            replace=False
        )
        
        # Display
        fig = display_predictions(
            sample_indices,
            f"Correct Predictions: {class_name.upper()}",
            n_cols=5
        )
        
        # Save
        save_path = os.path.join(VIZ_DIR, f'day4_03_examples_{class_name}.png')
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        print(f"Saved to: {save_path}")
        plt.show()

## 10. Prediction Function for New Images

In [None]:
def predict_single_image(image_path, model, show=True):
    """
    Make prediction on a single image.
    
    Args:
        image_path: Path to image file
        model: Trained Keras model
        show: Whether to display the image
    
    Returns:
        dict: Prediction results
    """
    # Load and preprocess image
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Could not load image: {image_path}")
    
    # Resize
    img_resized = cv2.resize(img, IMG_SIZE)
    
    # Normalize
    img_normalized = img_resized / 255.0
    
    # Add batch and channel dimensions
    img_batch = np.expand_dims(img_normalized, axis=(0, -1))
    
    # Predict
    predictions = model.predict(img_batch, verbose=0)[0]
    
    # Get predicted class and confidence
    pred_class_idx = np.argmax(predictions)
    pred_class = CLASS_NAMES[pred_class_idx]
    confidence = predictions[pred_class_idx]
    
    # Create results dictionary
    results = {
        'predicted_class': pred_class,
        'confidence': float(confidence),
        'all_probabilities': {
            CLASS_NAMES[i]: float(predictions[i]) 
            for i in range(NUM_CLASSES)
        }
    }
    
    # Display if requested
    if show:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Show image
        ax1.imshow(img, cmap='gray')
        ax1.set_title(f"Predicted: {pred_class.upper()}\nConfidence: {confidence:.1%}", 
                     fontsize=14, fontweight='bold')
        ax1.axis('off')
        
        # Show probability bars
        colors = ['green' if i == pred_class_idx else 'gray' for i in range(NUM_CLASSES)]
        ax2.barh(CLASS_NAMES, predictions, color=colors, alpha=0.7)
        ax2.set_xlabel('Probability', fontsize=12)
        ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold')
        ax2.set_xlim(0, 1)
        ax2.grid(True, alpha=0.3, axis='x')
        
        # Add percentage labels
        for i, prob in enumerate(predictions):
            ax2.text(prob + 0.02, i, f'{prob:.1%}', va='center')
        
        plt.tight_layout()
        plt.show()
    
    return results

print("✅ Prediction function ready!")
print("\nUsage: predict_single_image('path/to/image.png', model)")

## 11. Test Prediction Function

Let's test the prediction function on a few random test images

In [None]:
# Select 3 random test images
np.random.seed(42)
sample_indices = np.random.choice(len(test_df), 3, replace=False)

print("Testing prediction function on sample images:\n")

for i, idx in enumerate(sample_indices, 1):
    print(f"\n{'='*60}")
    print(f"Sample {i}:")
    print(f"{'='*60}")
    
    img_path = test_df.iloc[idx]['filepath']
    true_label = test_df.iloc[idx]['label']
    patient_id = test_df.iloc[idx]['patient_id']
    
    print(f"Image: {os.path.basename(img_path)}")
    print(f"True label: {true_label}")
    print(f"Patient ID: {patient_id}")
    
    # Make prediction
    results = predict_single_image(img_path, model, show=True)
    
    print(f"\nPrediction: {results['predicted_class']}")
    print(f"Confidence: {results['confidence']:.1%}")
    print("\nAll probabilities:")
    for class_name, prob in results['all_probabilities'].items():
        print(f"  {class_name}: {prob:.1%}")

## 12. Summary

In [None]:
print("\n" + "="*70)
print("🎉 DAY 4.3 COMPLETE - PREDICTION ANALYSIS FINISHED!")
print("="*70)

print("\n📊 Summary:")
correct = np.sum(y_pred_classes == y_true_classes)
total = len(y_pred_classes)
accuracy = correct / total

print(f"  Total predictions: {total}")
print(f"  Correct: {correct}")
print(f"  Incorrect: {total - correct}")
print(f"  Accuracy: {accuracy*100:.2f}%")

print("\n📁 Visualizations Created:")
viz_files = [
    'day4_03_correct_predictions_high_conf.png',
    'day4_03_correct_predictions_low_conf.png',
    'day4_03_misclassifications.png',
    'day4_03_misclassification_patterns.png',
    'day4_03_examples_glioma.png',
    'day4_03_examples_meningioma.png',
    'day4_03_examples_pituitary.png'
]

for f in viz_files:
    path = os.path.join(VIZ_DIR, f)
    if os.path.exists(path):
        print(f"  ✅ {f}")

print("\n💡 Key Insights:")
print("  - Review high-confidence misclassifications to find challenging cases")
print("  - Low-confidence correct predictions show model uncertainty")
print("  - Misclassification patterns reveal systematic errors")
print("  - Use predict_single_image() function for new images")

print("\n🎯 Project Complete!")
print("  You've successfully:")
print("  ✅ Extracted and enhanced 7,181 brain MRI images")
print("  ✅ Created patient-wise data splits")
print("  ✅ Trained a CNN model")
print("  ✅ Evaluated model performance")
print("  ✅ Analyzed predictions and errors")

print("\n🚀 Next Steps (Optional):")
print("  1. Try transfer learning (ResNet, VGG, etc.)")
print("  2. Experiment with different architectures")
print("  3. Tune hyperparameters (learning rate, batch size, etc.)")
print("  4. Create a simple web app for predictions")
print("  5. Export model for deployment")

print("\n" + "="*70)
print("🎊 CONGRATULATIONS! You've completed the Brain Tumor Classification project!")
print("="*70)