# CS184A Final Project: Alzheimer's Detection using CNN

Group 37: Jason Dunn, Rohan Jayasekara, Ben Boben

Dataset: https://www.kaggle.com/datasets/ninadaithal/imagesoasis/

### Dependencies

In [37]:
import os
import pandas as pd
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

### Load in data in small amount of scans from OASIS dataset for evaluation in <1 min

In [None]:
current_folder = os.getcwd()
data_path = os.path.join(current_folder, 'src', 'archive')
print(data_path)

c:\Users\rohan\Desktop\184a\AD-Detection\archive3


Each scan consists of 61 JPG images

We simplified our problem to a binary classification so there are only two labels: Non Demented and Demented.

In [39]:
# Create your DataFrame
images = []
labels = []
for subfolder in tqdm(os.listdir(data_path)):
    subfolder_path = os.path.join(data_path, subfolder)
    label = 0
    for folder in os.listdir(subfolder_path):
        subfolder_path2 = os.path.join(subfolder_path, folder)
        print(subfolder_path2)
        scan = []
        img_num = 0
        for image_filename in os.listdir(subfolder_path2):
            image_path = os.path.join(subfolder_path2, image_filename)
            scan.append(image_path)
            img_num += 1
            if img_num > 60:
                images.append(scan)
                if label == 2: # modded for binary classification
                    labels.append(0)
                else:
                    labels.append(1)
                img_num = 0
                scan = []
        label += 1

df = pd.DataFrame({'image': images, 'label': labels})
df

100%|██████████| 1/1 [00:00<00:00, 20.00it/s]

c:\Users\rohan\Desktop\184a\AD-Detection\archive3\Data\Mild-Dementia
c:\Users\rohan\Desktop\184a\AD-Detection\archive3\Data\Moderate-Dementia
c:\Users\rohan\Desktop\184a\AD-Detection\archive3\Data\Non-Demented
c:\Users\rohan\Desktop\184a\AD-Detection\archive3\Data\Very-mild-Dementia





