# 🫁 PneuNet - Chest X-ray Multi-Label Classification

**Dataset:** CheXpert 3-class (Cardiomegaly, Edema, Pneumothorax) - No Finding Excluded

**Model:** PneuNet (ResNet-18 + Transformer Encoder)

**Architecture:**
- ResNet-18 backbone for feature extraction
- Transformer encoder for spatial reasoning (6 layers, 8 heads)
- Deep MLP classifier

**Outputs:**
- 3 Binary Classification Reports (one per class)
- 3 Binary Confusion Matrices (one per class)
- Training/Validation Loss & Accuracy Curves
- ROC Curves for all 3 classes

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

# Torchvision
import torchvision.transforms as transforms
import torchvision.models as models

# Sklearn metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    classification_report, confusion_matrix, 
    roc_curve, auc, accuracy_score, f1_score,
    precision_score, recall_score
)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## ⚙️ Configuration

In [None]:
class Config:
    # Paths - UPDATE THIS TO YOUR DATASET NAME
    DATA_DIR = "/kaggle/input/chest-xray-4class-100k"  # Your uploaded dataset
    OUTPUT_DIR = "/kaggle/working"
    
    # Model
    MODEL_NAME = "pneunet"
    NUM_CLASSES = 3
    LABELS = ["Cardiomegaly", "Edema", "Pneumothorax"]
    
    # PneuNet specific
    FREEZE_BACKBONE = False  # Set to True to freeze ResNet backbone
    INPUT_CHANNELS = 3  # RGB images
    USE_80_TOKENS = True  # True for 10x8 tokens, False for 7x7 tokens
    
    # Training
    BATCH_SIZE = 32
    NUM_EPOCHS = 25
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    
    # Image
    IMG_SIZE = 224
    
    # Device
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Random seed
    SEED = 42

config = Config()

# Set seeds for reproducibility
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config.SEED)

print(f"🖥️ Device: {config.DEVICE}")
print(f"🤖 Model: {config.MODEL_NAME}")
print(f"🏷️ Labels: {config.LABELS}")
print(f"🔧 Tokens: {'80 (10x8)' if config.USE_80_TOKENS else '49 (7x7)'}")

## 📦 Dataset Class

In [None]:
class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, data_dir, transform=None):
        self.dataframe = dataframe
        self.data_dir = data_dir
        self.transform = transform
        self.labels = config.LABELS
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        img_path = os.path.join(self.data_dir, row['new_path'])
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Get multi-label target
        label = torch.FloatTensor([row[label] for label in self.labels])
        
        return image, label

## 🎨 Data Augmentation

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## 📊 Load CheXpert Dataset

In [None]:
metadata_path = os.path.join(config.DATA_DIR, 'metadata.csv')

if os.path.exists(metadata_path):
    df = pd.read_csv(metadata_path)
    print(f"📁 Total samples: {len(df)}")
    print(f"\n📋 Label distribution:")
    for label in config.LABELS:
        if label in df.columns:
            count = (df[label] == 1.0).sum()
            print(f"  {label}: {count} ({count/len(df)*100:.2f}%)")
else:
    print("⚠️ Metadata file not found. Creating dummy data structure for compilation check.")
    df = pd.DataFrame(columns=['new_path'] + config.LABELS)

In [None]:
# Split data
if len(df) > 0:
    train_df, val_df = train_test_split(
        df, test_size=0.2, random_state=config.SEED, 
        stratify=None
    )
else:
    train_df, val_df = df, df

print(f"🏋️ Train samples: {len(train_df)}")
print(f"🧪 Validation samples: {len(val_df)}")

# Create datasets
train_dataset = ChestXrayDataset(train_df, config.DATA_DIR, train_transform)
val_dataset = ChestXrayDataset(val_df, config.DATA_DIR, val_transform)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=config.BATCH_SIZE, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=config.BATCH_SIZE, 
    shuffle=False, 
    num_workers=2,
    pin_memory=True
)

print(f"\n✅ DataLoaders created!")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")

## 🤖 PneuNet Model Definition

