<a href="https://colab.research.google.com/github/nimrashaheen001/Programming_for_AI/blob/main/BasepaperImplementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install nibabel

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import nibabel as nib  # For potential metadata extraction (if needed)
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image  # For image loading
from torchvision import transforms # For image transformations
#Mount Google Drive (uncomment this in Google Colab)
from google.colab import drive
drive.mount('/content/drive')

class BasicBlock2D(nn.Module):  # Changed class name to BasicBlock2D
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock2D, self).__init__()
        # Changed to 2D convolutional layers
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        # Changed to 2D batch normalization
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet2D(nn.Module):  # Changed class name to ResNet2D
    def __init__(self, block, num_blocks, in_channels=3, num_classes=4):  # Updated in_channels and num_classes
        super(ResNet2D, self).__init__()
        self.in_planes = 64

        # Initial convolution layer (changed to 2D)
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)  # Changed to 2D batch normalization

        # ResNet layers (using BasicBlock2D)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        # Classification head (changed to 2D)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # Changed to 2D adaptive average pooling
        self.fc = nn.Linear(512 * block.expansion, 512)

        # Multiple output heads (adjusted for 4 classes)
        self.classification_head = nn.Linear(512, num_classes)
        self.regression_mmse = nn.Linear(512, 1)  # MMSE score regression
        self.regression_cdr = nn.Linear(512, 1)   # Clinical Dementia Rating regression

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        # Input processing
        out = F.relu(self.bn1(self.conv1(x)))

        # ResNet blocks
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        # Global pooling
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)

        # Shared features
        features = F.relu(self.fc(out))

        # Multi-task outputs
        classification = self.classification_head(features)
        mmse_score = self.regression_mmse(features)
        cdr_score = self.regression_cdr(features)

        return {
            'classification': classification,
            'mmse_score': mmse_score,
            'cdr_score': cdr_score,
            'features': features
        }

def ResNet18_2D(in_channels=3, num_classes=4):  # Changed function name and defaults
    return ResNet2D(BasicBlock2D, [2, 2, 2, 2], in_channels, num_classes)  # Us

class BrainMRIDataset(Dataset):
    def __init__(self, data_dir, classes=['VeryMildDemented', 'MildDemented', 'NonDemented', 'ModerateDemented'], transform=None):
        """
        Dataset for multimodal brain MRI data

        Args:
            data_dir (str): Directory containing the data
            classes (list): List of class names
            transform (callable, optional): Optional transform to be applied on a sample
        """
        self.data_dir = data_dir
        self.classes = classes
        self.transform = transform

        # Find all images and labels
        self.images = []
        self.labels = []

        # Assuming directory structure: data_dir/class_label/image.png (or other image format)
        for class_idx, class_name in enumerate(classes):
            class_path = os.path.join(data_dir, class_name)
            image_files = [f for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))]

            for image_file in image_files:
                image_path = os.path.join(class_path, image_file)
                self.images.append(image_path)
                self.labels.append(class_idx) # Assign class index as label

        print(f"Found {len(self.images)} images in total.")

    def __getitem__(self, idx):
        image_path = self.images[idx]
        label = self.labels[idx]

        # Load image
        image = Image.open(image_path).convert('RGB')  # Convert to RGB if needed

        # Apply transformations
        if self.transform:
            image = self.transform(image)

        # Simulate or load mmse_score and cdr_score (Replace with your actual logic)
        # Here I'm simulating them based on the label for demonstration purposes
        mmse_score = torch.tensor([28.0 if label == 2 else 20.0 + np.random.normal(0, 2)], dtype=torch.float32) # Assuming label 2 is 'NonDemented'
        cdr_score = torch.tensor([0.0 if label == 2 else 1.0 + np.random.normal(0, 0.5)], dtype=torch.float32)   # Assuming label 2 is 'NonDemented'


        return {'image': image, 'label': torch.tensor(label, dtype=torch.long), 'mmse_score': mmse_score, 'cdr_score': cdr_score}

    def __len__(self):
        """
        Returns the number of samples in the dataset.
        """
        print("Calling __len__ function!")  # Debug print statement
        return len(self.images)

# ... (Rest of the code) ...

