# 3D ResNet Classification with Feature Extraction
## ResNet3D model for volumetric medical image classification

This notebook implements:
- **Model:** ResNet3D with feature extraction hooks
- **Data:** 3D preprocessed medical volumes from Excel metadata
- **Task:** Multi-class classification (5 classes)
- **Features:** CNN feature extraction from intermediate layers
- **Optimization:** Adam optimizer with learning rate scheduling

## 1. Imports and Configuration

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import nibabel as nib
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import json
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Paths and Configuration

In [None]:
# Paths
BASE_DIR = Path('C:/FeatureEx')
PREPROCESSED_DIR = BASE_DIR / 'preprocessed_3d_data'
IMAGES_DIR = PREPROCESSED_DIR / 'images'
LABELS_DIR = PREPROCESSED_DIR / 'labels'
METADATA_FILE = BASE_DIR / 'classification_metadata.xlsx'
MODELS_DIR = BASE_DIR / 'models_3d'
MODELS_DIR.mkdir(exist_ok=True)
FEATURES_DIR = BASE_DIR / 'extracted_features'
FEATURES_DIR.mkdir(exist_ok=True)

# Model configuration
NUM_CLASSES = 5  # 0-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 2
NUM_EPOCHS = 50
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-4

print(f"Base directory: {BASE_DIR}")
print(f"Images directory: {IMAGES_DIR}")
print(f"Metadata file: {METADATA_FILE}")
print(f"Device: {DEVICE}")
print(f"Configuration:")
print(f"  Num classes: {NUM_CLASSES}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")

## 3. Load Metadata from Excel

In [None]:
# Load metadata
metadata_df = pd.read_excel(METADATA_FILE, sheet_name='samples')

print(f"Loaded metadata: {len(metadata_df)} samples")
print(f"\nColumns: {list(metadata_df.columns)}")
print(f"\nFirst 5 rows:")
print(metadata_df.head())
print(f"\nClass distribution:")
print(metadata_df['label'].value_counts().sort_index())
print(f"\nData split distribution:")
print(metadata_df['split'].value_counts())

## 4. 3D Classification Dataset

In [None]:
class Classification3DDataset(Dataset):
    """Dataset for 3D image classification from Excel metadata."""
    
    def __init__(self, images_dir, metadata_df, split='train'):
        self.images_dir = Path(images_dir)
        self.metadata_df = metadata_df[metadata_df['split'] == split].reset_index(drop=True)
        
        # Filter to only valid image files
        valid_indices = []
        for idx, row in self.metadata_df.iterrows():
            # Construct image path from sample_id
            img_path = self.images_dir / f"{row['sample_id']}.nii.gz"
            if img_path.exists():
                valid_indices.append(idx)
        
        self.metadata_df = self.metadata_df.loc[valid_indices].reset_index(drop=True)
        
        print(f"Dataset initialized ({split}):")
        print(f"  Total samples in {split}: {len(self.metadata_df)}")
    
    def __len__(self):
        return len(self.metadata_df)
    
    def __getitem__(self, idx):
        row = self.metadata_df.iloc[idx]
        sample_id = row['sample_id']
        label = int(row['label'])
        
        # Construct image path
        img_path = self.images_dir / f"{sample_id}.nii.gz"
        
        # Load image
        img_nib = nib.load(img_path)
        img_data = img_nib.get_fdata()  # shape: (512, 1024, 32, 2)
        
        # Normalize image
        img_min = img_data.min()
        img_max = img_data.max()
        if img_max > img_min:
            img_normalized = (img_data - img_min) / (img_max - img_min)
        else:
            img_normalized = img_data
        
        # Convert to torch tensor
        # Shape: (channels, depth, height, width) = (2, 32, 512, 1024)
        img_tensor = torch.from_numpy(np.transpose(img_normalized, (3, 2, 0, 1))).float()
        
        # Label tensor
        label_tensor = torch.tensor(label, dtype=torch.long)
        
        return img_tensor, label_tensor, sample_id

# Create datasets
train_dataset = Classification3DDataset(IMAGES_DIR, metadata_df, split='train')
val_dataset = Classification3DDataset(IMAGES_DIR, metadata_df, split='val')
test_dataset = Classification3DDataset(IMAGES_DIR, metadata_df, split='test')

print(f"\nTotal dataset sizes:")
print(f"  Train: {len(train_dataset)}")
print(f"  Val: {len(val_dataset)}")
print(f"  Test: {len(test_dataset)}")

## 5. Data Loaders

In [None]:
# Create loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"DataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

# Test a batch
batch_images, batch_labels, batch_ids = next(iter(train_loader))
print(f"\nSample batch:")
print(f"  Images shape: {batch_images.shape}")
print(f"  Labels shape: {batch_labels.shape}")
print(f"  Sample IDs: {batch_ids}")

## 6. ResNet3D Classification Model with Feature Extraction Hooks

In [None]:
class ResNet3DBlock(nn.Module):
    """3D Residual block for ResNet."""
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        
        # Skip connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm3d(out_channels)
            )
    
    def forward(self, x):
        residual = self.shortcut(x)
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += residual
        out = self.relu(out)
        
        return out


