# Multi-Class Chest X-Ray Disease Detection

**Production Notebook**: 4-Class Classification (Normal | TB | Pneumonia | COVID-19)

- **Model**: EfficientNet-B2 + Adaptive Sparse Training (AST)
- **Target**: 92-95% accuracy, 85-90% energy savings
- **Features**: Grad-CAM explainability, advanced augmentation, class-weighted loss

---

## 1. Environment Setup

In [None]:
# Install dependencies
!pip install -q torch torchvision kaggle matplotlib seaborn pillow opencv-python scikit-learn pandas tqdm

import torch
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Detect environment and setup
import os
from pathlib import Path

try:
    from google.colab import files
    IN_COLAB = True
    print("Environment: Google Colab")
except ImportError:
    IN_COLAB = False
    print("Environment: Local Jupyter")

# Clone repository if needed
if not os.path.exists('train_multiclass_simple.py'):
    !git clone https://github.com/oluwafemidiakhoa/Tuberculosis.git
    %cd Tuberculosis
else:
    print("Already in Tuberculosis directory")

# Setup Kaggle credentials
kaggle_dir = Path.home() / '.kaggle'
kaggle_file = kaggle_dir / 'kaggle.json'

if not kaggle_file.exists():
    if IN_COLAB:
        print("Upload kaggle.json:")
        uploaded = files.upload()
        kaggle_dir.mkdir(parents=True, exist_ok=True)
        !cp kaggle.json ~/.kaggle/
        !chmod 600 ~/.kaggle/kaggle.json
    else:
        print("Place kaggle.json in ~/.kaggle/")
        print("Download from: https://www.kaggle.com/settings/account")
else:
    !chmod 600 ~/.kaggle/kaggle.json
    print("✓ Kaggle configured")

## 2. Data Preparation

In [None]:
# Download datasets
datasets = [
    ('tawsifurrahman/covid19-radiography-database', 'data_covid', 'COVID-19 & Normal'),
    ('paultimothymooney/chest-xray-pneumonia', 'data_pneumonia', 'Pneumonia'),
    ('tawsifurrahman/tuberculosis-tb-chest-xray-dataset', 'data_tb', 'TB')
]

for dataset_id, output_dir, name in datasets:
    if not os.path.exists(output_dir):
        print(f"Downloading {name}...")
        !kaggle datasets download -d {dataset_id}
        zip_name = dataset_id.split('/')[-1] + '.zip'
        !unzip -q {zip_name} -d {output_dir}
        !rm {zip_name}
        print(f"✓ {name} ready")
    else:
        print(f"✓ {name} exists")

print("\n✓ All datasets downloaded")

In [None]:
# Organize dataset
from PIL import Image
import shutil
import random

random.seed(42)
data_dir = Path('data_multiclass')

if (data_dir / 'train' / 'Normal').exists() and len(list((data_dir / 'train' / 'Normal').glob('*.png'))) > 100:
    print("✓ Dataset already organized")
else:
    print("Organizing dataset...\n")
    
    for split in ['train', 'val', 'test']:
        for cls in ['Normal', 'TB', 'Pneumonia', 'COVID']:
            (data_dir / split / cls).mkdir(parents=True, exist_ok=True)
    
    def is_valid_image(img_path):
        try:
            with Image.open(img_path) as img:
                img.verify()
            with Image.open(img_path) as img:
                img.load()
                if img.size[0] < 10 or img.size[1] < 10:
                    return False
            return True
        except:
            return False
    
    def copy_images(source_patterns, class_name, target_root, max_count=3000):
        images = []
        corrupted = 0
        
        for pattern in source_patterns:
            for img_path in Path('.').rglob(pattern):
                if is_valid_image(img_path):
                    images.append(img_path)
                else:
                    corrupted += 1
        
        print(f"  {class_name}: {len(images)} valid ({corrupted} corrupted)")
        
        random.shuffle(images)
        images = images[:max_count]
        
        n = len(images)
        n_train = int(0.70 * n)
        n_val = int(0.15 * n)
        
        splits = {
            'train': images[:n_train],
            'val': images[n_train:n_train+n_val],
            'test': images[n_train+n_val:]
        }
        
        for split_name, split_images in splits.items():
            for i, img_path in enumerate(split_images):
                dest = target_root / split_name / class_name / f"{class_name}_{i}.png"
                try:
                    shutil.copy(img_path, dest)
                except:
                    continue
        
        return len(splits['train']), len(splits['val']), len(splits['test'])
    
    for patterns, cls_name in [
        (['data_covid/**/Normal/**/*.png', 'data_covid/**/Normal/**/*.jpg'], 'Normal'),
        (['data_covid/**/COVID/**/*.png', 'data_covid/**/COVID/**/*.jpg'], 'COVID'),
        (['data_pneumonia/**/PNEUMONIA/**/*.jpeg', 'data_pneumonia/**/PNEUMONIA/**/*.png'], 'Pneumonia'),
        (['data_tb/**/Tuberculosis/**/*.png', 'data_tb/**/Tuberculosis/**/*.jpg'], 'TB')
    ]:
        train, val, test = copy_images(patterns, cls_name, data_dir, 3000)
        print(f"    Train: {train}, Val: {val}, Test: {test}\n")
    
    print("✓ Dataset organized")