def train_model(model, dataloaders, criterion_dict, optimizer, scheduler, num_epochs=25, device='cuda'):
    model = model.to(device)

    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            running_mmse_loss = 0.0
            running_cdr_loss = 0.0

            # Iterate over data
            for inputs in tqdm(dataloaders[phase]):
                images = inputs['image'].to(device)
                labels = inputs['label'].to(device)
                mmse_scores = inputs['mmse_score'].to(device)
                cdr_scores = inputs['cdr_score'].to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(images)
                    _, preds = torch.max(outputs['classification'], 1)

                    # Compute losses
                    classification_loss = criterion_dict['classification'](outputs['classification'], labels)
                    mmse_loss = criterion_dict['regression'](outputs['mmse_score'], mmse_scores)
                    cdr_loss = criterion_dict['regression'](outputs['cdr_score'], cdr_scores)

                    # Combined loss
                    loss = classification_loss + 0.5 * mmse_loss + 0.5 * cdr_loss

                    # Backward + optimize only in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * images.size(0)
                running_corrects += torch.sum(preds == labels.data)
                running_mmse_loss += mmse_loss.item() * images.size(0)
                running_cdr_loss += cdr_loss.item() * images.size(0)

            if phase == 'train' and scheduler is not None:
                scheduler.step()

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            epoch_mmse_loss = running_mmse_loss / len(dataloaders[phase].dataset)
            epoch_cdr_loss = running_cdr_loss / len(dataloaders[phase].dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} MMSE Loss: {epoch_mmse_loss:.4f} CDR Loss: {epoch_cdr_loss:.4f}')

            # Deep copy the model if it's the best validation accuracy so far
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict().copy()

    print(f'Best val Acc: {best_acc:4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

 # Set paths
    data_dir = "/content/drive/MyDrive/Alzheimer_MRI_4_classes_dataset"  # Update with your Google Drive path

    # Define transformations (resize, normalize, etc.)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to a common size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet stats
    ])

    # Create dataset
    # Create dataset
    dataset = BrainMRIDataset(data_dir=data_dir, classes=['VeryMildDemented', 'MildDemented', 'NonDemented', 'ModerateDemented'], transform=transform)
    # Split into train and validation sets
    train_indices, val_indices = train_test_split(
        range(len(dataset)),
        test_size=0.2,
        random_state=42,
        stratify=dataset.labels
    )

    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)

    dataloaders = {
        'train': train_loader,
        'val': val_loader
    }

   # Create model
    model = ResNet18_2D(in_channels=3, num_classes=4)  # Update in_channels and num_classes for 2D images and 4 classes
    # ... (Loss, optimizer, training, saving - Similar a

    # Define loss functions
    criterion_dict = {
        'classification': nn.CrossEntropyLoss(),
        'regression': nn.MSELoss()
    }

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # Train model
    model = train_model(
        model=model,
        dataloaders=dataloaders,
        criterion_dict=criterion_dict,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=25,
        device=device
    )

    # Save model
    torch.save(model.state_dict(), 'brain_mri_model.pth')

    print("Training complete!")

def predict(model, dataloader, device='cuda'):
    model.eval()
    model = model.to(device)

    results = {
        'subject_ids': [],
        'true_labels': [],
        'predictions': [],
        'mmse_scores': [],
        'cdr_scores': []
    }

    with torch.no_grad():
        for inputs in tqdm(dataloader):
            images = inputs['image'].to(device)
            labels = inputs['label'].cpu().numpy()
            subject_ids = inputs['subject_id']

            outputs = model(images)
            _, preds = torch.max(outputs['classification'], 1)

            # Store results
            results['subject_ids'].extend(subject_ids)
            results['true_labels'].extend(labels)
            results['predictions'].extend(preds.cpu().numpy())
            results['mmse_scores'].extend(outputs['mmse_score'].cpu().numpy().flatten())
            results['cdr_scores'].extend(outputs['cdr_score'].cpu().numpy().flatten())

    return results

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda




Found 6400 images in total.
Calling __len__ function!
Epoch 0/24
----------


100%|██████████| 1280/1280 [04:45<00:00,  4.49it/s]


train Loss: 10.8677 Acc: 0.4854 MMSE Loss: 19.0536 CDR Loss: 0.4276


100%|██████████| 320/320 [01:06<00:00,  4.81it/s]


val Loss: 8.3758 Acc: 0.5438 MMSE Loss: 14.5230 CDR Loss: 0.3276
Epoch 1/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.11it/s]


train Loss: 9.8764 Acc: 0.4955 MMSE Loss: 17.1688 CDR Loss: 0.4098


100%|██████████| 320/320 [00:17<00:00, 17.92it/s]


val Loss: 9.1959 Acc: 0.5195 MMSE Loss: 15.4099 CDR Loss: 0.3628
Epoch 2/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 9.5573 Acc: 0.5115 MMSE Loss: 16.5912 CDR Loss: 0.4070


100%|██████████| 320/320 [00:17<00:00, 18.23it/s]


val Loss: 8.1571 Acc: 0.5539 MMSE Loss: 14.1137 CDR Loss: 0.3780
Epoch 3/24
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.12it/s]


train Loss: 9.3156 Acc: 0.5199 MMSE Loss: 16.1708 CDR Loss: 0.3822


100%|██████████| 320/320 [00:17<00:00, 18.34it/s]


val Loss: 8.2580 Acc: 0.5547 MMSE Loss: 14.3374 CDR Loss: 0.3220
Epoch 4/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 9.1210 Acc: 0.5180 MMSE Loss: 15.8462 CDR Loss: 0.3906


100%|██████████| 320/320 [00:17<00:00, 18.43it/s]


val Loss: 8.3710 Acc: 0.5141 MMSE Loss: 14.4666 CDR Loss: 0.3613
Epoch 5/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 9.1940 Acc: 0.5170 MMSE Loss: 15.9758 CDR Loss: 0.3956


100%|██████████| 320/320 [00:17<00:00, 18.28it/s]


val Loss: 9.5478 Acc: 0.5734 MMSE Loss: 16.8320 CDR Loss: 0.3123
Epoch 6/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 8.8529 Acc: 0.5320 MMSE Loss: 15.3409 CDR Loss: 0.3883


100%|██████████| 320/320 [00:17<00:00, 18.25it/s]


val Loss: 11.0605 Acc: 0.5734 MMSE Loss: 19.9656 CDR Loss: 0.3798
Epoch 7/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 8.0100 Acc: 0.5555 MMSE Loss: 13.8578 CDR Loss: 0.3244


