In [1]:
import numpy as np
import os
import cv2
import time
import random
import shutil
import glob

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import pydicom
from pydicom import dcmread

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [2]:
# device use CUDA
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# hyperparameters
random_seed = 1
learning_rate = 0.001
num_epochs = 10
batch_size = 128

# architecture
num_classes = 2

cuda:0


Extracting .dcm Files

In [None]:
# Extract all .dcm files because this dataset is too big and I do not want to do all that sorting by hand

# Define the parent directory where dataset are located
parent_mri_dir = "C:/Users/j/Downloads/MRI_PET_2D_Dataset/"  # Change this to your actual dataset directory

# Define source directories based on the parent directory
source_mri_dirs = [os.path.join(parent_mri_dir, "MRI_Norm"), os.path.join(parent_mri_dir, "MRI_AD"),
                   os.path.join(parent_mri_dir, "PET_Norm"), os.path.join(parent_mri_dir, "PET_AD")]

# Define the destination directory
destination_mri_root = os.path.join(parent_mri_dir, "extracted")

# Ensure the destination root directory exists
os.makedirs(destination_mri_root, exist_ok=True)

for source in source_mri_dirs:
    label = os.path.basename(source)  # Extracts "Norm_MRI" or "AD_MRI"
    destination = os.path.join(destination_mri_root, label)
    os.makedirs(destination, exist_ok=True)  # Create labeled destination folder

    # Walk through all subdirectories
    for root, _, files in os.walk(source):
        for file in files:
            if file.endswith(".dcm"):  # Check for .dcm files
                src_path = os.path.join(root, file)
                dst_path = os.path.join(destination, f"{os.path.splitext(file)[0]}.png")
                
                dicom_data = dcmread(src_path)
                pixel_array = dicom_data.pixel_array

                if len(pixel_array.shape) == 2:

                    if os.path.exists(dst_path):
                        print(f"Skipping {file}, already extracted.")
                        continue

                    plt.imshow(pixel_array, cmap=plt.cm.bone)
    
                    # Save the 2D image as a .png file in the new folder
                    # output_path = os.path.join(dst_path, f"{os.path.splitext(file)[0]}.png")
                    plt.imsave(dst_path, pixel_array, cmap='bone')  # Save using imsave
                    plt.clf()  # Clear the plot to avoid overlapping images in the next iteration

                    shutil.copy(src_path, dst_path)
                    print(f"Extracted: {file} to {dst_path} as .png")

print("Extraction complete. Files are sorted into labeled folders for images.")

Data Splitting

In [3]:
base_dir = "C:/Users/J/Downloads/MRI_PET_2D_Dataset"
categories = {
    "MRI_Norm": 0, "MRI_AD": 1,
    "PET_Norm": 0, "PET_AD": 1
}

data = []
for category, label in categories.items():
    files = glob.glob(os.path.join(base_dir, category, "*.dcm"))
    data.extend([(f, label) for f in files])

np.random.shuffle(data)

# Split into 80% train, 20% test
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

# Further split train into 90% test, 10% validation
train_data, val_data = train_test_split(train_data, test_size=0.1, random_state=42)

print(f"Train: {len(train_data)}, Validation: {len(val_data)}, Test: {len(test_data)}")


Train: 150547, Validation: 16728, Test: 41819


Creating Dataset

In [4]:
class BrainScanDataset:
    def __init__(self, data_list, transform=None):
        self.data_list = data_list
        self.transform = transform

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        file_path, label = self.data_list[idx]
        
        # Load DICOM image
        dicom_image = pydicom.dcmread(file_path).pixel_array

        # Ensure native byte order (NumPy 2.0 compatible)
        if dicom_image.dtype.byteorder not in ('=', '|'):
            dicom_image = dicom_image.view(dicom_image.dtype.newbyteorder('='))

        # Convert to float32 and normalize
        dicom_image = dicom_image.astype(np.float32)
        dicom_image = (dicom_image - dicom_image.min()) / (dicom_image.max() - dicom_image.min() + 1e-8)

        # Convert to PyTorch tensor
        image = torch.from_numpy(dicom_image).unsqueeze(0)  # Shape: (1, H, W)

        # Apply transformations
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