class ResNet3DClassification(nn.Module):
    """3D ResNet for classification with feature extraction hooks."""
    
    def __init__(self, in_channels=2, num_classes=5):
        super().__init__()
        
        # Initial convolution
        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        
        # Residual blocks
        self.layer1 = self._make_layer(64, 64, 3, stride=1)
        self.layer2 = self._make_layer(64, 128, 4, stride=2)
        self.layer3 = self._make_layer(128, 256, 6, stride=2)
        self.layer4 = self._make_layer(256, 512, 3, stride=2)
        
        # Global average pooling
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        
        # Classification head
        self.fc = nn.Linear(512, num_classes)
        
        # Feature extraction hooks storage
        self.features = {}
        self._register_hooks()
    
    def _make_layer(self, in_channels, out_channels, blocks, stride):
        layers = []
        layers.append(ResNet3DBlock(in_channels, out_channels, stride))
        for _ in range(1, blocks):
            layers.append(ResNet3DBlock(out_channels, out_channels, stride=1))
        return nn.Sequential(*layers)
    
    def _register_hooks(self):
        """Register forward hooks to extract features from intermediate layers."""
        def hook_fn(name):
            def hook(module, input, output):
                # Store flattened features
                if output.dim() > 2:
                    self.features[name] = output.view(output.size(0), -1)
                else:
                    self.features[name] = output
            return hook
        
        # Register hooks on intermediate layers
        self.layer1.register_forward_hook(hook_fn('layer1'))
        self.layer2.register_forward_hook(hook_fn('layer2'))
        self.layer3.register_forward_hook(hook_fn('layer3'))
        self.layer4.register_forward_hook(hook_fn('layer4'))
        self.avgpool.register_forward_hook(hook_fn('avgpool'))
    
    def forward(self, x):
        # Clear features dict
        self.features = {}
        
        # Initial layers
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # Residual blocks
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Global pooling
        x = self.avgpool(x)
        
        # Flatten and classify
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x
    
    def get_features(self):
        """Return extracted intermediate features."""
        return self.features


# Create model
model = ResNet3DClassification(in_channels=2, num_classes=NUM_CLASSES).to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model created:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Device: {DEVICE}")

## 7. Loss Function and Optimizer

In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    min_lr=1e-6
)

print(f"Loss function: CrossEntropyLoss")
print(f"Optimizer: Adam")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Learning rate scheduler: ReduceLROnPlateau")

## 8. Training Loop

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for batch_idx, (images, labels, _) in enumerate(loader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        # Get predictions
        preds = outputs.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())
        
        if (batch_idx + 1) % 10 == 0:
            print(f"  Batch {batch_idx + 1}/{len(loader)}, Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / len(loader)
    accuracy = accuracy_score(all_labels, all_preds)
    
    return avg_loss, accuracy


def validate(model, loader, criterion, device):
    """Validation loop."""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels, _ in loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            
            # Get predictions
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(loader)
    accuracy = accuracy_score(all_labels, all_preds)
    
    return avg_loss, accuracy


print("Training functions defined.")

## 9. Train Model

In [None]:
# Training history
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'learning_rates': []}
best_val_loss = float('inf')
best_epoch = 0