100%|██████████| 320/320 [00:17<00:00, 18.32it/s]


val Loss: 8.0343 Acc: 0.5922 MMSE Loss: 13.9963 CDR Loss: 0.3287
Epoch 8/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 7.7802 Acc: 0.5641 MMSE Loss: 13.4033 CDR Loss: 0.3364


100%|██████████| 320/320 [00:17<00:00, 18.27it/s]


val Loss: 10.2897 Acc: 0.5672 MMSE Loss: 18.4394 CDR Loss: 0.3537
Epoch 9/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 7.5550 Acc: 0.5670 MMSE Loss: 12.9485 CDR Loss: 0.3513


100%|██████████| 320/320 [00:17<00:00, 18.09it/s]


val Loss: 10.2087 Acc: 0.5953 MMSE Loss: 18.3982 CDR Loss: 0.3160
Epoch 10/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.11it/s]


train Loss: 7.4164 Acc: 0.5736 MMSE Loss: 12.7149 CDR Loss: 0.3274


100%|██████████| 320/320 [00:17<00:00, 18.34it/s]


val Loss: 7.3784 Acc: 0.5914 MMSE Loss: 12.7335 CDR Loss: 0.3136
Epoch 11/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 7.1266 Acc: 0.5748 MMSE Loss: 12.1545 CDR Loss: 0.3274


100%|██████████| 320/320 [00:17<00:00, 18.28it/s]


val Loss: 16.7154 Acc: 0.5859 MMSE Loss: 31.3014 CDR Loss: 0.3158
Epoch 12/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 7.0439 Acc: 0.5938 MMSE Loss: 12.0173 CDR Loss: 0.3209


100%|██████████| 320/320 [00:17<00:00, 18.29it/s]


val Loss: 6.7891 Acc: 0.5898 MMSE Loss: 11.5489 CDR Loss: 0.3105
Epoch 13/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 6.9188 Acc: 0.5979 MMSE Loss: 11.7742 CDR Loss: 0.3271


100%|██████████| 320/320 [00:17<00:00, 18.28it/s]


val Loss: 6.9563 Acc: 0.6109 MMSE Loss: 11.9811 CDR Loss: 0.3041
Epoch 14/24
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.13it/s]


train Loss: 6.5515 Acc: 0.6059 MMSE Loss: 11.1080 CDR Loss: 0.3081


100%|██████████| 320/320 [00:17<00:00, 18.17it/s]


val Loss: 6.6122 Acc: 0.6141 MMSE Loss: 11.3066 CDR Loss: 0.2938
Epoch 15/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 6.4516 Acc: 0.6146 MMSE Loss: 10.9140 CDR Loss: 0.3117


100%|██████████| 320/320 [00:17<00:00, 18.40it/s]


val Loss: 6.7326 Acc: 0.6125 MMSE Loss: 11.5481 CDR Loss: 0.2996
Epoch 16/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 6.5369 Acc: 0.6098 MMSE Loss: 11.0872 CDR Loss: 0.3057


100%|██████████| 320/320 [00:17<00:00, 18.44it/s]


val Loss: 6.4218 Acc: 0.6148 MMSE Loss: 10.9425 CDR Loss: 0.2929
Epoch 17/24
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.13it/s]


train Loss: 6.3379 Acc: 0.6158 MMSE Loss: 10.7248 CDR Loss: 0.2936


100%|██████████| 320/320 [00:17<00:00, 18.32it/s]


val Loss: 7.1625 Acc: 0.6195 MMSE Loss: 12.4386 CDR Loss: 0.2834
Epoch 18/24
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.13it/s]


train Loss: 6.3577 Acc: 0.6215 MMSE Loss: 10.7445 CDR Loss: 0.3103


100%|██████████| 320/320 [00:17<00:00, 18.40it/s]


val Loss: 6.5805 Acc: 0.6141 MMSE Loss: 11.2610 CDR Loss: 0.3020
Epoch 19/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 6.3144 Acc: 0.6213 MMSE Loss: 10.6665 CDR Loss: 0.3042


100%|██████████| 320/320 [00:17<00:00, 18.38it/s]


val Loss: 6.5938 Acc: 0.6078 MMSE Loss: 11.2772 CDR Loss: 0.2830
Epoch 20/24
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.13it/s]


train Loss: 6.2600 Acc: 0.6244 MMSE Loss: 10.5832 CDR Loss: 0.2980


100%|██████████| 320/320 [00:17<00:00, 18.40it/s]


val Loss: 6.7695 Acc: 0.6234 MMSE Loss: 11.6220 CDR Loss: 0.2972
Epoch 21/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 6.3265 Acc: 0.6215 MMSE Loss: 10.7187 CDR Loss: 0.3011


100%|██████████| 320/320 [00:17<00:00, 18.34it/s]


val Loss: 6.4332 Acc: 0.6227 MMSE Loss: 10.9774 CDR Loss: 0.3001
Epoch 22/24
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.13it/s]


train Loss: 6.1509 Acc: 0.6162 MMSE Loss: 10.3386 CDR Loss: 0.3072


100%|██████████| 320/320 [00:17<00:00, 18.34it/s]


val Loss: 6.4792 Acc: 0.6203 MMSE Loss: 11.0903 CDR Loss: 0.2797
Epoch 23/24
----------


100%|██████████| 1280/1280 [03:29<00:00,  6.12it/s]


