In [None]:
import os
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.cuda.amp import autocast, GradScaler
import gc

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

set_seed()

class MDVRDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.classes = ['HC', 'PD']
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.samples = []
        self._load_paths()
        
    def _load_paths(self):
        split_dir = os.path.join(self.root_dir, self.split)
        for class_name in self.classes:
            class_dir = os.path.join(split_dir, class_name)
            for patient_folder in os.listdir(class_dir):
                patient_path = os.path.join(class_dir, patient_folder)
                if os.path.isdir(patient_path):
                    self.samples.extend([
                        (os.path.join(patient_path, img_name), self.class_to_idx[class_name])
                        for img_name in os.listdir(patient_path)
                        if img_name.endswith(('.png', '.jpg', '.jpeg'))
                    ])
        print(f"Found {len(self.samples)} samples for {self.split} split")
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('L')
        if self.transform:
            image = self.transform(image)
        return image, label

class SimpleAugment:
    def __init__(self, freq_mask=15, time_mask=25):
        self.freq_mask = freq_mask
        self.time_mask = time_mask

    def __call__(self, img):
        img_tensor = transforms.ToTensor()(img)
        _, h, w = img_tensor.shape
        
        # Frequency masking
        f = np.random.randint(1, self.freq_mask)
        f0 = np.random.randint(0, h - f)
        img_tensor[:, f0:f0+f, :] = 0
            
        # Time masking
        t = np.random.randint(1, self.time_mask)
        t0 = np.random.randint(0, w - t)
        img_tensor[:, :, t0:t0+t] = 0
            
        return transforms.ToPILImage()(img_tensor)

class LightCNN(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(16, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.classifier = nn.Linear(32, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

def compute_stats_safely(dataset):
    mean = 0.0
    mean2 = 0.0
    count = 0
    
    for img, _ in tqdm(dataset, desc="Calculating stats"):
        img = img.view(-1).float()
        batch_pixels = img.numel()
        batch_mean = torch.mean(img)
        batch_mean2 = torch.mean(img**2)
        
        delta = batch_mean - mean
        mean += delta * batch_pixels / (count + batch_pixels)
        
        delta2 = batch_mean2 - mean2
        mean2 += delta2 * batch_pixels / (count + batch_pixels)
        
        count += batch_pixels
        
        del img, batch_mean, batch_mean2
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
    std = np.sqrt(max(mean2 - mean**2, 0.0))
    return mean.item(), max(std, 0.5)

def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, device='cuda', patience=5):
    model.to(device)
    scaler = GradScaler()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)
    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    early_stop_counter = 0

    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        
        for phase in ['train', 'val']:
            model.train() if phase == 'train' else model.eval()
            running_loss = 0.0
            running_corrects = 0
            
            pbar = tqdm(dataloaders[phase], desc=f'{phase} phase')
            for inputs, labels in pbar:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad(set_to_none=True)
                
                with torch.set_grad_enabled(phase == 'train'):
                    with autocast():
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        scaler.scale(loss).backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                        scaler.step(optimizer)
                        scaler.update()
                    
                    preds = torch.argmax(outputs, 1)
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels).item()
                    
                    del outputs, preds
                    torch.cuda.empty_cache()
                
                del inputs, labels, loss
                gc.collect()
                pbar.set_postfix({'loss': running_loss/((pbar.n+1)*inputs.size(0))})

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects / len(dataloaders[phase].dataset)
            
            if phase == 'val':
                scheduler.step(epoch_acc)
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), 'best_model.pth')
                    early_stop_counter = 0
                else:
                    early_stop_counter += 1
            
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc)
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        if early_stop_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break
            
    model.load_state_dict(torch.load('best_model.pth'))
    return model, history

def run_experiment(data_dir, batch_size=16, num_epochs=50):
    # Calculate dataset statistics safely
    temp_dataset = MDVRDataset(data_dir, 'train', transforms.Compose([
        transforms.Resize((496, 200)), 
        transforms.ToTensor()
    ]))
    mean, std = compute_stats_safely(temp_dataset)
    print(f"Dataset stats - Mean: {mean:.4f}, Std: {std:.4f}")
    del temp_dataset
    gc.collect()
    
    # Simplified transforms
    train_transform = transforms.Compose([
        transforms.Resize((496, 200)),
        SimpleAugment(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((496, 200)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    
    # Create datasets
    train_dataset = MDVRDataset(data_dir, 'train', train_transform)
    val_dataset = MDVRDataset(data_dir, 'val', test_transform)
    test_dataset = MDVRDataset(data_dir, 'test', test_transform)
    
    # Memory-safe dataloaders
    dataloaders = {
        'train': DataLoader(train_dataset, batch_size, shuffle=True,
                          num_workers=2, pin_memory=True, persistent_workers=False),
        'val': DataLoader(val_dataset, batch_size, num_workers=2),
        'test': DataLoader(test_dataset, batch_size, num_workers=2)
    }
    
    # Initialize model
    model = LightCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    
    # Train with memory monitoring
    model, history = train_model(
        model, dataloaders, criterion, optimizer, 
        num_epochs=num_epochs, patience=7
    )
    
    # Plot results
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Val')
    plt.title('Loss Curve')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['val_acc'], label='Val')
    plt.title('Accuracy Curve')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_curves.png')
    
    return model, history

if __name__ == "__main__":
    model, history = run_experiment(
        # data_dir="C:/NPersonal/Projects/SDP/Prediction Stuff/Dataset/MDVR",
        data_dir=r"/home/nigmu/NPersonal/Projects/SDP/nigmu-parkinsons_disease_prediction/Dataset/MDVR",
        batch_size=16,
        num_epochs=50
    )

Found 41744 samples for train split


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7743c90ccd40>>       | 3399/41744 [06:47<1:03:45, 10.02it/s]
Traceback (most recent call last):
  File "/home/nigmu/pytorch_env/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7743c90ccd40>>       | 3429/41744 [06:50<1:03:26, 10.06it/s]
Traceback (most recent call last):
  File "/home/nigmu/pytorch_env/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7743c90ccd40>>       | 3431/41744 [06:50<1:04:51,  9.85it/s