In [None]:
class PneuNet(nn.Module):
    def __init__(self, num_classes=3, freeze_backbone=False, input_channels=3, use_80_tokens=True):
        """
        PneuNet: ResNet-18 + Transformer Encoder for medical image classification
        
        Parameters:
          - num_classes (int): Number of output classes.
          - freeze_backbone (bool): If True, freeze ResNet backbone parameters.
          - input_channels (int): Number of channels in input images (1 for grayscale, 3 for RGB).
          - use_80_tokens (bool): If True, use 10x8 tokens; otherwise, use 7x7 tokens.
        """
        super(PneuNet, self).__init__()
        
        # Load pretrained ResNet-18 backbone
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # Modify first conv layer if input channels != 3
        if input_channels != 3:
            resnet.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        # Truncate ResNet to get spatial feature map from layer4
        self.backbone = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,  # output: [b, 512, 7, 7]
        )
        
        # Optionally freeze backbone
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
        
        # BatchNorm after backbone
        self.post_resnet_bn = nn.BatchNorm2d(512)
        
        # Adaptive pooling for token configuration
        if use_80_tokens:
            self.adaptive_pool = nn.AdaptiveAvgPool2d((10, 8))
        else:
            self.adaptive_pool = nn.Identity()
        
        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=512,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            activation='relu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        
        # Layer Normalization
        self.layer_norm = nn.LayerNorm(512)
        
        # Determine number of tokens
        token_count = 10 * 8 if use_80_tokens else 7 * 7
        
        # MLP Classifier
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(token_count * 512, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(16, num_classes)
        )
    
    def forward(self, x):
        # Extract features with ResNet
        features = self.backbone(x)  # [b, 512, 7, 7]
        
        # Apply BatchNorm
        features = self.post_resnet_bn(features)
        
        # Adapt spatial dimensions
        features = self.adaptive_pool(features)
        
        # Reshape to sequence for Transformer
        b, c, h, w = features.shape
        features = features.view(b, c, h * w).permute(0, 2, 1)  # [b, tokens, 512]
        
        # Transformer encoding
        transformed = self.transformer_encoder(features)
        
        # Layer normalization
        transformed = self.layer_norm(transformed)
        
        # Flatten for MLP
        transformed = transformed.view(b, -1)
        
        # Classification
        out = self.mlp(transformed)
        return out

# Create model
model = PneuNet(
    num_classes=config.NUM_CLASSES,
    freeze_backbone=config.FREEZE_BACKBONE,
    input_channels=config.INPUT_CHANNELS,
    use_80_tokens=config.USE_80_TOKENS
)
model = model.to(config.DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"📊 Total parameters: {total_params:,}")
print(f"📊 Trainable parameters: {trainable_params:,}")

## 🎯 Training Setup

In [None]:
# Loss function, optimizer, scheduler
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
scheduler = CosineAnnealingLR(optimizer, T_max=config.NUM_EPOCHS, eta_min=1e-6)

print("✅ Training setup complete!")
print(f"  Loss: BCEWithLogitsLoss")
print(f"  Optimizer: AdamW (lr={config.LEARNING_RATE})")
print(f"  Scheduler: CosineAnnealingLR")

## 🔄 Training & Validation Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for images, labels in tqdm(loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        all_preds.append(torch.sigmoid(outputs).detach().cpu())
        all_labels.append(labels.cpu())
    
    avg_loss = running_loss / len(loader)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    accuracy = ((all_preds > 0.5) == all_labels).float().mean().item()
    
    return avg_loss, accuracy

def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validating", leave=False):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            all_preds.append(torch.sigmoid(outputs).cpu())
            all_labels.append(labels.cpu())
    
    avg_loss = running_loss / len(loader)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    accuracy = ((all_preds > 0.5) == all_labels).float().mean().item()
    
    return avg_loss, accuracy, all_preds, all_labels

## 🚀 Training Loop

In [None]:
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

best_val_loss = float('inf')
best_model_path = os.path.join(config.OUTPUT_DIR, 'pneunet_chestxray.pth')

print("🏋️ Starting training...\n")

for epoch in range(config.NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{config.NUM_EPOCHS}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, config.DEVICE)
    
    # Validate
    val_loss, val_acc, val_preds, val_labels = validate_epoch(model, val_loader, criterion, config.DEVICE)
    
    # Update scheduler
    scheduler.step()
    
    # Store history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}\n")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc
        }, best_model_path)
        print(f"  💾 Saved best model (val_loss: {val_loss:.4f})\n")

print("✅ Training complete!")

## 📈 Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train Accuracy', linewidth=2)
axes[1].plot(history['val_acc'], label='Val Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_title('Training & Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(config.OUTPUT_DIR, 'pneunet_training_curves.png'), dpi=300, bbox_inches='tight')
plt.show()
print("✅ Saved: pneunet_training_curves.png")

## 📊 Evaluation - Load Best Model

In [None]:
# Load best model
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✅ Loaded best model from epoch {checkpoint['epoch']}")
print(f"   Val Loss: {checkpoint['val_loss']:.4f}")
print(f"   Val Acc: {checkpoint['val_acc']:.4f}")

# Get predictions
_, val_acc, all_probs, all_labels = validate_epoch(model, val_loader, criterion, config.DEVICE)
all_preds = (all_probs > 0.5).float()

all_preds = all_preds.numpy()
all_probs = all_probs.numpy()
all_labels = all_labels.numpy()

print(f"\n📊 Predictions shape: {all_preds.shape}")
print(f"📊 Labels shape: {all_labels.shape}")

## 📋 Per-Class Classification Reports

In [None]:
reports_data = []

for i, label in enumerate(config.LABELS):
    print(f"\n{'='*60}")
    print(f"📊 {label}")
    print(f"{'='*60}")
    
    y_true = all_labels[:, i]
    y_pred = all_preds[:, i]
    
    # Classification report
    report = classification_report(y_true, y_pred, target_names=['Negative', 'Positive'], output_dict=True)
    print(classification_report(y_true, y_pred, target_names=['Negative', 'Positive']))
    
    # Calculate metrics
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    
    reports_data.append({
        'Label': label,
        'Accuracy': acc,
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1
    })