train Loss: 6.2713 Acc: 0.6258 MMSE Loss: 10.6065 CDR Loss: 0.2952


100%|██████████| 320/320 [00:17<00:00, 18.39it/s]


val Loss: 6.3210 Acc: 0.6281 MMSE Loss: 10.7665 CDR Loss: 0.2865
Epoch 24/24
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.12it/s]


train Loss: 6.3565 Acc: 0.6256 MMSE Loss: 10.7592 CDR Loss: 0.3074


100%|██████████| 320/320 [00:17<00:00, 18.35it/s]

val Loss: 6.6636 Acc: 0.6203 MMSE Loss: 11.4279 CDR Loss: 0.3124
Best val Acc: 0.628125
Training complete!





In [None]:
!pip install nibabel

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import nibabel as nib  # For potential metadata extraction (if needed)
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image  # For image loading
from torchvision import transforms # For image transformations
#Mount Google Drive (uncomment this in Google Colab)
from google.colab import drive
drive.mount('/content/drive')

class BasicBlock2D(nn.Module):  # Changed class name to BasicBlock2D
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock2D, self).__init__()
        # Changed to 2D convolutional layers
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        # Changed to 2D batch normalization
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet2D(nn.Module):  # Changed class name to ResNet2D
    def __init__(self, block, num_blocks, in_channels=3, num_classes=4):  # Updated in_channels and num_classes
        super(ResNet2D, self).__init__()
        self.in_planes = 64

        # Initial convolution layer (changed to 2D)
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)  # Changed to 2D batch normalization

        # ResNet layers (using BasicBlock2D)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        # Classification head (changed to 2D)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # Changed to 2D adaptive average pooling
        self.fc = nn.Linear(512 * block.expansion, 512)

        # Multiple output heads (adjusted for 4 classes)
        self.classification_head = nn.Linear(512, num_classes)
        self.regression_mmse = nn.Linear(512, 1)  # MMSE score regression
        self.regression_cdr = nn.Linear(512, 1)   # Clinical Dementia Rating regression

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        # Input processing
        out = F.relu(self.bn1(self.conv1(x)))

        # ResNet blocks
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        # Global pooling
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)

        # Shared features
        features = F.relu(self.fc(out))

        # Multi-task outputs
        classification = self.classification_head(features)
        mmse_score = self.regression_mmse(features)
        cdr_score = self.regression_cdr(features)

        return {
            'classification': classification,
            'mmse_score': mmse_score,
            'cdr_score': cdr_score,
            'features': features
        }

def ResNet18_2D(in_channels=3, num_classes=4):  # Changed function name and defaults
    return ResNet2D(BasicBlock2D, [2, 2, 2, 2], in_channels, num_classes)  # Us

class BrainMRIDataset(Dataset):
    def __init__(self, data_dir, classes=['VeryMildDemented', 'MildDemented', 'NonDemented', 'ModerateDemented'], transform=None):
        """
        Dataset for multimodal brain MRI data

        Args:
            data_dir (str): Directory containing the data
            classes (list): List of class names
            transform (callable, optional): Optional transform to be applied on a sample
        """
        self.data_dir = data_dir
        self.classes = classes
        self.transform = transform

        # Find all images and labels
        self.images = []
        self.labels = []

        # Assuming directory structure: data_dir/class_label/image.png (or other image format)
        for class_idx, class_name in enumerate(classes):
            class_path = os.path.join(data_dir, class_name)
            image_files = [f for f in os.listdir(class_path) if os.path.isfile(os.path.join(class_path, f))]

            for image_file in image_files:
                image_path = os.path.join(class_path, image_file)
                self.images.append(image_path)
                self.labels.append(class_idx) # Assign class index as label

        print(f"Found {len(self.images)} images in total.")

    def __getitem__(self, idx):
        image_path = self.images[idx]
        label = self.labels[idx]

        # Load image
        image = Image.open(image_path).convert('RGB')  # Convert to RGB if needed

        # Apply transformations
        if self.transform:
            image = self.transform(image)

        # Simulate or load mmse_score and cdr_score (Replace with your actual logic)
        # Here I'm simulating them based on the label for demonstration purposes
        mmse_score = torch.tensor([28.0 if label == 2 else 20.0 + np.random.normal(0, 2)], dtype=torch.float32) # Assuming label 2 is 'NonDemented'
        cdr_score = torch.tensor([0.0 if label == 2 else 1.0 + np.random.normal(0, 0.5)], dtype=torch.float32)   # Assuming label 2 is 'NonDemented'


        return {'image': image, 'label': torch.tensor(label, dtype=torch.long), 'mmse_score': mmse_score, 'cdr_score': cdr_score}

    def __len__(self):
        """
        Returns the number of samples in the dataset.
        """
        print("Calling __len__ function!")  # Debug print statement
        return len(self.images)

# ... (Rest of the code) ...

