In [None]:
import os
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
import math
warnings.filterwarnings('ignore')

Set device and config

In [None]:

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU Count: {torch.cuda.device_count()}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

# Speed-optimized configuration for ~5 min per epoch
CONFIG = {
    'data_path': '/kaggle/input/nii-larger-dataset/DATASET_NIFTI2',
    'image_size': 224,      # Reduced from 260 for faster processing
    'num_slices': 16,       # Reduced from 20 to decrease total samples
    'batch_size': 32,       # Increased from 16 for fewer iterations
    'num_epochs': 35,
    'learning_rate': 3e-4,
    'num_classes': 3,
    'patience': 12,
    'weight_decay': 0.01,
    'label_smoothing': 0.1,
    'T_max': 10,
}

class_names = ['AD', 'CN', 'MCI']

In [None]:
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing
    
    def forward(self, x, target):
        confidence = 1. - self.smoothing
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()


In [None]:

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(channels, channels // reduction, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(channels // reduction, channels, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [None]:
class EnhancedMRISliceDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        path, label, slice_idx = self.samples[idx]
        
        img = nib.load(path).get_fdata()
        
        if slice_idx < img.shape[2]:
            slice_2d = img[:, :, slice_idx]
        else:
            slice_2d = img[:, :, img.shape[2]//2]
        
        # Enhanced normalization for intensity. 
        p2, p98 = np.percentile(slice_2d, (2, 98))
        slice_2d = np.clip(slice_2d, p2, p98)
        slice_2d = (slice_2d - slice_2d.min()) / (slice_2d.max() - slice_2d.min() + 1e-8)
        
        slice_2d = np.stack([slice_2d, slice_2d, slice_2d], axis=2)
        slice_2d = (slice_2d * 255).astype(np.uint8)
        
        if self.transform:
            slice_2d = self.transform(slice_2d)
        
        return slice_2d, label


In [None]:
def extract_key_slices(nii_path, num_slices=16):
    img = nib.load(nii_path).get_fdata()
    total_slices = img.shape[2]
    
    # Optimized slice selection for faster processing
    start_slice = int(total_slices * 0.15)
    end_slice = int(total_slices * 0.85)
    
    slice_indices = np.linspace(start_slice, end_slice-1, num_slices, dtype=int)
    return slice_indices


In [None]:

def create_dataset():
    print("Creating enhanced dataset...")
    
    all_samples = []
    class_counts = {}
    
    for class_idx, class_name in enumerate(class_names):
        class_path = os.path.join(CONFIG['data_path'], class_name)
        nii_files = [f for f in os.listdir(class_path) if f.endswith('.nii')]
        
        print(f"Processing {len(nii_files)} files for {class_name}")
        class_counts[class_name] = len(nii_files)
        
        for file in tqdm(nii_files, desc=f"Loading {class_name}"):
            file_path = os.path.join(class_path, file)
            
            try:
                slice_indices = extract_key_slices(file_path, CONFIG['num_slices'])
                
                for slice_idx in slice_indices:
                    all_samples.append((file_path, class_idx, slice_idx))
                    
            except Exception as e:
                print(f"Error processing {file}: {e}")
                continue
    
    print(f"Total samples: {len(all_samples)}")
    print(f"Class distribution: {class_counts}")
    return all_samples


In [None]:

class EfficientNetB2WithSE(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        
        # Load pre-trained EfficientNetB2  
        try:
            from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
            self.backbone = efficientnet_b2(weights=EfficientNet_B2_Weights.IMAGENET1K_V1)
        except ImportError:
            self.backbone = models.efficientnet_b2(pretrained=True)
        
        # Get features
        num_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()
        
        # Lighter SE block for speed
        self.se_block = SEBlock(num_features, reduction=8)  # Reduced from 16
        
        # Simplified classifier for faster training
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(num_features, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        # Extract features
        features = self.backbone.features(x)
        
        # Global pooling
        features = nn.AdaptiveAvgPool2d(1)(features)
        
        # SE attention
        features = self.se_block(features)
        
        # Flatten and classify
        features = features.flatten(1)
        return self.classifier(features)


In [None]:
class CosineAnnealingWarmupRestarts(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, T_max, eta_min=0, T_mult=1, last_epoch=-1):
        self.T_max = T_max
        self.eta_min = eta_min
        self.T_mult = T_mult
        self.T_cur = 0
        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        return [self.eta_min + (base_lr - self.eta_min) * 
                (1 + math.cos(math.pi * self.T_cur / self.T_max)) / 2
                for base_lr in self.base_lrs]
    
    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.T_cur = epoch % self.T_max
        super(CosineAnnealingWarmupRestarts, self).step(epoch)

In [None]:

def train_model(model, train_loader, val_loader):
    model = model.to(device)
    
    criterion = LabelSmoothingCrossEntropy(smoothing=CONFIG['label_smoothing'])
    optimizer = optim.AdamW(model.parameters(), 
                           lr=CONFIG['learning_rate'], 
                           weight_decay=CONFIG['weight_decay'])
    
    scheduler = CosineAnnealingWarmupRestarts(optimizer, T_max=CONFIG['T_max'])
    
    best_val_acc = 0
    patience_counter = 0
    
    train_accs = []
    val_accs = []
    train_losses = []
    
    for epoch in range(CONFIG['num_epochs']):
        # Training
        model.train()
        train_correct = 0
        train_total = 0
        train_loss = 0
        
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        train_acc = 100.0 * train_correct / train_total
        avg_train_loss = train_loss / len(train_loader)
        
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        val_loss = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_acc = 100.0 * val_correct / val_total
        avg_val_loss = val_loss / len(val_loader)
        
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        train_losses.append(avg_train_loss)
        
        print(f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
        print(f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), 'best_alzheimer_efficientnet.pth')
            print(f"✓ New best model! Val Acc: {best_val_acc:.2f}%")
        else:
            patience_counter += 1
        
        if patience_counter >= CONFIG['patience']:
            print(f"Early stopping at epoch {epoch+1}")
            break
        
        scheduler.step()
        print("-" * 60)
    
    model.load_state_dict(torch.load('best_alzheimer_efficientnet.pth'))
    return train_accs, val_accs, train_losses


In [None]:

def evaluate_per_volume(model, test_samples):
    model.eval()
    
    volume_predictions = {}
    volume_labels = {}
    
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    with torch.no_grad():
        for path, label, slice_idx in tqdm(test_samples, desc="Evaluating"):
            img = nib.load(path).get_fdata()
            if slice_idx < img.shape[2]:
                slice_2d = img[:, :, slice_idx]
            else:
                slice_2d = img[:, :, img.shape[2]//2]
            
            p2, p98 = np.percentile(slice_2d, (2, 98))
            slice_2d = np.clip(slice_2d, p2, p98)
            slice_2d = (slice_2d - slice_2d.min()) / (slice_2d.max() - slice_2d.min() + 1e-8)
            slice_2d = np.stack([slice_2d, slice_2d, slice_2d], axis=2)
            slice_2d = (slice_2d * 255).astype(np.uint8)
            
            slice_tensor = transform(slice_2d).unsqueeze(0).to(device)
            
            output = model(slice_tensor)
            prob = torch.softmax(output, dim=1).cpu().numpy()[0]
            
            if path not in volume_predictions:
                volume_predictions[path] = []
                volume_labels[path] = label
            
            volume_predictions[path].append(prob)
    
    final_predictions = []
    final_labels = []
    
    for path in volume_predictions:
        avg_prob = np.mean(volume_predictions[path], axis=0)
        pred_class = np.argmax(avg_prob)
        
        final_predictions.append(pred_class)
        final_labels.append(volume_labels[path])
    
    return final_predictions, final_labels


In [None]:

def main():
    print("=" * 60)
    print("🧠 ENHANCED ALZHEIMER'S CLASSIFICATION - BALANCED DATASET (TARGET: >90%)")
    print("=" * 60)
    
    all_samples = create_dataset()
    
    train_samples, test_samples = train_test_split(all_samples, test_size=0.2, random_state=42)
    train_samples, val_samples = train_test_split(train_samples, test_size=0.2, random_state=42)
    
    print(f"Train: {len(train_samples)}, Val: {len(val_samples)}, Test: {len(test_samples)}")
    
    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    train_dataset = EnhancedMRISliceDataset(train_samples, train_transform)
    val_dataset = EnhancedMRISliceDataset(val_samples, val_transform)
    
    # Faster data loading with more workers
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                             shuffle=True, num_workers=4, pin_memory=True, 
                             persistent_workers=True, prefetch_factor=2)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size']*2, 
                           shuffle=False, num_workers=4, pin_memory=True,
                           persistent_workers=True, prefetch_factor=2)
    
    model = EfficientNetB2WithSE(CONFIG['num_classes'])
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    print("Starting enhanced training...")
    train_accs, val_accs, train_losses = train_model(model, train_loader, val_loader)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].plot(train_accs, label='Train', color='blue')
    axes[0].plot(val_accs, label='Validation', color='red')
    axes[0].set_title('Accuracy Progress')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Accuracy (%)')
    axes[0].legend()
    axes[0].grid(True)
    
    axes[1].plot(train_losses, label='Train Loss', color='green')
    axes[1].set_title('Training Loss')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss')
    axes[1].legend()
    axes[1].grid(True)
    
    axes[2].axhline(y=90, color='red', linestyle='--', label='90% Target')
    axes[2].plot(val_accs, label='Validation Accuracy', color='red')
    axes[2].set_title('Validation Accuracy vs Target')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Accuracy (%)')
    axes[2].legend()
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    print("Enhanced evaluation on test set...")
    test_predictions, test_labels = evaluate_per_volume(model, test_samples)
    
    accuracy = accuracy_score(test_labels, test_predictions)
    print(f"\n🎯 ENHANCED RESULTS")
    print(f"Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    
    if accuracy >= 0.90:
        print("🎉 TARGET ACHIEVED: >90% ACCURACY!")
    else:
        print(f"📊 Progress: {accuracy*100:.2f}% (Target: 90%)")
    
    print("\nDetailed Classification Report:")
    print(classification_report(test_labels, test_predictions, target_names=class_names))
    
    return model, accuracy


In [None]:
if __name__ == "__main__":
    model, accuracy = main()
    print(f"\n🚀 Enhanced training complete! Final accuracy: {accuracy*100:.2f}%")
    
    if accuracy >= 0.90:
        print("✅ MISSION ACCOMPLISHED: >90% ACCURACY ACHIEVED!")
    else:
        print(f"📈 Current: {accuracy*100:.2f}% | Target: 90%")