# Save to CSV
reports_df = pd.DataFrame(reports_data)
reports_df.to_csv(os.path.join(config.OUTPUT_DIR, 'pneunet_classification_reports.csv'), index=False)
print(f"\n✅ Saved: pneunet_classification_reports.csv")
print(f"\n{reports_df}")

## 🔲 Confusion Matrices

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
colors = ['#e74c3c', '#3498db', '#2ecc71']

for i, label in enumerate(config.LABELS):
    y_true = all_labels[:, i]
    y_pred = all_preds[:, i]
    
    cm = confusion_matrix(y_true, y_pred)
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Negative', 'Positive'],
                yticklabels=['Negative', 'Positive'],
                ax=axes[i], cbar=True, square=True)
    
    axes[i].set_xlabel('Predicted', fontsize=11)
    axes[i].set_ylabel('Actual', fontsize=11)
    axes[i].set_title(f'🏷️ {label}', fontsize=13, fontweight='bold', color=colors[i])

plt.suptitle('PneuNet - Confusion Matrices', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(config.OUTPUT_DIR, 'pneunet_confusion_matrices.png'), dpi=300, bbox_inches='tight')
plt.show()
print("✅ Saved: pneunet_confusion_matrices.png")

## 📊 Metrics Comparison

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
colors = ['#e74c3c', '#3498db', '#2ecc71']
metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score']

for i, label in enumerate(config.LABELS):
    values = reports_df[reports_df['Label'] == label][metrics].values[0]
    
    bars = axes[i].bar(metrics, values, color=colors[i], alpha=0.7, edgecolor='black')
    axes[i].set_ylim([0, 1])
    axes[i].set_ylabel('Score', fontsize=11)
    axes[i].set_title(f'🏷️ {label}', fontsize=13, fontweight='bold')
    axes[i].grid(True, axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        axes[i].text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.3f}', ha='center', va='bottom', fontsize=10)

plt.suptitle('PneuNet - Metrics Comparison', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(config.OUTPUT_DIR, 'pneunet_metrics_comparison.png'), dpi=300, bbox_inches='tight')
plt.show()
print("✅ Saved: pneunet_metrics_comparison.png")

## 📈 ROC Curves

In [None]:
if len(all_preds) > 0:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    colors = ['#e74c3c', '#3498db', '#2ecc71']

    for i, label in enumerate(config.LABELS):
        class_probs = all_probs[:, i]
        class_labels = all_labels[:, i]
        
        # Compute ROC curve
        fpr, tpr, _ = roc_curve(class_labels, class_probs)
        roc_auc = auc(fpr, tpr)
        
        # Plot
        axes[i].plot(fpr, tpr, color=colors[i], linewidth=2, 
                     label=f'ROC curve (AUC = {roc_auc:.4f})')
        axes[i].plot([0, 1], [0, 1], 'k--', linewidth=1, alpha=0.5)
        axes[i].fill_between(fpr, tpr, alpha=0.2, color=colors[i])
        axes[i].set_xlim([0.0, 1.0])
        axes[i].set_ylim([0.0, 1.05])
        axes[i].set_xlabel('False Positive Rate', fontsize=12)
        axes[i].set_ylabel('True Positive Rate', fontsize=12)
        axes[i].set_title(f'🏷️ {label}\nROC Curve', fontsize=14, fontweight='bold')
        axes[i].legend(loc='lower right', fontsize=11)
        axes[i].grid(True, alpha=0.3)

    plt.suptitle('PneuNet - ROC Curves', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(os.path.join(config.OUTPUT_DIR, 'pneunet_roc_curves.png'), dpi=300, bbox_inches='tight')
    plt.show()
    print("✅ Saved: pneunet_roc_curves.png")

## 🎉 Final Summary

In [None]:
print("="*60)
print("🎉 TRAINING COMPLETE! - PneuNet")
print("="*60)

if len(all_preds) > 0:
    print(f"""
    🤖 Model: PneuNet (ResNet-18 + Transformer)
    📊 Dataset: CheXpert (3-class: {', '.join(config.LABELS)})
    🏋️ Training samples: {len(train_df)}
    🧪 Validation samples: {len(val_df)}

    📉 Best Validation Loss: {best_val_loss:.4f}
    📈 Final Validation Accuracy: {val_acc:.4f}

    📁 Output Files:
      📊 pneunet_training_curves.png
      📊 pneunet_confusion_matrices.png
      📊 pneunet_metrics_comparison.png
      📊 pneunet_roc_curves.png
      📄 pneunet_classification_reports.csv
      🔧 pneunet_chestxray.pth

    📋 Per-Class Performance:
    """)

    for _, row in reports_df.iterrows():
        print(f"  {row['Label']}: Acc={row['Accuracy']:.4f} | F1={row['F1-Score']:.4f}")

    print("\n✅ All outputs saved successfully!")
else:
    print("⚠️ No results to report.")