# Define training function with accuracy tracking
def train_model(model, dataloaders, criterion_dict, optimizer, scheduler, num_epochs=100, device='cuda'):
    model = model.to(device)
    best_acc = 0.0

    train_acc_history = []
    val_acc_history = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            total_samples = 0

            for inputs in tqdm(dataloaders[phase]):
                images = inputs['image'].to(device)
                labels = inputs['label'].to(device)
                mmse_scores = inputs['mmse_score'].to(device)
                cdr_scores = inputs['cdr_score'].to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(images)
                    _, preds = torch.max(outputs['classification'], 1)

                    classification_loss = criterion_dict['classification'](outputs['classification'], labels)
                    mmse_loss = criterion_dict['regression'](outputs['mmse_score'], mmse_scores)
                    cdr_loss = criterion_dict['regression'](outputs['cdr_score'], cdr_scores)

                    loss = classification_loss + 0.5 * mmse_loss + 0.5 * cdr_loss

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * images.size(0)
                running_corrects += torch.sum(preds == labels.data)
                total_samples += labels.size(0)

            if phase == 'train' and scheduler is not None:
                scheduler.step()

            epoch_loss = running_loss / total_samples
            epoch_acc = running_corrects.double() / total_samples
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'train':
                train_acc_history.append(epoch_acc.item())
            else:
                val_acc_history.append(epoch_acc.item())

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict().copy()

    print(f'Best validation accuracy: {best_acc:.4f}')
    model.load_state_dict(best_model_wts)
    return model, train_acc_history, val_acc_history


def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

 # Set paths
    data_dir = "/content/drive/MyDrive/Alzheimer_MRI_4_classes_dataset"  # Update with your Google Drive path

    # Define transformations (resize, normalize, etc.)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize to a common size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet stats
    ])

    # Create dataset
    # Create dataset
    dataset = BrainMRIDataset(data_dir=data_dir, classes=['VeryMildDemented', 'MildDemented', 'NonDemented', 'ModerateDemented'], transform=transform)
    # Split into train and validation sets
    train_indices, val_indices = train_test_split(
        range(len(dataset)),
        test_size=0.2,
        random_state=42,
        stratify=dataset.labels
    )

    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)

    dataloaders = {
        'train': train_loader,
        'val': val_loader
    }

   # Create model
    model = ResNet18_2D(in_channels=3, num_classes=4)  # Update in_channels and num_classes for 2D images and 4 classes
    # ... (Loss, optimizer, training, saving - Similar a

    # Define loss functions
    criterion_dict = {
        'classification': nn.CrossEntropyLoss(),
        'regression': nn.MSELoss()
    }

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    # Train model
    model = train_model(
        model=model,
        dataloaders=dataloaders,
        criterion_dict=criterion_dict,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=25,
        device=device
    )

    # Save model
    torch.save(model.state_dict(), 'brain_mri_model.pth')

    print("Training complete!")

def predict(model, dataloader, device='cuda'):
    model.eval()
    model = model.to(device)

    results = {
        'subject_ids': [],
        'true_labels': [],
        'predictions': [],
        'mmse_scores': [],
        'cdr_scores': []
    }

    with torch.no_grad():
        for inputs in tqdm(dataloader):
            images = inputs['image'].to(device)
            labels = inputs['label'].cpu().numpy()
            subject_ids = inputs['subject_id']

            outputs = model(images)
            _, preds = torch.max(outputs['classification'], 1)

            # Store results
            results['subject_ids'].extend(subject_ids)
            results['true_labels'].extend(labels)
            results['predictions'].extend(preds.cpu().numpy())
            results['mmse_scores'].extend(outputs['mmse_score'].cpu().numpy().flatten())
            results['cdr_scores'].extend(outputs['cdr_score'].cpu().numpy().flatten())

    return results

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    data_dir = "/content/drive/MyDrive/Alzheimer_MRI_4_classes_dataset"
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = BrainMRIDataset(data_dir=data_dir, classes=['VeryMildDemented', 'MildDemented', 'NonDemented', 'ModerateDemented'], transform=transform)
    train_indices, val_indices = train_test_split(range(len(dataset)), test_size=0.2, random_state=42, stratify=dataset.labels)
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)
    dataloaders = {'train': train_loader, 'val': val_loader}

    model = ResNet18_2D(in_channels=3, num_classes=4)
    criterion_dict = {'classification': nn.CrossEntropyLoss(), 'regression': nn.MSELoss()}
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    model, train_acc, val_acc = train_model(model, dataloaders, criterion_dict, optimizer, scheduler, num_epochs=100, device=device)
    torch.save(model.state_dict(), 'brain_mri_model.pth')

    print("Training complete!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
Found 6400 images in total.
Calling __len__ function!




Epoch 1/100
----------


100%|██████████| 1280/1280 [05:08<00:00,  4.15it/s]


train Loss: 11.1434 Acc: 0.4684


100%|██████████| 320/320 [00:59<00:00,  5.42it/s]


val Loss: 8.6331 Acc: 0.5523
Epoch 2/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.16it/s]


train Loss: 9.7905 Acc: 0.5135


100%|██████████| 320/320 [00:17<00:00, 18.53it/s]


val Loss: 8.6142 Acc: 0.5750
Epoch 3/100
----------


100%|██████████| 1280/1280 [03:26<00:00,  6.18it/s]


train Loss: 9.3597 Acc: 0.5102


100%|██████████| 320/320 [00:17<00:00, 18.34it/s]


val Loss: 8.8621 Acc: 0.5508
Epoch 4/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.18it/s]


train Loss: 9.1881 Acc: 0.5320


100%|██████████| 320/320 [00:17<00:00, 18.47it/s]


val Loss: 7.9488 Acc: 0.5664
Epoch 5/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.17it/s]


train Loss: 9.0723 Acc: 0.5172


100%|██████████| 320/320 [00:18<00:00, 17.04it/s]


