In [None]:
import os
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.cuda.amp import autocast, GradScaler

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True  # Better for performance

set_seed()

class MDVRDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.classes = ['hc', 'pd']
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.samples = []
        self._load_data()
        
    def _load_data(self):
        split_dir = os.path.join(self.root_dir, self.split)
        for class_name in self.classes:
            class_dir = os.path.join(split_dir, class_name)
            for patient_folder in os.listdir(class_dir):
                patient_path = os.path.join(class_dir, patient_folder)
                if os.path.isdir(patient_path):
                    for img_name in os.listdir(patient_path):
                        if img_name.endswith(('.png', '.jpg', '.jpeg')):
                            img_path = os.path.join(patient_path, img_name)
                            self.samples.append((img_path, self.class_to_idx[class_name]))
        print(f"Loaded {len(self.samples)} samples for {self.split} split")
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('L')
        if self.transform:
            image = self.transform(image)
        return image, label

class SpecAugment:
    def __init__(self, freq_mask=30, time_mask=50, num_masks=2):
        self.freq_mask = freq_mask
        self.time_mask = time_mask
        self.num_masks = num_masks

    def __call__(self, img):
        img_tensor = transforms.ToTensor()(img)
        _, n_mels, n_time = img_tensor.shape
        
        # Frequency masking
        for _ in range(self.num_masks):
            f = np.random.randint(1, self.freq_mask)
            f0 = np.random.randint(0, n_mels - f)
            img_tensor[:, f0:f0+f, :] = 0
            
        # Time masking
        for _ in range(self.num_masks):
            t = np.random.randint(1, self.time_mask)
            t0 = np.random.randint(0, n_time - t)
            img_tensor[:, :, t0:t0+t] = 0
            
        return transforms.ToPILImage()(img_tensor)

class GaussianNoise:
    def __init__(self, std=0.02):
        self.std = std
        
    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size):
        super().__init__()
        self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size, 
                                 groups=in_ch, padding='same')
        self.pointwise = nn.Conv2d(in_ch, out_ch, 1)
        
    def forward(self, x):
        return self.pointwise(self.depthwise(x))

class ParkinsonNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.features = nn.Sequential(
            DepthwiseSeparableConv(1, 32, 3),
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.MaxPool2d(2, 2),
            
            DepthwiseSeparableConv(32, 64, 3),
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.MaxPool2d((2,4)),
            
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.GELU(),
            nn.AdaptiveAvgPool2d((1,1))
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.Dropout(0.5),
            nn.GELU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, device='cuda', patience=5):
    model.to(device)
    scaler = GradScaler()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2, verbose=True)
    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    early_stop_counter = 0

    for epoch in range(num_epochs):
        epoch_start = time.time()
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        
        for phase in ['train', 'val']:
            model.train() if phase == 'train' else model.eval()
            running_loss = 0.0
            running_corrects = 0
            
            for inputs, labels in tqdm(dataloaders[phase], desc=f'{phase} phase'):
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    with autocast():
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    
                    preds = torch.argmax(outputs, 1)
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels)
            
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            
            if phase == 'val':
                scheduler.step(epoch_acc)
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), 'best_model.pth')
                    early_stop_counter = 0
                else:
                    early_stop_counter += 1
            
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        if early_stop_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break
            
    model.load_state_dict(torch.load('best_model.pth'))
    return model, history

def run_experiment(data_dir, batch_size=64, num_epochs=50):
    # Calculate dataset statistics
    temp_dataset = MDVRDataset(data_dir, 'train', transforms.Compose([
        transforms.Resize((496, 200)), transforms.ToTensor()
    ]))
    pixels = torch.cat([img.view(-1) for img, _ in tqdm(temp_dataset)])
    mean, std = pixels.mean().item(), pixels.std().item()
    
    # Enhanced transforms
    train_transform = transforms.Compose([
        transforms.Resize((496, 200)),
        SpecAugment(freq_mask=30, time_mask=50),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
        GaussianNoise(std=0.02)
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((496, 200)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    
    # Create datasets
    train_dataset = MDVRDataset(data_dir, 'train', train_transform)
    val_dataset = MDVRDataset(data_dir, 'val', test_transform)
    test_dataset = MDVRDataset(data_dir, 'test', test_transform)
    
    # Create dataloaders with optimized settings
    dataloaders = {
        'train': DataLoader(train_dataset, batch_size, shuffle=True, 
                          num_workers=4, pin_memory=True, persistent_workers=True),
        'val': DataLoader(val_dataset, batch_size, num_workers=4),
        'test': DataLoader(test_dataset, batch_size, num_workers=4)
    }
    
    # Initialize model
    model = ParkinsonNet()
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Reduces overconfidence
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    
    # Train
    model, history = train_model(
        model, dataloaders, criterion, optimizer, 
        num_epochs=num_epochs, patience=7
    )
    
    # Test
    test_loss, test_acc = test_model(model, dataloaders['test'], criterion)
    
    # Plot results
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Val')
    plt.title('Loss Curve')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['val_acc'], label='Val')
    plt.title('Accuracy Curve')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_curves.png')
    
    return model, history, (test_loss, test_acc)

# Example usage
if __name__ == "__main__":
    model, history, test_results = run_experiment(
        data_dir="C:/NPersonal/Projects/SDP/Prediction Stuff/Dataset/MDVR",
        batch_size=64,
        num_epochs=50
    )

Loaded 41744 samples for train split


 12%|█████████▎                                                                  | 5130/41744 [00:17<02:13, 273.53it/s]