In [None]:
# Visualize distribution
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

CLASSES = ['Normal', 'TB', 'Pneumonia', 'COVID']

train_counts = [len(list((data_dir / 'train' / cls).glob('*.png'))) for cls in CLASSES]
val_counts = [len(list((data_dir / 'val' / cls).glob('*.png'))) for cls in CLASSES]
test_counts = [len(list((data_dir / 'test' / cls).glob('*.png'))) for cls in CLASSES]

fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Dataset Distribution', fontsize=20, fontweight='bold')

# Pie chart
colors = ['#2ecc71', '#e74c3c', '#f39c12', '#9b59b6']
axes[0].pie(train_counts, labels=CLASSES, autopct='%1.1f%%',
           colors=colors, explode=[0.05]*4, shadow=True, startangle=90,
           textprops={'fontsize': 14, 'weight': 'bold'})
axes[0].set_title('Class Distribution', fontsize=16, fontweight='bold')

# Bar chart
x = np.arange(len(CLASSES))
width = 0.25
axes[1].bar(x - width, train_counts, width, label='Train (70%)', color='#3498db')
axes[1].bar(x, val_counts, width, label='Val (15%)', color='#e67e22')
axes[1].bar(x + width, test_counts, width, label='Test (15%)', color='#95a5a6')
axes[1].set_ylabel('Number of Images', fontsize=12, fontweight='bold')
axes[1].set_title('Train/Val/Test Split', fontsize=16, fontweight='bold')
axes[1].set_xticks(x)
axes[1].set_xticklabels(CLASSES)
axes[1].legend()
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('dataset_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nDataset Summary:")
for i, cls in enumerate(CLASSES):
    total = train_counts[i] + val_counts[i] + test_counts[i]
    print(f"  {cls:12s}: {total:4d} images (Train: {train_counts[i]}, Val: {val_counts[i]}, Test: {test_counts[i]})")

## 3. Model Training

**Config**: EfficientNet-B2 | 100 epochs | Batch 32 | AST 15% activation  
**Expected**: 92-95% accuracy, 85-90% energy savings

In [None]:
# Train the model
!python train_optimized_90_95.py

print("\n✓ Training complete")
print("  Checkpoints: checkpoints_multiclass_optimized/")
print("  Metrics: checkpoints_multiclass_optimized/metrics_optimized.csv")

## 4. Evaluation & Visualization

In [None]:
# Visualize training metrics
import pandas as pd

df = pd.read_csv('checkpoints_multiclass_optimized/metrics_optimized.csv')

if df['val_acc'].max() > 1:
    df['val_acc'] = df['val_acc'] / 100

fig, axes = plt.subplots(2, 2, figsize=(18, 12))
fig.suptitle('Training Results - Multi-Class Disease Detection', fontsize=24, fontweight='bold')

# Loss
axes[0,0].plot(df['epoch'], df['train_loss'], label='Train', linewidth=3, color='#e74c3c')
axes[0,0].plot(df['epoch'], df['val_loss'], label='Validation', linewidth=3, color='#3498db')
axes[0,0].set_xlabel('Epoch', fontsize=14, fontweight='bold')
axes[0,0].set_ylabel('Loss', fontsize=14, fontweight='bold')
axes[0,0].set_title('Training & Validation Loss', fontsize=16, fontweight='bold')
axes[0,0].legend(fontsize=12)
axes[0,0].grid(True, alpha=0.3)

# Accuracy
best_acc = df['val_acc'].max() * 100
axes[0,1].plot(df['epoch'], df['val_acc']*100, linewidth=3, color='#2ecc71')
axes[0,1].axhline(best_acc, color='#e74c3c', linestyle='--', linewidth=2.5, 
                 label=f'Best: {best_acc:.2f}%')
axes[0,1].axhline(90, color='#f39c12', linestyle=':', linewidth=2, label='Target: 90%')
axes[0,1].set_xlabel('Epoch', fontsize=14, fontweight='bold')
axes[0,1].set_ylabel('Accuracy (%)', fontsize=14, fontweight='bold')
axes[0,1].set_title(f'Validation Accuracy (Best: {best_acc:.2f}%)', fontsize=16, fontweight='bold')
axes[0,1].legend(fontsize=12)
axes[0,1].grid(True, alpha=0.3)
axes[0,1].set_ylim([0, 105])

# Activation rate
avg_activation = df['activation_rate'].mean() * 100
axes[1,0].plot(df['epoch'], df['activation_rate']*100, linewidth=3, color='#f39c12')
axes[1,0].axhline(15, color='#e74c3c', linestyle='--', linewidth=2.5, label='Target: 15%')
axes[1,0].set_xlabel('Epoch', fontsize=14, fontweight='bold')
axes[1,0].set_ylabel('Activation Rate (%)', fontsize=14, fontweight='bold')
axes[1,0].set_title(f'Network Activation (Avg: {avg_activation:.2f}%)', fontsize=16, fontweight='bold')
axes[1,0].legend(fontsize=12)
axes[1,0].grid(True, alpha=0.3)

# Energy savings
avg_energy = df['energy_savings'].mean()
axes[1,1].plot(df['epoch'], df['energy_savings'], linewidth=3, color='#9b59b6')
axes[1,1].fill_between(df['epoch'], df['energy_savings'], alpha=0.3, color='#9b59b6')
axes[1,1].set_xlabel('Epoch', fontsize=14, fontweight='bold')
axes[1,1].set_ylabel('Energy Savings (%)', fontsize=14, fontweight='bold')
axes[1,1].set_title(f'Energy Efficiency (Avg: {avg_energy:.2f}%)', fontsize=16, fontweight='bold')
axes[1,1].grid(True, alpha=0.3)
axes[1,1].set_ylim([0, 100])

plt.tight_layout()
plt.savefig('training_results.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nTraining Summary:")
print(f"  Best Accuracy: {best_acc:.2f}%")
print(f"  Avg Energy Savings: {avg_energy:.2f}%")
print(f"  Avg Activation Rate: {avg_activation:.2f}%")

if 'Normal_acc' in df.columns:
    best_epoch = df['val_acc'].idxmax()
    print(f"\nPer-Class Accuracy (Epoch {df.iloc[best_epoch]['epoch']:.0f}):")
    for cls in CLASSES:
        if f'{cls}_acc' in df.columns:
            acc = df.iloc[best_epoch][f'{cls}_acc']
            print(f"  {cls:12s}: {acc:.2f}%")

In [None]:
# Load trained model
import torch.nn as nn
from torchvision import models, transforms
from collections import OrderedDict

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = models.efficientnet_b2(weights=None)
model.classifier[1] = nn.Linear(1408, 4)

checkpoint_path = 'checkpoints_multiclass_optimized/best.pt'
checkpoint = torch.load(checkpoint_path, map_location=device)

# Extract state dict
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
    print(f"Checkpoint - Epoch: {checkpoint.get('epoch', 'N/A')}, Acc: {checkpoint.get('val_acc', 0):.2f}%")
else:
    state_dict = checkpoint

# Clean state dict
clean_state_dict = OrderedDict()
for key, value in state_dict.items():
    if key.startswith('model.'):
        clean_state_dict[key[6:]] = value
    elif key != 'activation_mask':
        clean_state_dict[key] = value

model.load_state_dict(clean_state_dict, strict=False)
model = model.to(device)
model.eval()

print("✓ Model loaded")

# Transform
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict(img_path):
    img = Image.open(img_path).convert('RGB')
    x = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(x)
        probs = torch.softmax(out, dim=1)[0]
    pred_idx = out.argmax(dim=1).item()
    return CLASSES[pred_idx], float(probs[pred_idx]*100)

In [None]:
# Confusion matrix
from sklearn.metrics import confusion_matrix, classification_report

all_preds, all_labels = [], []

print("Evaluating test set...")
for class_idx, cls in enumerate(CLASSES):
    test_path = data_dir / 'test' / cls
    test_imgs = list(test_path.glob('*.png'))[:100]
    
    for img_path in test_imgs:
        try:
            pred, _ = predict(img_path)
            all_preds.append(CLASSES.index(pred))
            all_labels.append(class_idx)
        except:
            continue

if all_preds:
    print("\nClassification Report:\n")
    print(classification_report(all_labels, all_preds, target_names=CLASSES, digits=3))
    
    cm = confusion_matrix(all_labels, all_preds)
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
               xticklabels=CLASSES, yticklabels=CLASSES,
               cbar_kws={'label': 'Count'},
               annot_kws={'fontsize': 14, 'fontweight': 'bold'})
    ax.set_title('Confusion Matrix', fontsize=18, fontweight='bold', pad=20)
    ax.set_ylabel('True Label', fontsize=14, fontweight='bold')
    ax.set_xlabel('Predicted Label', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("\n✓ Confusion matrix saved")
else:
    print("No predictions - check test images")

In [None]:
# Grad-CAM visualization
import cv2

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        def save_gradient(grad):
            self.gradients = grad
        
        def save_activation(module, input, output):
            self.activations = output.detach()
        
        target_layer.register_forward_hook(save_activation)
        target_layer.register_full_backward_hook(
            lambda m, gi, go: save_gradient(go[0])
        )
    
    def generate(self, input_img):
        output = self.model(input_img)
        pred_class = output.argmax(dim=1)
        
        self.model.zero_grad()
        one_hot = torch.zeros_like(output)
        one_hot[0][pred_class] = 1
        output.backward(gradient=one_hot, retain_graph=True)
        
        if self.gradients is None:
            return None, output
        
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = torch.relu(cam)
        cam = cam.squeeze().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam, output

grad_cam = GradCAM(model, model.features[-1])

# Generate for each class
samples = []
for cls in CLASSES:
    test_path = data_dir / 'test' / cls
    img_files = list(test_path.glob('*.png'))
    if img_files:
        samples.append((img_files[0], cls))

if samples:
    fig, axes = plt.subplots(len(samples), 3, figsize=(15, 4.5*len(samples)))
    if len(samples) == 1:
        axes = axes.reshape(1, -1)
    
    fig.suptitle('Grad-CAM Explainable AI', fontsize=20, fontweight='bold')
    
    for idx, (img_path, true_class) in enumerate(samples):
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)
        
        with torch.set_grad_enabled(True):
            cam, output = grad_cam.generate(img_tensor)
        
        probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
        pred_idx = output.argmax(dim=1).item()
        pred_class = CLASSES[pred_idx]
        confidence = probs[pred_idx] * 100
        
        img_resized = img.resize((224, 224))
        img_array = np.array(img_resized)
        
        if cam is not None:
            cam_resized = cv2.resize(cam, (224, 224))
            heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
            heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
            overlay = img_array * 0.5 + heatmap * 0.5
            overlay = np.clip(overlay, 0, 255).astype(np.uint8)
        else:
            heatmap = np.zeros_like(img_array)
            overlay = img_array
        
        axes[idx, 0].imshow(img_resized)
        axes[idx, 0].set_title(f'Original\n{true_class}', fontsize=12, fontweight='bold')
        axes[idx, 0].axis('off')
        
        axes[idx, 1].imshow(heatmap)
        axes[idx, 1].set_title('Attention Map', fontsize=12, fontweight='bold')
        axes[idx, 1].axis('off')
        
        status = '✓' if pred_class == true_class else '✗'
        color = 'green' if pred_class == true_class else 'red'
        axes[idx, 2].imshow(overlay)
        axes[idx, 2].set_title(f'{status} {pred_class} ({confidence:.1f}%)',
                              fontsize=12, fontweight='bold', color=color)
        axes[idx, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig('gradcam_visualization.png', dpi=300, bbox_inches='tight')
    plt.show()
    print("✓ Grad-CAM saved")
else:
    print("No test images found")

## 5. Export Results

In [None]:
# List and download results
results = [
    'checkpoints_multiclass_optimized/best.pt',
    'checkpoints_multiclass_optimized/metrics_optimized.csv',
    'dataset_distribution.png',
    'training_results.png',
    'confusion_matrix.png',
    'gradcam_visualization.png'
]

print("Generated Files:\n")
for file in results:
    if os.path.exists(file):
        size_mb = os.path.getsize(file) / (1024 * 1024)
        print(f"  ✓ {file} ({size_mb:.2f} MB)")
    else:
        print(f"  ✗ {file}")

if IN_COLAB:
    print("\nDownloading...")
    for file in results:
        if os.path.exists(file):
            try:
                files.download(file)
            except:
                pass
    print("✓ Complete")
else:
    print("\n✓ Files saved locally")

print("\nDeployment: Use gradio_app/app.py with best.pt")
print("Repository: https://github.com/oluwafemidiakhoa/Tuberculosis")