val Loss: 8.1068 Acc: 0.5586
Epoch 6/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.17it/s]


train Loss: 9.1736 Acc: 0.5336


100%|██████████| 320/320 [00:18<00:00, 16.94it/s]


val Loss: 8.3636 Acc: 0.5578
Epoch 7/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.16it/s]


train Loss: 8.7519 Acc: 0.5291


100%|██████████| 320/320 [00:17<00:00, 18.17it/s]


val Loss: 8.3783 Acc: 0.5766
Epoch 8/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.18it/s]


train Loss: 7.8471 Acc: 0.5678


100%|██████████| 320/320 [00:17<00:00, 18.27it/s]


val Loss: 7.7580 Acc: 0.5898
Epoch 9/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.16it/s]


train Loss: 7.8448 Acc: 0.5740


100%|██████████| 320/320 [00:17<00:00, 18.25it/s]


val Loss: 7.5954 Acc: 0.6047
Epoch 10/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.16it/s]


train Loss: 7.6211 Acc: 0.5869


100%|██████████| 320/320 [00:17<00:00, 17.92it/s]


val Loss: 7.2306 Acc: 0.6063
Epoch 11/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 7.3375 Acc: 0.5945


100%|██████████| 320/320 [00:17<00:00, 18.04it/s]


val Loss: 7.2868 Acc: 0.6117
Epoch 12/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 7.2290 Acc: 0.6006


100%|██████████| 320/320 [00:17<00:00, 17.89it/s]


val Loss: 7.4251 Acc: 0.6125
Epoch 13/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.16it/s]


train Loss: 7.0835 Acc: 0.6176


100%|██████████| 320/320 [00:17<00:00, 18.03it/s]


val Loss: 7.3650 Acc: 0.6148
Epoch 14/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.16it/s]


train Loss: 6.9401 Acc: 0.6117


100%|██████████| 320/320 [00:17<00:00, 18.07it/s]


val Loss: 8.2565 Acc: 0.6000
Epoch 15/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.8374 Acc: 0.6229


100%|██████████| 320/320 [00:17<00:00, 18.32it/s]


val Loss: 6.5302 Acc: 0.6188
Epoch 16/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.5838 Acc: 0.6355


100%|██████████| 320/320 [00:17<00:00, 18.22it/s]


val Loss: 6.7761 Acc: 0.6242
Epoch 17/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.8009 Acc: 0.6303


100%|██████████| 320/320 [00:17<00:00, 18.29it/s]


val Loss: 7.0065 Acc: 0.6273
Epoch 18/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.14it/s]


train Loss: 6.6811 Acc: 0.6406


100%|██████████| 320/320 [00:17<00:00, 17.96it/s]


val Loss: 6.7293 Acc: 0.6281
Epoch 19/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.6862 Acc: 0.6377


100%|██████████| 320/320 [00:17<00:00, 18.08it/s]


val Loss: 6.8244 Acc: 0.6273
Epoch 20/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.14it/s]


train Loss: 6.5947 Acc: 0.6379


100%|██████████| 320/320 [00:17<00:00, 18.42it/s]


val Loss: 6.8471 Acc: 0.6297
Epoch 21/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.6739 Acc: 0.6410


100%|██████████| 320/320 [00:17<00:00, 18.33it/s]


val Loss: 6.7127 Acc: 0.6328
Epoch 22/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.6519 Acc: 0.6398


100%|██████████| 320/320 [00:17<00:00, 18.39it/s]


val Loss: 6.8315 Acc: 0.6227
Epoch 23/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.16it/s]


train Loss: 6.5976 Acc: 0.6420


100%|██████████| 320/320 [00:17<00:00, 18.05it/s]


val Loss: 6.7003 Acc: 0.6289
Epoch 24/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.4518 Acc: 0.6408


100%|██████████| 320/320 [00:17<00:00, 18.34it/s]


val Loss: 6.8581 Acc: 0.6320
Epoch 25/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.4163 Acc: 0.6521


100%|██████████| 320/320 [00:17<00:00, 18.42it/s]


val Loss: 6.5984 Acc: 0.6281
Epoch 26/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.5647 Acc: 0.6387


100%|██████████| 320/320 [00:17<00:00, 18.50it/s]


val Loss: 6.8425 Acc: 0.6320
Epoch 27/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.4770 Acc: 0.6486


100%|██████████| 320/320 [00:17<00:00, 18.27it/s]


val Loss: 6.8438 Acc: 0.6281
Epoch 28/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.5369 Acc: 0.6459


100%|██████████| 320/320 [00:17<00:00, 18.38it/s]


val Loss: 6.6726 Acc: 0.6305
Epoch 29/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.5788 Acc: 0.6441


100%|██████████| 320/320 [00:17<00:00, 18.48it/s]


val Loss: 6.4016 Acc: 0.6289
Epoch 30/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.5389 Acc: 0.6484


100%|██████████| 320/320 [00:17<00:00, 18.29it/s]


val Loss: 6.6205 Acc: 0.6313
Epoch 31/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.14it/s]


train Loss: 6.5696 Acc: 0.6354


100%|██████████| 320/320 [00:17<00:00, 18.29it/s]


val Loss: 6.5542 Acc: 0.6328
Epoch 32/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.5763 Acc: 0.6457


100%|██████████| 320/320 [00:17<00:00, 18.28it/s]


