In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Configuration
DATA_ROOTS = [
    r"/home/nigmu/NPersonal/Projects/SDP/nigmu-parkinsons_disease_prediction/Dataset/MDVR",
    r"/home/nigmu/NPersonal/Projects/SDP/nigmu-parkinsons_disease_prediction/Dataset/Italian"
]
BATCH_SIZE = 8
NUM_WORKERS = 0  # Optimize based on CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Custom Dataset Class
class ParkinsonSpectrogramDataset(Dataset):
    def __init__(self, root_dirs, transform=None):
        if isinstance(root_dirs, str):  # If a single path is given, convert it to a list
            root_dirs = [root_dirs]
        self.root_dirs = root_dirs
        self.transform = transform
        self.samples = self._load_samples()
        
    def _load_samples(self):
        samples = []
        for root_dir in self.root_dirs:
            for class_name in ['HC', 'PD']:
                class_dir = os.path.join(root_dir, class_name)
                if not os.path.exists(class_dir):
                    continue

                # Traverse subfolders inside HC/PD
                for patient_folder in os.listdir(class_dir):
                    patient_path = os.path.join(class_dir, patient_folder)
                    if os.path.isdir(patient_path):  # Ensure it's a directory
                        for img_file in os.listdir(patient_path):
                            if img_file.lower().endswith('.png'):  # Only PNG images
                                img_path = os.path.join(patient_path, img_file)
                                samples.append((img_path, 0 if class_name == 'HC' else 1))

        print(f"✅ Loaded {len(samples)} samples from {self.root_dirs}")
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert('L')  # Convert to grayscale
        if self.transform:
            img = self.transform(img)
        return img, label

# Data Transforms
train_transform = transforms.Compose([
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Only slight translations
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize between -1 and 1
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Create Datasets (Loading from both MDVR & Italian datasets)
train_dataset = ParkinsonSpectrogramDataset([os.path.join(root, 'train') for root in DATA_ROOTS], transform=train_transform)
val_dataset = ParkinsonSpectrogramDataset([os.path.join(root, 'val') for root in DATA_ROOTS], transform=test_transform)
test_dataset = ParkinsonSpectrogramDataset([os.path.join(root, 'test') for root in DATA_ROOTS], transform=test_transform)

# Optimized DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=False)

val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                        num_workers=NUM_WORKERS, pin_memory=True, persistent_workers=False)

test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                         num_workers=NUM_WORKERS, pin_memory=True)

print(f"✅ DataLoaders ready! Using device: {DEVICE}")

✅ Loaded 154732 samples from ['/home/nigmu/NPersonal/Projects/SDP/nigmu-parkinsons_disease_prediction/Dataset/MDVR/train', '/home/nigmu/NPersonal/Projects/SDP/nigmu-parkinsons_disease_prediction/Dataset/Italian/train']
✅ Loaded 37807 samples from ['/home/nigmu/NPersonal/Projects/SDP/nigmu-parkinsons_disease_prediction/Dataset/MDVR/val', '/home/nigmu/NPersonal/Projects/SDP/nigmu-parkinsons_disease_prediction/Dataset/Italian/val']
✅ Loaded 41886 samples from ['/home/nigmu/NPersonal/Projects/SDP/nigmu-parkinsons_disease_prediction/Dataset/MDVR/test', '/home/nigmu/NPersonal/Projects/SDP/nigmu-parkinsons_disease_prediction/Dataset/Italian/test']
✅ DataLoaders ready! Using device: cuda


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import seaborn as sns
import time
import copy
import gc

# Configuration
NUM_EPOCHS = 30
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
EARLY_STOPPING_PATIENCE = 5
CHECKPOINT_PATH = 'best_densenet_model.pth'

# Model definition
def create_densenet_model(num_classes=2):
    # Using DenseNet121 as base model (smaller than DenseNet169/201 to conserve VRAM)
    model = models.densenet121(weights='DEFAULT')
    
    # Modify first conv layer to accept grayscale images (1 channel)
    # We'll convert the first conv layer to accept 1 channel instead of 3
    first_conv = model.features.conv0
    model.features.conv0 = nn.Conv2d(
        1, 64, kernel_size=7, stride=2, padding=3, bias=False
    )
    # Initialize with weight averaging from the pre-trained 3-channel weights
    with torch.no_grad():
        model.features.conv0.weight.data = first_conv.weight.data.sum(dim=1, keepdim=True)
    
    # Replace classifier
    in_features = model.classifier.in_features
    model.classifier = nn.Sequential(
        nn.Dropout(0.2),  # Adding dropout for regularization
        nn.Linear(in_features, num_classes)
    )
    
    return model

