In [8]:
# libraries
import numpy as np
import os
import cv2
import time
import random
import shutil
import glob

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import pydicom
from pydicom import dcmread

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [28]:
# device use CUDA
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# hyperparameters
random_seed = 1
learning_rate = 0.001
num_epochs = 10
batch_size = 128

# architecture
num_classes = 2

cuda:0


In [None]:
# Extract all .dcm files because this dataset is too big and I do not want to do all that sorting by hand, this and that

# Define the parent directory where dataset are located
parent_mri_dir = "C:/Users/j/Downloads/MRI_PET_2D_Dataset/ss"  # Change this to your actual dataset directory

# Define source directories based on the parent directory
source_mri_dirs = [os.path.join(parent_mri_dir, "MRI_Norm"), os.path.join(parent_mri_dir, "MRI_AD"),
                   os.path.join(parent_mri_dir, "PET_Norm"), os.path.join(parent_mri_dir, "PET_AD")]

# Define the destination directory
destination_mri_root = os.path.join(parent_mri_dir, "ss")

# Ensure the destination root directory exists
os.makedirs(destination_mri_root, exist_ok=True)

for source in source_mri_dirs:
    label = os.path.basename(source)  # Extracts "Norm_MRI" or "AD_MRI"
    destination = os.path.join(destination_mri_root, label)
    os.makedirs(destination, exist_ok=True)  # Create labeled destination folder

    # Walk through all subdirectories
    for root, _, files in os.walk(source):
        for file in files:
            if file.endswith(".dcm"):  # Check for .dcm files
                src_path = os.path.join(root, file)
                dst_path = os.path.join(destination, f"{os.path.splitext(file)[0]}.png")
                
                dicom_data = dcmread(src_path)
                pixel_array = dicom_data.pixel_array

                if len(pixel_array.shape) == 2:

                    if os.path.exists(dst_path):
                        print(f"Skipping {file}, already extracted.")
                        continue

                    plt.imshow(pixel_array, cmap=plt.cm.bone)
    
                    # Save the 2D image as a .png file in the new folder
                    # output_path = os.path.join(dst_path, f"{os.path.splitext(file)[0]}.png")
                    plt.imsave(dst_path, pixel_array, cmap='bone')  # Save using imsave
                    plt.clf()  # Clear the plot to avoid overlapping images in the next iteration

                    shutil.copy(src_path, dst_path)
                    print(f"Extracted: {file} to {dst_path} as .png")

print("Extraction complete. Files are sorted into labeled folders for images.")

In [14]:
base_dir = "C:/Users/J/Downloads/MRI_PET_2D_Dataset"
categories = {
    "MRI_Norm": 0, "MRI_AD": 1,
    "PET_Norm": 0, "PET_AD": 1
}

data = []
for category, label in categories.items():
    files = glob.glob(os.path.join(base_dir, category, "*.dcm"))
    data.extend([(f, label) for f in files])

np.random.shuffle(data)

# Split into 80% train, 20% test
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# Further split train into 90% test, 10% validation
test_data, val_data = train_test_split(test_data, test_size=0.1, random_state=42)

print(f"Train: {len(train_data)}, Validation: {len(val_data)}, Test: {len(test_data)}")


Train: 167275, Validation: 4182, Test: 37637


In [16]:
class BrainScanDataset:
    def __init__(self, data_list, transform=None):
        self.data_list = data_list
        self.transform = transform

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        file_path, label = self.data_list[idx]
        
        # Load DICOM image
        dicom_image = pydicom.dcmread(file_path).pixel_array

        # Ensure native byte order (NumPy 2.0 compatible)
        if dicom_image.dtype.byteorder not in ('=', '|'):
            dicom_image = dicom_image.view(dicom_image.dtype.newbyteorder('='))

        # Convert to float32 and normalize
        dicom_image = dicom_image.astype(np.float32)
        dicom_image = (dicom_image - dicom_image.min()) / (dicom_image.max() - dicom_image.min() + 1e-8)

        # Convert to PyTorch tensor
        image = torch.from_numpy(dicom_image).unsqueeze(0)  # Shape: (1, H, W)

        # Apply transformations
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

In [17]:
transform = transforms.Compose([
    transforms.Resize((192, 192), antialias=True),
    transforms.GaussianBlur(kernel_size=3, sigma=1.5),
    transforms.Normalize(mean=[0.5], std=[0.5]) 
])

