# Medical X-ray Triage - Setup and Environment Verification

This notebook verifies the environment setup and demonstrates basic functionality of the Medical X-ray Triage project.

## Table of Contents
1. [Environment Check](#environment-check)
2. [Data Exploration](#data-exploration)
3. [Model Verification](#model-verification)
4. [Sample Data Generation](#sample-data-generation)
5. [Quick Demo](#quick-demo)


## 1. Environment Check


In [None]:
# Import required libraries
import sys
import os
import torch
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Add src to path
sys.path.append('../src')

print("Python version:", sys.version)
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("NumPy version:", np.__version__)
print("Pandas version:", pd.__version__)


In [None]:
# Check CUDA availability
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device count:", torch.cuda.device_count())
    print("Current CUDA device:", torch.cuda.current_device())
    print("CUDA device name:", torch.cuda.get_device_name())

# Check MPS availability (Apple Silicon)
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    print("MPS available: True")
else:
    print("MPS available: False")

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
# Test torchvision models
print("Testing torchvision models...")

# Test ResNet50
try:
    resnet = torchvision.models.resnet50(weights='IMAGENET1K_V2')
    print("✓ ResNet50 loaded successfully")
except Exception as e:
    print(f"✗ ResNet50 failed: {e}")

# Test EfficientNetV2-S
try:
    efficientnet = torchvision.models.efficientnet_v2_s(weights='IMAGENET1K_V1')
    print("✓ EfficientNetV2-S loaded successfully")
except Exception as e:
    print(f"✗ EfficientNetV2-S failed: {e}")

# Test Grad-CAM
try:
    from pytorch_grad_cam import GradCAM
    print("✓ pytorch-grad-cam available")
except ImportError:
    print("✗ pytorch-grad-cam not available. Install with: pip install pytorch-grad-cam")

# Test Streamlit
try:
    import streamlit
    print("✓ Streamlit available")
except ImportError:
    print("✗ Streamlit not available. Install with: pip install streamlit")


## 2. Data Exploration


In [None]:
# Check if sample data exists
sample_data_dir = "../data/sample"
labels_path = os.path.join(sample_data_dir, "labels.csv")
images_dir = os.path.join(sample_data_dir, "images")

print(f"Sample data directory: {sample_data_dir}")
print(f"Labels file exists: {os.path.exists(labels_path)}")
print(f"Images directory exists: {os.path.exists(images_dir)}")

if os.path.exists(images_dir):
    image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    print(f"Number of images: {len(image_files)}")
    print(f"Image files: {image_files}")


In [None]:
# Load and explore labels
if os.path.exists(labels_path):
    labels_df = pd.read_csv(labels_path)
    print("Labels DataFrame:")
    print(labels_df.head())
    print(f"\nDataset shape: {labels_df.shape}")
    print(f"Columns: {list(labels_df.columns)}")
    
    # Class distribution
    print("\nClass distribution:")
    class_counts = labels_df['label'].value_counts().sort_index()
    print(class_counts)
    
    # Plot class distribution
    plt.figure(figsize=(8, 5))
    class_counts.plot(kind='bar', color=['lightblue', 'lightcoral'])
    plt.title('Class Distribution in Sample Dataset')
    plt.xlabel('Class (0=Normal, 1=Abnormal)')
    plt.ylabel('Count')
    plt.xticks(rotation=0)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("Labels file not found. Will generate sample data.")


In [None]:
# Display sample images
if os.path.exists(images_dir) and len(image_files) > 0:
    print("Sample Images:")
    
    # Select first 4 images
    sample_images = image_files[:4]
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.ravel()
    
    for i, img_name in enumerate(sample_images):
        img_path = os.path.join(images_dir, img_name)
        
        # Load image
        image = Image.open(img_path)
        
        # Get label
        label = labels_df[labels_df['filepath'] == f'images/{img_name}']['label'].iloc[0]
        class_name = 'Normal' if label == 0 else 'Abnormal'
        
        # Display image
        axes[i].imshow(image, cmap='gray')
        axes[i].set_title(f'{img_name}\n({class_name})')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Image statistics
    print("\nImage Statistics:")
    for img_name in sample_images[:2]:  # Show stats for first 2 images
        img_path = os.path.join(images_dir, img_name)
        image = Image.open(img_path)
        img_array = np.array(image)
        
        print(f"{img_name}:")
        print(f"  Shape: {img_array.shape}")
        print(f"  Min value: {img_array.min()}")
        print(f"  Max value: {img_array.max()}")
        print(f"  Mean value: {img_array.mean():.2f}")
        print(f"  Std value: {img_array.std():.2f}")
else:
    print("No sample images found.")


## 3. Model Verification


In [None]:
# Test model creation
from model import create_model, get_model_summary

print("Testing model creation...")

# Test ResNet50 model
try:
    resnet_model = create_model("resnet50", num_classes=1, pretrained=True)
    resnet_summary = get_model_summary(resnet_model)
    
    print("✓ ResNet50 model created successfully")
    print(f"  Total parameters: {resnet_summary['total_parameters']:,}")
    print(f"  Trainable parameters: {resnet_summary['trainable_parameters']:,}")
    print(f"  Model size: {resnet_summary['model_size_mb']:.2f} MB")
    
except Exception as e:
    print(f"✗ ResNet50 model creation failed: {e}")

# Test EfficientNetV2-S model
try:
    effnet_model = create_model("efficientnet_v2_s", num_classes=1, pretrained=True)
    effnet_summary = get_model_summary(effnet_model)
    
    print("✓ EfficientNetV2-S model created successfully")
    print(f"  Total parameters: {effnet_summary['total_parameters']:,}")
    print(f"  Trainable parameters: {effnet_summary['trainable_parameters']:,}")
    print(f"  Model size: {effnet_summary['model_size_mb']:.2f} MB")
    
except Exception as e:
    print(f"✗ EfficientNetV2-S model creation failed: {e}")


In [None]:
# Test forward pass
print("Testing forward pass...")

try:
    # Create dummy input
    dummy_input = torch.randn(2, 3, 320, 320)
    
    # Test ResNet50
    resnet_model.eval()
    with torch.no_grad():
        output = resnet_model(dummy_input)
        prob = resnet_model.predict_proba(dummy_input)
    
    print("✓ Forward pass successful")
    print(f"  Input shape: {dummy_input.shape}")
    print(f"  Output shape: {output.shape}")
    print(f"  Probability shape: {prob.shape}")
    print(f"  Sample probabilities: {prob.flatten()[:5].tolist()}")
    
except Exception as e:
    print(f"✗ Forward pass failed: {e}")


## 4. Sample Data Generation


In [None]:
# Generate sample data if it doesn't exist
if not os.path.exists(labels_path) or not os.path.exists(images_dir):
    print("Generating sample data...")
    
    from make_sample_data import create_sample_dataset
    
    try:
        labels_df = create_sample_dataset("../data/sample")
        print("✓ Sample data generated successfully")
        print(f"  Generated {len(labels_df)} samples")
        
    except Exception as e:
        print(f"✗ Sample data generation failed: {e}")
else:
    print("✓ Sample data already exists")


## 5. Quick Demo


In [None]:
# Test data loading
from data import get_simple_data_loader, print_dataset_info

print("Testing data loading...")

try:
    # Print dataset info
    print_dataset_info(labels_path, images_dir)
    
    # Test data loader
    data_loader, dataset = get_simple_data_loader(
        labels_path=labels_path,
        images_dir=images_dir,
        batch_size=2,
        img_size=320,
        is_training=False
    )
    
    print(f"\n✓ Data loader created successfully")
    print(f"  Number of batches: {len(data_loader)}")
    
    # Test a batch
    for images, labels in data_loader:
        print(f"  Batch - Images shape: {images.shape}, Labels shape: {labels.shape}")
        print(f"  Sample labels: {labels.numpy()}")
        break
        
except Exception as e:
    print(f"✗ Data loading failed: {e}")


In [None]:
# Test utility functions
from utils import compute_metrics, seed_everything

print("Testing utility functions...")

try:
    # Test seeding
    seed_everything(42)
    print("✓ Seeding function works")
    
    # Test metrics computation
    y_true = np.array([0, 1, 0, 1, 1])
    y_prob = np.array([0.1, 0.9, 0.2, 0.8, 0.7])
    
    metrics = compute_metrics(y_true, y_prob)
    print("✓ Metrics computation works")
    print(f"  AUROC: {metrics['auroc']:.3f}")
    print(f"  F1 Score: {metrics['f1']:.3f}")
    print(f"  Precision: {metrics['precision']:.3f}")
    print(f"  Recall: {metrics['recall']:.3f}")
    
except Exception as e:
    print(f"✗ Utility functions failed: {e}")


## Summary

This notebook has verified:

1. ✅ **Environment Setup**: Python, PyTorch, and required libraries
2. ✅ **Hardware**: CUDA/MPS availability and device selection
3. ✅ **Models**: ResNet50 and EfficientNetV2-S model creation
4. ✅ **Data**: Sample data generation and loading
5. ✅ **Utilities**: Metrics computation and helper functions

### Next Steps

1. **Train a model**: Run `python src/train.py` to train on sample data
2. **Evaluate model**: Run `python src/eval.py` to evaluate the trained model
3. **Generate Grad-CAM**: Run `python src/interpret.py` for interpretability
4. **Launch UI**: Run `streamlit run ui/app.py` for the web interface

### Available Commands

```bash
# Generate sample data
python src/make_sample_data.py

# Train model
python src/train.py --epochs 5

# Evaluate model
python src/eval.py

# Generate Grad-CAM visualizations
python src/interpret.py

# Launch Streamlit UI
streamlit run ui/app.py
```

The environment is ready for the Medical X-ray Triage project! 🎉
 