# Trainer class
class ParkinsonsTrainer:
    def __init__(self, model, train_loader, val_loader, test_loader, device, 
                 criterion=nn.CrossEntropyLoss(), learning_rate=1e-4, 
                 weight_decay=1e-5, checkpoint_path='best_model.pth'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device
        self.criterion = criterion
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=0.5, patience=3, verbose=True)
        self.checkpoint_path = checkpoint_path
        self.best_val_acc = 0.0
        self.early_stopping_counter = 0
        
    def train_one_epoch(self):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(self.train_loader):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            # Zero gradients
            self.optimizer.zero_grad()
            
            # Forward pass
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            # Backward pass and optimize
            loss.backward()
            self.optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Clear memory
            if batch_idx % 20 == 0:  # Clear every 20 batches
                del inputs, targets, outputs
                torch.cuda.empty_cache()
                gc.collect()
            
            # Print progress every 50 batches
            if batch_idx % 50 == 49:
                print(f'Batch {batch_idx+1}/{len(self.train_loader)}, Loss: {running_loss/(batch_idx+1):.4f}, '
                      f'Acc: {100.*correct/total:.2f}%')
        
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = 100. * correct / total
        
        return epoch_loss, epoch_acc
    
    def validate(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in self.val_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                # Forward pass
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                
                # Statistics
                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                # Clear memory
                del inputs, targets, outputs
                
        torch.cuda.empty_cache()
        gc.collect()
        
        val_loss = running_loss / len(self.val_loader)
        val_acc = 100. * correct / total
        
        return val_loss, val_acc
    
    def save_checkpoint(self):
        print(f"Saving checkpoint with validation accuracy: {self.best_val_acc:.2f}%")
        torch.save(self.model.state_dict(), self.checkpoint_path)
    
    def train(self, num_epochs, early_stopping_patience=5):
        start_time = time.time()
        train_losses, train_accs = [], []
        val_losses, val_accs = [], []
        
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print('-' * 60)
            
            # Train
            train_loss, train_acc = self.train_one_epoch()
            train_losses.append(train_loss)
            train_accs.append(train_acc)
            
            # Validate
            val_loss, val_acc = self.validate()
            val_losses.append(val_loss)
            val_accs.append(val_acc)
            
            # Update learning rate based on validation accuracy
            self.scheduler.step(val_acc)
            
            # Print epoch results
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            
            # Save checkpoint if validation accuracy improves
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.save_checkpoint()
                self.early_stopping_counter = 0
            else:
                self.early_stopping_counter += 1
                print(f"EarlyStopping counter: {self.early_stopping_counter} out of {early_stopping_patience}")
                
                if self.early_stopping_counter >= early_stopping_patience:
                    print("Early stopping triggered")
                    break
        
        end_time = time.time()
        training_time = end_time - start_time
        hours, remainder = divmod(training_time, 3600)
        minutes, seconds = divmod(remainder, 60)
        print(f"Training completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")
        
        # Plot training/validation metrics
        self.plot_training_metrics(train_losses, val_losses, train_accs, val_accs)
        
        return train_losses, train_accs, val_losses, val_accs
    
    def plot_training_metrics(self, train_losses, val_losses, train_accs, val_accs):
        plt.figure(figsize=(15, 5))
        
        # Plot losses
        plt.subplot(1, 2, 1)
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        
        # Plot accuracies
        plt.subplot(1, 2, 2)
        plt.plot(train_accs, label='Train Accuracy')
        plt.plot(val_accs, label='Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy (%)')
        plt.title('Training and Validation Accuracy')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('training_metrics.png')
        plt.close()
    
    def load_best_model(self):
        try:
            self.model.load_state_dict(torch.load(self.checkpoint_path))
            print(f"Loaded best model from {self.checkpoint_path}")
            return True
        except:
            print("Could not load checkpoint. Using current model state.")
            return False
    
    def test(self):
        # Load best model for testing
        self.load_best_model()
        self.model.eval()
        
        all_preds = []
        all_targets = []
        all_probs = []
        
        with torch.no_grad():
            for inputs, targets in self.test_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                # Forward pass
                outputs = self.model(inputs)
                probs = nn.Softmax(dim=1)(outputs)
                
                # Get predictions
                _, preds = outputs.max(1)
                
                # Store for evaluation
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
                all_probs.extend(probs[:, 1].cpu().numpy())  # Store probability of class 1 (PD)
                
                # Clear memory
                del inputs, targets, outputs
        
        torch.cuda.empty_cache()
        
        # Calculate metrics
        test_acc = 100. * np.mean(np.array(all_preds) == np.array(all_targets))
        print(f"Test Accuracy: {test_acc:.2f}%")
        
        # Classification report
        print("\nClassification Report:")
        report = classification_report(all_targets, all_preds, 
                                       target_names=['Healthy Control', 'Parkinson\'s Disease'])
        print(report)
        
        # Confusion Matrix
        cm = confusion_matrix(all_targets, all_preds)
        self.plot_confusion_matrix(cm)
        
        # ROC Curve
        self.plot_roc_curve(all_targets, all_probs)
        
        return test_acc, all_preds, all_targets, all_probs
    
    def plot_confusion_matrix(self, cm):
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=['Healthy Control', 'Parkinson\'s Disease'],
                    yticklabels=['Healthy Control', 'Parkinson\'s Disease'])
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.title('Confusion Matrix')
        plt.tight_layout()
        plt.savefig('confusion_matrix.png')
        plt.close()
    
    def plot_roc_curve(self, y_true, y_score):
        fpr, tpr, _ = roc_curve(y_true, y_score)
        roc_auc = auc(fpr, tpr)
        
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC) Curve')
        plt.legend(loc='lower right')
        plt.tight_layout()
        plt.savefig('roc_curve.png')
        plt.close()