train_dataset = BrainScanDataset(train_data, transform=transform)
val_dataset = BrainScanDataset(val_data, transform=transform)
test_dataset = BrainScanDataset(test_data, transform=transform)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)    

In [18]:
for images, labels in train_loader:
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break


Image batch dimensions: torch.Size([128, 1, 192, 192])
Image label dimensions: torch.Size([128])


In [19]:
for batch_idx, (data, targets) in enumerate(train_loader):
    print(f"Batch {batch_idx}:")
    print("  Data shape:", data.shape)
    print("  Targets shape:", targets.shape)
    print("  First data sample:\n", data[0])
    print("  First target sample:\n", targets[0])

    if batch_idx == 9:
        break
    

Batch 0:
  Data shape: torch.Size([128, 1, 192, 192])
  Targets shape: torch.Size([128])
  First data sample:
 tensor([[[1.5378e-05, 1.5378e-05, 1.5378e-05,  ..., 1.5378e-05,
          1.5378e-05, 1.5378e-05],
         [1.5378e-05, 1.5378e-05, 1.5378e-05,  ..., 1.5378e-05,
          1.5378e-05, 1.5378e-05],
         [1.5378e-05, 1.5378e-05, 1.5378e-05,  ..., 1.5378e-05,
          1.5378e-05, 1.5378e-05],
         ...,
         [1.5378e-05, 1.5378e-05, 1.5378e-05,  ..., 1.5378e-05,
          1.5378e-05, 1.5378e-05],
         [1.5378e-05, 1.5378e-05, 1.5378e-05,  ..., 1.5378e-05,
          1.5378e-05, 1.5378e-05],
         [1.5378e-05, 1.5378e-05, 1.5378e-05,  ..., 1.5378e-05,
          1.5378e-05, 1.5378e-05]]])
  First target sample:
 tensor(0)
Batch 1:
  Data shape: torch.Size([128, 1, 192, 192])
  Targets shape: torch.Size([128])
  First data sample:
 tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1.,

In [21]:
# CNN model
class CNN_model(torch.nn.Module):
    def __init__(self, num_classes):
        super(CNN_model, self).__init__()

        self.conv_1 = torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=0)

        self.maxpool_1 = torch.nn.MaxPool2d(2, 2)


        self.conv_2 = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0) 

        self.maxpool_2 = torch.nn.MaxPool2d(2, 2)


        self.conv_3 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=0) 

        self.maxpool_3 = torch.nn.MaxPool2d(2, 2)


        self.conv_4 = torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=0) 

        self.maxpool_4 = torch.nn.MaxPool2d(2, 2)


        self.fc1 = torch.nn.Linear(128 * 8 * 8, 1024)

        self.fc2 = torch.nn.Linear(1024, num_classes)

        

    def forward(self, x):
        out = F.relu(self.conv_1(x))
        out = self.maxpool_1(out)
        # output: [(batch_size)128, 16, 188, 188]

        out = F.relu(self.conv_2(out))
        out = self.maxpool_2(out)
        # Output: [(batch_size)128, 32, 184, 184]

        out = F.relu(self.conv_3(out))
        out = self.maxpool_3(out)
        # Output: [(batch_size)128, 64, 180, 180]

        out = F.relu(self.conv_4(out))
        out = self.maxpool_4(out)
        # Output: [(batch_size)128, 128, 176, 176]

        out = torch.flatten(out, start_dim=1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)

        return out

In [29]:
# training CNN model
torch.manual_seed(random_seed)

model = CNN_model(num_classes=num_classes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

criterion = torch.nn.CrossEntropyLoss()

start_time = time.time()

print(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (features, targets) in enumerate(train_loader):
        features, targets = features.to(device), targets.to(device)

        optimizer.zero_grad()
        
        outputs = model(features)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 10 == 0:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 100))
            running_loss = 0.0

print('CNN Training Finished')