In [5]:
transform = transforms.Compose([
    transforms.Resize((192, 192), antialias=True),
    transforms.GaussianBlur(kernel_size=3, sigma=1.5),
    transforms.Normalize(mean=[0.5], std=[0.5]) 
])

train_dataset = BrainScanDataset(train_data, transform=transform)
val_dataset = BrainScanDataset(val_data, transform=transform)
test_dataset = BrainScanDataset(test_data, transform=transform)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)    

In [None]:
for batch_idx, (data, targets) in enumerate(train_loader):
    print(f"Batch {batch_idx}:")
    print("  Data shape:", data.shape)
    print("  Targets shape:", targets.shape)
    print("  First data sample:\n", data[0])
    print("  First target sample:\n", targets[0])

    if batch_idx == 2:
        break

Visualizing Neuroimages

In [None]:
# if the file is a directory, we need to use os.listdir() to get the .dcm files inside the directory
# pydicom.dcmread is trying to read a file and not a directory which our file is.

# for mri normal images
dicom_MRI_Norm_dir = "C:/Users/J/Downloads/MRI_PET_2D_Dataset/MRI_Norm"
dicom_MRI_Norm_files = [f for f in os.listdir(dicom_MRI_Norm_dir) if f.endswith(".dcm")]

random_int = random.randint(0, len(dicom_MRI_Norm_files))

if dicom_MRI_Norm_files:
    dicom_file = os.path.join(dicom_MRI_Norm_dir , dicom_MRI_Norm_files[random_int])
    dataset = pydicom.dcmread(dicom_file)

    pixel_array = dataset.pixel_array

    if len(pixel_array.shape) == 2:
        plt.imshow(dataset.pixel_array, cmap=plt.cm.bone)

    plt.title("Random MRI Normal DICOM Image")
    plt.axis("off")
    image_shape = dataset.pixel_array.shape
    print("Original MRI Normal Image Shape:", image_shape)
    plt.show()

else:
    print("No DICOM files found in the directory!")



# for mri ad images
dicom_MRI_AD_dir = "C:/Users/J/Downloads/MRI_PET_2D_Dataset/MRI_AD"
dicom_MRI_AD_files = [f for f in os.listdir(dicom_MRI_AD_dir) if f.endswith(".dcm")]

random_int = random.randint(0, len(dicom_MRI_AD_files))

if dicom_MRI_AD_files:
    dicom_file = os.path.join(dicom_MRI_AD_dir , dicom_MRI_AD_files[random_int])
    dataset = pydicom.dcmread(dicom_file)

    pixel_array = dataset.pixel_array

    if len(pixel_array.shape) == 2:
        plt.imshow(dataset.pixel_array, cmap=plt.cm.bone)

    plt.title("Random MRI AD DICOM Image")
    plt.axis("off")
    image_shape = dataset.pixel_array.shape
    print("Original MRI AD Image Shape:", image_shape)
    plt.show()
else:
    print("No DICOM files found in the directory!")

Model

In [8]:
# CNN model
class CNN_model(torch.nn.Module):
    def __init__(self, num_classes):
        super(CNN_model, self).__init__()

        self.conv_1 = torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=0)

        self.maxpool_1 = torch.nn.MaxPool2d(2, 2)


        self.conv_2 = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=0) 

        self.maxpool_2 = torch.nn.MaxPool2d(2, 2)


        self.conv_3 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=0) 

        self.maxpool_3 = torch.nn.MaxPool2d(2, 2)


        self.conv_4 = torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=0) 

        self.maxpool_4 = torch.nn.MaxPool2d(2, 2)


        self.fc1 = torch.nn.Linear(128 * 8 * 8, 1024)

        self.fc2 = torch.nn.Linear(1024, num_classes)

        

    def forward(self, x):
        out = F.relu(self.conv_1(x))
        # output: [(batch_size)128, 16, 188, 188]
        out = self.maxpool_1(out)
        # output: [94]

        out = F.relu(self.conv_2(out))
        # Output: [(batch_size)128, 32, 90, 90]
        out = self.maxpool_2(out)
        # Output: [45]

        out = F.relu(self.conv_3(out))
        # Output: [(batch_size)128, 64, 41, 41]
        out = self.maxpool_3(out)
        # Output: [20]

        out = F.relu(self.conv_4(out))
        # Output: [(batch_size)128, 128, 20, 20]
        out = self.maxpool_4(out)
        # Output: [8]

        out = torch.flatten(out, start_dim=1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)

        return out