# Memory optimization functions
def optimize_memory():
    # Empty CUDA cache
    torch.cuda.empty_cache()
    # Collect garbage
    gc.collect()

# Main execution
def main():
    # Check available memory before starting
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        print(f"Available GPU memory: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB reserved")
    
    # Create model
    model = create_densenet_model(num_classes=2)
    print("Model created")
    
    # Initialize trainer
    trainer = ParkinsonsTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        device=DEVICE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        checkpoint_path=CHECKPOINT_PATH
    )
    print("Trainer initialized")
    
    # Train the model
    print("\nStarting training...")
    trainer.train(num_epochs=NUM_EPOCHS, early_stopping_patience=EARLY_STOPPING_PATIENCE)
    
    # Test the model
    print("\nEvaluating on test set...")
    test_acc, all_preds, all_targets, all_probs = trainer.test()
    
    # Save predictions for further analysis
    np.save('test_predictions.npy', {
        'predictions': all_preds,
        'targets': all_targets,
        'probabilities': all_probs
    })
    
    print("Complete! Check the saved model and evaluation outputs.")

if __name__ == "__main__":
    # Set seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
    
    # Run main function
    try:
        main()
    except RuntimeError as e:
        if 'out of memory' in str(e):
            print("\n🚨 CUDA OUT OF MEMORY ERROR")
            print("Try one of these solutions:")
            print("1. Reduce batch size (e.g., BATCH_SIZE = 8 or 4)")
            print("2. Use a smaller DenseNet variant (e.g., DenseNet121 instead of 169/201)")
            print("3. Resize images to smaller dimensions in your data transforms")
            print("4. Add gradient accumulation to simulate larger batch sizes")
            optimize_memory()
        else:
            raise e

GPU: NVIDIA GeForce GTX 1650
Total GPU memory: 3.90 GB
Available GPU memory: 0.00 GB reserved
Model created




Trainer initialized

Starting training...

Epoch 1/30
------------------------------------------------------------
Batch 50/19342, Loss: 0.6568, Acc: 61.50%
Batch 100/19342, Loss: 0.6138, Acc: 65.62%
Batch 150/19342, Loss: 0.5842, Acc: 68.17%
Batch 200/19342, Loss: 0.5869, Acc: 68.62%
Batch 250/19342, Loss: 0.5763, Acc: 69.70%
Batch 300/19342, Loss: 0.5626, Acc: 70.71%
Batch 350/19342, Loss: 0.5540, Acc: 71.46%
Batch 400/19342, Loss: 0.5463, Acc: 72.12%
Batch 450/19342, Loss: 0.5400, Acc: 72.25%
Batch 500/19342, Loss: 0.5301, Acc: 72.85%
Batch 550/19342, Loss: 0.5232, Acc: 73.14%
Batch 600/19342, Loss: 0.5202, Acc: 73.58%
Batch 650/19342, Loss: 0.5189, Acc: 73.67%
Batch 700/19342, Loss: 0.5131, Acc: 74.11%
Batch 750/19342, Loss: 0.5094, Acc: 74.23%
Batch 800/19342, Loss: 0.5060, Acc: 74.52%
Batch 850/19342, Loss: 0.5017, Acc: 74.91%
Batch 900/19342, Loss: 0.4960, Acc: 75.42%
Batch 950/19342, Loss: 0.4964, Acc: 75.34%
Batch 1000/19342, Loss: 0.4919, Acc: 75.66%
Batch 1050/19342, Loss: 0