Unnamed: 0,image,label
0,[c:\Users\rohan\Desktop\184a\AD-Detection\arch...,1
1,[c:\Users\rohan\Desktop\184a\AD-Detection\arch...,1
2,[c:\Users\rohan\Desktop\184a\AD-Detection\arch...,1
3,[c:\Users\rohan\Desktop\184a\AD-Detection\arch...,1
4,[c:\Users\rohan\Desktop\184a\AD-Detection\arch...,1
...,...,...
133,[c:\Users\rohan\Desktop\184a\AD-Detection\arch...,1
134,[c:\Users\rohan\Desktop\184a\AD-Detection\arch...,1
135,[c:\Users\rohan\Desktop\184a\AD-Detection\arch...,1
136,[c:\Users\rohan\Desktop\184a\AD-Detection\arch...,1


### Custom Dataset applying resizing and normalization

In [40]:
class ADDataset(Dataset):
    def __init__(self, dataframe, transform=None):

        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_paths = self.dataframe.iloc[idx, 0]  # List of 61 image paths
        label = self.dataframe.iloc[idx, 1]

        # Load and process images
        images = []
        for img_path in img_paths:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            images.append(image)

        # Stack the images into a tensor of shape
        images = torch.stack(images)
        label = torch.tensor(label, dtype=torch.long)

        return images, label

# Define transformations for each image
transform = transforms.Compose([
    transforms.Resize((176, 176)),  # Resize all images to 176x176
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet stats
])

In [41]:
test_dataset = ADDataset(dataframe=df, transform=transform)
batch_size = 8
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### CNN Model

Utilizes 4 feature extraction blocks

In [42]:
class FeatureExtractionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FeatureExtractionBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.batch_norm(x)
        x = self.max_pool(x)
        return x


class ADCNN(nn.Module):
    def __init__(self, dropout_rate=0.25):
        super(ADCNN, self).__init__()

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            FeatureExtractionBlock(16, 32),
            FeatureExtractionBlock(32, 64),
            FeatureExtractionBlock(64, 128),
            FeatureExtractionBlock(128, 256),
        )

        self.dropout1 = nn.Dropout(p=dropout_rate)
        self.dropout2 = nn.Dropout(p=dropout_rate)
        
        # Fully connected layers
        self.fc1 = nn.Linear(256 * 5 * 5, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 64) 
        self.fc4 = nn.Linear(64, 2)
        
    def forward(self, x):
        # Input: (batch_size, 61, 3, 176, 176)
        batch_size, seq_len, channels, height, width = x.shape

        # Flatten temporal dimension
        x = x.view(batch_size * seq_len, channels, height, width)  # Shape: (batch_size * 61, 3, 176, 176)
        x = self.feature_extractor(x)

        # Flatten spatial dimensions
        x = torch.flatten(x, start_dim=1)  # Shape: (batch_size * 61, 256 * 5 * 5)
        x = x.view(batch_size, seq_len, -1)

        x = torch.mean(x, dim=1)  # Shape: (batch_size, 256 * 5 * 5)

        # Fully connected layers
        x = self.dropout1(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)  # Final output: (batch_size, 2)

        return x

In [43]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = ADCNN().to(device)

cuda:0


### Here is our training function for reference but we will not run it for this demo

In [44]:
def train_model(model, train_loader, num_epochs=20, learning_rate=0.001, device='cuda'):
    model.to(device)
    
    optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        # Print stats for the epoch
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

        if epoch % 5 == 0:
            torch.save(model.state_dict(), r'C:\Users\rohan\Desktop\184a\AD-Detection\trained.pth')
     
    print('Training Finished!')

### Load the weights from our pre-trained model

In [None]:
current_folder = os.getcwd()
model_path = os.path.join(current_folder, 'src','3d_cnn_model.pth')
print(model_path)
model.load_state_dict(torch.load(model_path))

  model.load_state_dict(torch.load(model_path))


c:\Users\rohan\Desktop\184a\AD-Detection\3d_cnn_model.pth


<All keys matched successfully>

### Evaluate using Confusion Matrix, Precision, Recall, and F1 Score

Important Note: There is data leakage in this demo due to the fact that we used sklearn train_test_split to train our model with the full dataset, and we were unable to replicate this when creating this small dataset. Please see src/3D-AD-Detection-CNN.ipynb for our complete process.

In [46]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import itertools
import torch

def evaluate_model(model, test_loader, device='cuda'):
    y_pred_list = []
    y_target_list = []

    model.eval()
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Get model predictions
            outputs = model(inputs)
            _, y_pred = torch.max(outputs, 1)  # Predicted labels
            
            y_pred_list.append(y_pred.cpu().numpy())  # Move to CPU for metrics
            y_target_list.append(labels.cpu().numpy())

    # Flatten the predictions and targets
    y_pred_list = list(itertools.chain.from_iterable(y_pred_list))
    y_target_list = list(itertools.chain.from_iterable(y_target_list))

    # Confusion Matrix
    conf_matrix = confusion_matrix(y_target_list, y_pred_list)
    print("Confusion Matrix of the Test Set")
    print("-----------")
    print(conf_matrix)

    # Calculate metrics
    precision = precision_score(y_target_list, y_pred_list, average='weighted')
    recall = recall_score(y_target_list, y_pred_list, average='weighted')
    f1 = f1_score(y_target_list, y_pred_list, average='weighted')

    print(f"Precision of the Model :\t{precision:.4f}")
    print(f"Recall of the Model    :\t{recall:.4f}")
    print(f"F1 Score of the Model  :\t{f1:.4f}")

evaluate_model(model, test_dataloader, device='cuda')

Confusion Matrix of the Test Set
-----------
[[79  0]
 [ 0 59]]
Precision of the Model :	1.0000
Recall of the Model    :	1.0000
F1 Score of the Model  :	1.0000