val Loss: 6.6996 Acc: 0.6281
Epoch 33/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.6979 Acc: 0.6467


100%|██████████| 320/320 [00:17<00:00, 18.42it/s]


val Loss: 6.5686 Acc: 0.6289
Epoch 34/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.15it/s]


train Loss: 6.6256 Acc: 0.6438


100%|██████████| 320/320 [00:17<00:00, 18.39it/s]


val Loss: 6.7116 Acc: 0.6297
Epoch 35/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.16it/s]


train Loss: 6.6357 Acc: 0.6391


100%|██████████| 320/320 [00:17<00:00, 18.41it/s]


val Loss: 6.7847 Acc: 0.6289
Epoch 36/100
----------


100%|██████████| 1280/1280 [03:27<00:00,  6.15it/s]


train Loss: 6.6759 Acc: 0.6451


100%|██████████| 320/320 [00:17<00:00, 18.34it/s]


val Loss: 6.6935 Acc: 0.6258
Epoch 37/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.6379 Acc: 0.6404


100%|██████████| 320/320 [00:17<00:00, 18.45it/s]


val Loss: 6.5484 Acc: 0.6297
Epoch 38/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.14it/s]


train Loss: 6.4745 Acc: 0.6514


100%|██████████| 320/320 [00:17<00:00, 18.22it/s]


val Loss: 6.3060 Acc: 0.6320
Epoch 39/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.14it/s]


train Loss: 6.4711 Acc: 0.6428


100%|██████████| 320/320 [00:17<00:00, 18.34it/s]


val Loss: 6.5717 Acc: 0.6297
Epoch 40/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.5789 Acc: 0.6393


100%|██████████| 320/320 [00:17<00:00, 18.16it/s]


val Loss: 6.7533 Acc: 0.6297
Epoch 41/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.14it/s]


train Loss: 6.4449 Acc: 0.6438


100%|██████████| 320/320 [00:17<00:00, 18.16it/s]


val Loss: 6.5984 Acc: 0.6313
Epoch 42/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.5363 Acc: 0.6404


100%|██████████| 320/320 [00:17<00:00, 18.10it/s]


val Loss: 6.6519 Acc: 0.6289
Epoch 43/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.6021 Acc: 0.6436


100%|██████████| 320/320 [00:17<00:00, 17.88it/s]


val Loss: 6.4335 Acc: 0.6242
Epoch 44/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.6017 Acc: 0.6396


100%|██████████| 320/320 [00:17<00:00, 18.05it/s]


val Loss: 6.5389 Acc: 0.6313
Epoch 45/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.6303 Acc: 0.6430


100%|██████████| 320/320 [00:17<00:00, 18.24it/s]


val Loss: 6.7138 Acc: 0.6297
Epoch 46/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.4598 Acc: 0.6463


100%|██████████| 320/320 [00:17<00:00, 18.41it/s]


val Loss: 6.7223 Acc: 0.6289
Epoch 47/100
----------


100%|██████████| 1280/1280 [03:28<00:00,  6.15it/s]


train Loss: 6.5371 Acc: 0.6488


100%|██████████| 320/320 [00:17<00:00, 18.35it/s]


val Loss: 6.7087 Acc: 0.6328
Epoch 48/100
----------


 25%|██▍       | 316/1280 [00:51<02:35,  6.19it/s]

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import random
import logging
from typing import List, Dict, Tuple

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import confusion_matrix, classification_report, mean_squared_error
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