print(f"Starting training for {NUM_EPOCHS} epochs...\n")

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    print(f"  Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, DEVICE)
    print(f"  Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # History
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['learning_rates'].append(current_lr)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch
        torch.save(model.state_dict(), MODELS_DIR / 'best_resnet3d_classification.pth')
        print(f"  ** Best model saved (Epoch {epoch + 1}) **")
    
    print(f"  Learning rate: {current_lr:.2e}\n")

print(f"Training complete!")
print(f"Best epoch: {best_epoch + 1} with val_loss: {best_val_loss:.4f}")

## 10. Evaluate on Test Set

In [None]:
# Load best model
model.load_state_dict(torch.load(MODELS_DIR / 'best_resnet3d_classification.pth'))
model.eval()

test_loss, test_acc = validate(model, test_loader, criterion, DEVICE)

print(f"Test Results:")
print(f"  Test Loss: {test_loss:.4f}")
print(f"  Test Accuracy: {test_acc:.4f}")

## 11. Extract Features from All Splits

In [None]:
def extract_features_from_loader(model, loader, device, feature_types=['layer4', 'avgpool']):
    """Extract features from model for all samples in loader."""
    model.eval()
    
    all_features = {ft: [] for ft in feature_types}
    all_labels = []
    all_ids = []
    
    with torch.no_grad():
        for images, labels, sample_ids in loader:
            images = images.to(device)
            
            # Forward pass (triggers hooks)
            _ = model(images)
            
            # Get extracted features
            features = model.get_features()
            
            for ft in feature_types:
                if ft in features:
                    all_features[ft].append(features[ft].cpu())
            
            all_labels.extend(labels.numpy())
            all_ids.extend(sample_ids)
    
    # Concatenate features
    for ft in feature_types:
        if all_features[ft]:
            all_features[ft] = torch.cat(all_features[ft], dim=0).numpy()
    
    return all_features, np.array(all_labels), all_ids


# Extract features
print("Extracting features from training set...")
train_features, train_labels, train_ids = extract_features_from_loader(
    model, train_loader, DEVICE, feature_types=['layer4', 'avgpool']
)

print("Extracting features from validation set...")
val_features, val_labels, val_ids = extract_features_from_loader(
    model, val_loader, DEVICE, feature_types=['layer4', 'avgpool']
)

print("Extracting features from test set...")
test_features, test_labels, test_ids = extract_features_from_loader(
    model, test_loader, DEVICE, feature_types=['layer4', 'avgpool']
)

print(f"\nFeature extraction complete:")
for ft in ['layer4', 'avgpool']:
    print(f"  {ft} shape (train): {train_features[ft].shape}")

## 12. Save Extracted Features

In [None]:
import pickle

# Save features for each split
for split, features, labels, ids in [
    ('train', train_features, train_labels, train_ids),
    ('val', val_features, val_labels, val_ids),
    ('test', test_features, test_labels, test_ids)
]:
    features_data = {
        'layer4': features['layer4'],
        'avgpool': features['avgpool'],
        'labels': labels,
        'sample_ids': ids
    }
    
    pkl_path = FEATURES_DIR / f'resnet3d_features_{split}.pkl'
    with open(pkl_path, 'wb') as f:
        pickle.dump(features_data, f)
    
    print(f"Saved {split} features to {pkl_path.name}")

print(f"\nFeatures saved to {FEATURES_DIR}")

## 13. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 4))

# Loss curve
axes[0].plot(history['train_loss'], label='Training Loss', marker='o', markersize=3)
axes[0].plot(history['val_loss'], label='Validation Loss', marker='s', markersize=3)
axes[0].axvline(x=best_epoch, color='r', linestyle='--', alpha=0.5, label=f'Best Epoch ({best_epoch + 1})')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy curve
axes[1].plot(history['train_acc'], label='Training Accuracy', marker='o', markersize=3)
axes[1].plot(history['val_acc'], label='Validation Accuracy', marker='s', markersize=3)
axes[1].axvline(x=best_epoch, color='r', linestyle='--', alpha=0.5)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning rate
axes[2].plot(history['learning_rates'], label='Learning Rate', color='orange')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
axes[2].set_yscale('log')

plt.tight_layout()
plt.savefig(MODELS_DIR / 'training_history_classification_3d.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Training history plot saved to {MODELS_DIR / 'training_history_classification_3d.png'}")

## 14. Save Training Metrics

In [None]:
# Save metrics
metrics = {
    'best_epoch': best_epoch + 1,
    'best_val_loss': float(best_val_loss),
    'best_val_acc': float(history['val_acc'][best_epoch]),
    'final_test_loss': float(test_loss),
    'final_test_acc': float(test_acc),
    'num_epochs_trained': NUM_EPOCHS,
    'num_classes': NUM_CLASSES,
    'input_shape': [2, 32, 512, 1024],
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'total_parameters': int(total_params),
    'training_samples': len(train_dataset),
    'validation_samples': len(val_dataset),
    'test_samples': len(test_dataset)
}

with open(MODELS_DIR / 'metrics_classification_3d.json', 'w') as f:
    json.dump(metrics, f, indent=2)

print(f"Metrics saved:")
for key, val in metrics.items():
    print(f"  {key}: {val}")

## 15. Summary

In [None]:
print("\n" + "="*70)
print("3D RESNET CLASSIFICATION - COMPLETE")
print("="*70)

print(f"\nModel Performance:")
print(f"  Best Val Loss: {metrics['best_val_loss']:.4f} (Epoch {metrics['best_epoch']})")
print(f"  Best Val Acc: {metrics['best_val_acc']:.4f}")
print(f"  Test Loss: {metrics['final_test_loss']:.4f}")
print(f"  Test Acc: {metrics['final_test_acc']:.4f}")

print(f"\nExtracted Features:")
print(f"  Layer4 features shape: {train_features['layer4'].shape}")
print(f"  AvgPool features shape: {train_features['avgpool'].shape}")

print(f"\nOutput Files:")
print(f"  Model: {MODELS_DIR / 'best_resnet3d_classification.pth'}")
print(f"  Metrics: {MODELS_DIR / 'metrics_classification_3d.json'}")
print(f"  Training history plot: {MODELS_DIR / 'training_history_classification_3d.png'}")
print(f"  Features (train): {FEATURES_DIR / 'resnet3d_features_train.pkl'}")
print(f"  Features (val): {FEATURES_DIR / 'resnet3d_features_val.pkl'}")
print(f"  Features (test): {FEATURES_DIR / 'resnet3d_features_test.pkl'}")

print(f"\nNext Steps:")
print(f"  1. Use extracted features with classifiers (SVM, RFC, etc.)")
print(f"  2. Combine with radiomic features for multi-modal analysis")
print(f"  3. Perform ablation studies on intermediate layers")