Training Model

In [9]:
# training CNN model
torch.manual_seed(random_seed)

model = CNN_model(num_classes=num_classes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

criterion = torch.nn.CrossEntropyLoss()

start_time = time.time()

print(device)
print(f'Training for [ {num_epochs} ] epochs')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (features, targets) in enumerate(train_loader):
        features, targets = features.to(device), targets.to(device)

        optimizer.zero_grad()
        
        outputs = model(features)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 100 == 0:
            print('[Epoch %d, Batch %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 100))
            running_loss = 0.0

print('CNN Training Finished')

cuda:0
Training for [ 10 ] epochs
[Epoch 1, Batch     1] loss: 0.007
[Epoch 1, Batch   101] loss: 0.626
[Epoch 1, Batch   201] loss: 0.590
[Epoch 1, Batch   301] loss: 0.550
[Epoch 1, Batch   401] loss: 0.541
[Epoch 1, Batch   501] loss: 0.518
[Epoch 1, Batch   601] loss: 0.503
[Epoch 1, Batch   701] loss: 0.474
[Epoch 1, Batch   801] loss: 0.453
[Epoch 1, Batch   901] loss: 0.426
[Epoch 1, Batch  1001] loss: 0.409
[Epoch 1, Batch  1101] loss: 0.389
[Epoch 2, Batch     1] loss: 0.004
[Epoch 2, Batch   101] loss: 0.351
[Epoch 2, Batch   201] loss: 0.327
[Epoch 2, Batch   301] loss: 0.316
[Epoch 2, Batch   401] loss: 0.296
[Epoch 2, Batch   501] loss: 0.280
[Epoch 2, Batch   601] loss: 0.267
[Epoch 2, Batch   701] loss: 0.257
[Epoch 2, Batch   801] loss: 0.240
[Epoch 2, Batch   901] loss: 0.243
[Epoch 2, Batch  1001] loss: 0.224
[Epoch 2, Batch  1101] loss: 0.208
[Epoch 3, Batch     1] loss: 0.001
[Epoch 3, Batch   101] loss: 0.172
[Epoch 3, Batch   201] loss: 0.170
[Epoch 3, Batch   301

Evaluation

In [10]:
def evaluate_model(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for features, targets in val_loader:
            features, targets = features.to(device), targets.to(device)
            outputs = model(features)
            loss = criterion(outputs, targets)
            val_loss += loss.item()

            # Accuracy Calculation
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * correct / total
    print(f'Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')

    return avg_val_loss, val_accuracy


In [11]:
torch.save(model.state_dict(), f"cnn_model_v1_{num_epochs}_epochs.pth")
print(f"cnn_model_v1_{num_epochs}_epochs.pth saved successfully.")

cnn_model_v1_10_epochs.pth saved successfully.


In [12]:
loaded_model = CNN_model(num_classes=num_classes).to(device)
loaded_model.load_state_dict(torch.load(f"cnn_model_v1_{num_epochs}_epochs.pth"))
loaded_model.eval()

print(f"cnn_model_v1_{num_epochs}_epochs.pth loaded successfully.")

cnn_model_v1_10_epochs.pth loaded successfully.


In [13]:
# compute accuracy of model
def compute_accuracy(model, data_loader):
    correct=0
    total=0

    for batch_idx, (features, targets) in enumerate(data_loader):
        features, targets = features.to(device), targets.to(device)

        outputs = model(features)
        _, predicted = torch.max(outputs, 1)

        correct += (predicted == targets).sum().item()
        total += targets.size(0)

    acc = (correct/total) * 100
    
    return acc

In [14]:
# compute sensitivity (recall)
def compute_sensitivity(model, data_loader):
    truePositive=0
    falseNegative=0

    for batch_idx, (features, targets) in enumerate(data_loader):
        features, targets = features.to(device), targets.to(device)

        outputs = model(features)
        _, predicted = torch.max(outputs, 1)

        truePositive += ((predicted == 1) & (targets == 1)).sum().item()
        falseNegative += ((predicted == 0) & (targets == 1)).sum().item()

    sens = truePositive / (truePositive + falseNegative) * 100

    return sens

In [15]:
# compute specificity 
def compute_specificity(model, data_loader):
    trueNegative=0
    falsePositive=0

    for batch_idx, (features, targets) in enumerate(data_loader):
        features, targets = features.to(device), targets.to(device)

        outputs = model(features)
        _, predicted = torch.max(outputs, 1)

        trueNegative += ((predicted == 0) & (targets == 0)).sum().item()
        falsePositive += ((predicted == 1) & (targets == 0)).sum().item()

    spec = trueNegative / (trueNegative + falsePositive) * 100

    return spec

In [16]:
# compute f1-score
def compute_f1_score(model, data_loader):
    truePositive=0
    trueNegative=0
    falsePositive=0
    falseNegative=0

    for batch_idx, (features, targets) in enumerate(data_loader):
        features, targets = features.to(device), targets.to(device)

        outputs = model(features)
        _, predicted = torch.max(outputs, 1)

        truePositive += ((predicted == 1) & (targets == 1)).sum().item()
        trueNegative += ((predicted == 0) & (targets == 0)).sum().item()

        falsePositive += ((predicted == 1) & (targets == 0)).sum().item()
        falseNegative += ((predicted == 0) & (targets == 1)).sum().item()

        precision = truePositive / (truePositive + falseNegative)

        sensitivity = truePositive / (truePositive + falseNegative)

    f1 = (2 * precision * sensitivity) / (precision + sensitivity) * 100

    return f1

In [17]:
print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))
print('Test sensitivity: %.2f%%' % (compute_sensitivity(model, test_loader)))
print('Test specificity: %.2f%%' % (compute_specificity(model, test_loader)))
print('Test f1-score: %.2f%%' % (compute_f1_score(model, test_loader)))

Test accuracy: 96.42%
Test sensitivity: 94.41%
Test specificity: 97.71%
Test f1-score: 94.41%


In [None]:
from sklearn.metrics import confusion_matrix, classification_report

for batch_idx, (features, targets) in enumerate(test_loader):
    features, targets = features.to(device), targets.to(device)

    outputs = model(features)
    _, predicted = torch.max(outputs, 1)

cm = confusion_matrix(targets, predicted)
print(cm)

In [None]:
img_path = "//"
dcm_data = pydicom.dcmread(img_path)
pixel_array = dcm_data.pixel_array

image = Image.fromarray(pixel_array.astype(np.uint8)).convert("L")


transform = transforms.Compose([
    transforms.Resize((192, 192), antialias=True),
    transforms.GaussianBlur(kernel_size=3, sigma=1.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

image_tensor = transform(image).unsqueeze(0).to(device)

plt.imshow(pixel_array, cmap=plt.cm.bone)
plt.axis("off")
plt.show()

loaded_model = CNN_model(num_classes=num_classes).to(device)
loaded_model.load_state_dict(torch.load("cnn_model_v1.pth"))
loaded_model.eval()

with torch.no_grad():
    output = loaded_model(image_tensor)
    prediction = torch.argmax(output, dim=1)
    print(f"Predicted class: {prediction.item()}")

if prediction.item() == 0:
    print("No brain atrophy detected.")
else:
    print("Brain atrophy detected.")