cuda:0
[1,     1] loss: 0.007
[1,    11] loss: 0.068
[1,    21] loss: 0.065
[1,    31] loss: 0.062
[1,    41] loss: 0.064
[1,    51] loss: 0.060
[1,    61] loss: 0.058
[1,    71] loss: 0.057
[1,    81] loss: 0.056
[1,    91] loss: 0.069
[1,   101] loss: 0.062
[1,   111] loss: 0.058
[1,   121] loss: 0.057
[1,   131] loss: 0.055
[1,   141] loss: 0.056
[1,   151] loss: 0.058
[1,   161] loss: 0.054
[1,   171] loss: 0.054
[1,   181] loss: 0.056
[1,   191] loss: 0.054
[1,   201] loss: 0.054
[1,   211] loss: 0.054
[1,   221] loss: 0.054
[1,   231] loss: 0.049
[1,   241] loss: 0.052
[1,   251] loss: 0.051
[1,   261] loss: 0.052
[1,   271] loss: 0.056
[1,   281] loss: 0.052
[1,   291] loss: 0.051
[1,   301] loss: 0.054
[1,   311] loss: 0.052
[1,   321] loss: 0.051
[1,   331] loss: 0.055
[1,   341] loss: 0.049
[1,   351] loss: 0.049
[1,   361] loss: 0.053
[1,   371] loss: 0.053
[1,   381] loss: 0.052
[1,   391] loss: 0.052
[1,   401] loss: 0.051
[1,   411] loss: 0.051
[1,   421] loss: 0.051
[1, 

In [30]:
def evaluate_model(model, val_loader, criterion, device):
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for features, targets in val_loader:
            features, targets = features.to(device), targets.to(device)
            outputs = model(features)
            loss = criterion(outputs, targets)
            val_loss += loss.item()

            # Accuracy Calculation
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * correct / total
    print(f'Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')

    return avg_val_loss, val_accuracy


In [31]:
torch.save(model.state_dict(), "cnn_model.pth")
print("Model saved successfully.")

Model saved successfully.


In [32]:
loaded_model = CNN_model(num_classes=num_classes).to(device)
loaded_model.load_state_dict(torch.load("cnn_model.pth"))
loaded_model.eval()  # Set model to evaluation mode

print("Model loaded successfully.")

Model loaded successfully.


In [33]:
# compute accuracy of model
def compute_accuracy(model, data_loader):
    correct=0
    total=0

    for batch_idx, (features, targets) in enumerate(data_loader):
        features, targets = features.to(device), targets.to(device)

        outputs = model(features)
        _, predicted = torch.max(outputs, 1)

        correct += (predicted == targets).sum().item()
        total += targets.size(0)

    acc = (correct/total) * 100
    
    return acc

In [34]:
# compute sensitivity (recall)
def compute_sensitivity(model, data_loader):
    truePositive=0
    falseNegative=0

    for batch_idx, (features, targets) in enumerate(data_loader):
        features, targets = features.to(device), targets.to(device)

        outputs = model(features)
        _, predicted = torch.max(outputs, 1)

        truePositive += ((predicted == 1) & (targets == 1)).sum().item()
        falseNegative += ((predicted == 0) & (targets == 1)).sum().item()

    sens = truePositive / (truePositive + falseNegative)

    return sens

In [35]:
# compute specificity 
def compute_specificity(model, data_loader):
    trueNegative=0
    falsePositive=0

    for batch_idx, (features, targets) in enumerate(data_loader):
        features, targets = features.to(device), targets.to(device)

        outputs = model(features)
        _, predicted = torch.max(outputs, 1)

        trueNegative += ((predicted == 0) & (targets == 0)).sum().item()
        falsePositive += ((predicted == 1) & (targets == 0)).sum().item()

    spec = trueNegative / (trueNegative + falsePositive)

    return spec

In [36]:
# compute f1-score
def compute_f1_score(model, data_loader):
    truePositive=0
    trueNegative=0
    falsePositive=0
    falseNegative=0

    for batch_idx, (features, targets) in enumerate(data_loader):
        features, targets = features.to(device), targets.to(device)

        outputs = model(features)
        _, predicted = torch.max(outputs, 1)

        truePositive += ((predicted == 1) & (targets == 1)).sum().item()
        trueNegative += ((predicted == 0) & (targets == 0)).sum().item()

        falsePositive += ((predicted == 1) & (targets == 0)).sum().item()
        falseNegative += ((predicted == 0) & (targets == 1)).sum().item()

        # precison = TP / (TP + FP)
        precision = truePositive / (truePositive + falseNegative)

        # sensitivity = TP / (TP + FN)
        sensitivity = truePositive / (truePositive + falseNegative)

    f1 = (2 * precision * sensitivity) / (precision + sensitivity)

    return f1

In [37]:
print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))
print('Test sensitivity: %.2f%%' % (compute_sensitivity(model, test_loader)))
print('Test specificity: %.2f%%' % (compute_specificity(model, test_loader)))
print('Test f1-score: %.2f%%' % (compute_f1_score(model, test_loader)))

Test accuracy: 97.35%
Test sensitivity: 0.97%
Test specificity: 0.98%
Test f1-score: 0.97%