class ResNetBlock(nn.Module):
    """Enhanced ResNet block with optional squeeze-and-excitation"""
    def __init__(self, in_planes, planes, stride=1, use_se=True):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

        # Squeeze-and-Excitation
        self.use_se = use_se
        if use_se:
            self.se = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(planes, planes // 16, kernel_size=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(planes // 16, planes, kernel_size=1),
                nn.Sigmoid()
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.use_se:
            se_weights = self.se(out)
            out = out * se_weights

        out += self.shortcut(x)
        out = F.relu(out)
        return out

class BrainMRIModel(nn.Module):
    def __init__(self, in_channels=3, num_classes=4, block_type=ResNetBlock):
        super(BrainMRIModel, self).__init__()
        self.in_planes = 64

        # Initial layers
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # ResNet layers with residual blocks
        self.layer1 = self._make_layer(block_type, 64, 2, stride=1)
        self.layer2 = self._make_layer(block_type, 128, 2, stride=2)
        self.layer3 = self._make_layer(block_type, 256, 2, stride=2)
        self.layer4 = self._make_layer(block_type, 512, 2, stride=2)

        # Global pooling and feature extraction
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.feature_extractor = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )

        # Multi-task heads
        self.classification_head = nn.Linear(256, num_classes)
        self.regression_heads = nn.ModuleDict({
            'mmse': nn.Linear(256, 1),
            'cdr': nn.Linear(256, 1)
        })

    def _make_layer(self, block_type, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for curr_stride in strides:
            layers.append(block_type(self.in_planes, planes, curr_stride))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.maxpool(F.relu(self.bn1(self.conv1(x))))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        features = self.feature_extractor(x)

        return {
            'classification': self.classification_head(features),
            'mmse_score': self.regression_heads['mmse'](features),
            'cdr_score': self.regression_heads['cdr'](features),
            'features': features
        }

class BrainMRIDataset(Dataset):
    def __init__(self, data_dir: str, classes: List[str] = None, transform=None):
        """Enhanced dataset with more robust loading and metadata handling"""
        self.data_dir = data_dir
        self.classes = classes or ['VeryMildDemented', 'MildDemented', 'NonDemented', 'ModerateDemented']
        self.transform = transform or self._get_default_transforms()

        self.images = []
        self.labels = []
        self.metadata = []  # Optional: store additional metadata

        self._load_dataset()

    def _get_default_transforms(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def _load_dataset(self):
        for class_idx, class_name in enumerate(self.classes):
            class_path = os.path.join(self.data_dir, class_name)
            for image_file in os.listdir(class_path):
                full_path = os.path.join(class_path, image_file)
                if os.path.isfile(full_path):
                    self.images.append(full_path)
                    self.labels.append(class_idx)
                    # Optional: Add metadata logic here

        logger.info(f"Loaded {len(self.images)} images across {len(self.classes)} classes")

    def __getitem__(self, idx):
        image_path = self.images[idx]
        label = self.labels[idx]

        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)

        # Simulated scores with more sophisticated generation
        noise = np.random.normal(0, 0.5)
        mmse_score = torch.tensor([max(0, min(30, 24 + noise * (label + 1)))], dtype=torch.float32)
        cdr_score = torch.tensor([max(0, min(3, 0.5 * label + noise))], dtype=torch.float32)

        return {
            'image': image,
            'label': torch.tensor(label, dtype=torch.long),
            'mmse_score': mmse_score,
            'cdr_score': cdr_score
        }

    def __len__(self):
        return len(self.images)

def train_model(
    model: nn.Module,
    dataloaders: Dict[str, DataLoader],
    criterion_dict: Dict[str, nn.Module],
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
    num_epochs: int = 50,
    device: str = 'cuda'
) -> Tuple[nn.Module, List[float], List[float]]:
    """Enhanced training function with early stopping and advanced logging"""
    model = model.to(device)
    best_val_loss = float('inf')
    patience = 10
    early_stopping_counter = 0

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            model.train() if phase == 'train' else model.eval()
            running_loss, correct, total = 0.0, 0, 0

            with torch.set_grad_enabled(phase == 'train'):
                for batch in tqdm(dataloaders[phase], desc=f"{phase.capitalize()} Epoch {epoch+1}"):
                    images = batch['image'].to(device)
                    labels = batch['label'].to(device)
                    mmse_scores = batch['mmse_score'].to(device)
                    cdr_scores = batch['cdr_score'].to(device)

                    optimizer.zero_grad()
                    outputs = model(images)

                    class_loss = criterion_dict['classification'](outputs['classification'], labels)
                    mmse_loss = criterion_dict['regression'](outputs['mmse_score'], mmse_scores)
                    cdr_loss = criterion_dict['regression'](outputs['cdr_score'], cdr_scores)

                    # Weighted loss
                    loss = class_loss + 0.3 * mmse_loss + 0.3 * cdr_loss

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    running_loss += loss.item()
                    _, predicted = torch.max(outputs['classification'], 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

            # Compute epoch metrics
            epoch_loss = running_loss / len(dataloaders[phase])
            epoch_acc = correct / total

            if phase == 'train':
                train_losses.append(epoch_loss)
                train_accuracies.append(epoch_acc)
                scheduler.step()
            else:
                val_losses.append(epoch_loss)
                val_accuracies.append(epoch_acc)

                # Early stopping
                if epoch_loss < best_val_loss:
                    best_val_loss = epoch_loss
                    early_stopping_counter = 0
                    torch.save(model.state_dict(), 'best_model.pth')
                else:
                    early_stopping_counter += 1

                if early_stopping_counter >= patience:
                    logger.info(f"Early stopping triggered after {epoch+1} epochs")
                    break

        logger.info(f"Epoch {epoch+1}: Train Loss {train_losses[-1]:.4f}, Val Loss {val_losses[-1]:.4f}")
        logger.info(f"Train Accuracy: {train_accuracies[-1]:.4f}, Val Accuracy: {val_accuracies[-1]:.4f}")

    return model, train_accuracies, val_accuracies

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and training configuration
    data_dir = "/path/to/Alzheimer_MRI_4_classes_dataset"

    # K-Fold Cross Validation
    dataset = BrainMRIDataset(data_dir)
    kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    for fold, (train_indices, val_indices) in enumerate(kfold.split(dataset.images, dataset.labels), 1):
        logger.info(f"Training Fold {fold}")

        train_dataset = torch.utils.data.Subset(dataset, train_indices)
        val_dataset = torch.utils.data.Subset(dataset, val_indices)

        train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

        dataloaders = {'train': train_loader, 'val': val_loader}

        model = BrainMRIModel(in_channels=3, num_classes=4)

        criterion_dict = {
            'classification': nn.CrossEntropyLoss(),
            'regression': nn.MSELoss()
        }

        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
        scheduler = CosineAnnealingLR(optimizer, T_max=50)

        model, train_acc, val_acc = train_model(
            model=model,
            dataloaders=dataloaders,
            criterion_dict=criterion_dict,
            optimizer=optimizer,
            scheduler=scheduler,
            num_epochs=50,
            device=device
        )

if __name__ == "__main__":
    main()

FileNotFoundError: [Errno 2] No such file or directory: '/path/to/Alzheimer_MRI_4_classes_dataset/VeryMildDemented'