In [10]:
import os
import glob
import numpy as np
import pandas as pd
import nibabel as nib
import torch
torch.cuda.empty_cache()
import torch.nn as nn
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn.functional as F
from torchvision.transforms import Resize
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
from sklearn import metrics
import torch.nn.init as init
from sklearn.utils import shuffle
from collections import Counter
from torch.optim.lr_scheduler import ExponentialLR
from torch.optim.lr_scheduler import MultiStepLR
from torch.cuda.amp import autocast, GradScaler
from imblearn.over_sampling import ADASYN
from tqdm import tqdm

from sklearn.model_selection import StratifiedKFold
from sklearn.utils import shuffle
from torch.utils.data import DataLoader
from collections import Counter


In [11]:
# Directory containing the concatenated images
concat_directory = "/home/yasmine/OASIS3/CNN/concat2"

In [12]:
# Set the device for training (GPU if available, else CPU)
device = torch.device("cuda")
# Print the device being used
print("Device:", device)
torch.cuda.is_available()

Device: cuda


True

In [13]:
class Conv3DModel(nn.Module):
    def __init__(self):
        super(Conv3DModel, self).__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 16 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, 4)  # Assuming you have 4 classes

    def forward(self, x):
        # x has shape (batch_size, num_patches, 1, 16, 16, 16)
        batch_size, num_patches, _, _, _ = x.size()
        x = x.view(batch_size * num_patches, 1, 16, 16, 16)
        x = self.conv1(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

In [18]:
# Define the custom dataset class
class MRICTDataset(Dataset):
    def __init__(self, data_directory, patch_size=(16, 16, 16), num_patches=200):
        self.data_files = glob.glob(os.path.join(data_directory, "*.nii.gz"))
        self.labels_df = pd.read_csv("/home/yasmine/OASIS3/CNN/hot_deck_labels_2.csv")
        self.label_encoder = LabelEncoder()
        self.resize_transform = Resize((16, 16))
        self.patch_size = patch_size
        self.num_patches = num_patches

        # Fit and transform the string labels to numerical values
        self.labels = self.label_encoder.fit_transform(self.labels_df["Diagnosis"])

        # Create a dictionary to map subject IDs to labels
        self.subject_labels = dict(zip(self.labels_df["Subject_ID"], self.labels))

    def __len__(self):
        return len(self.data_files)

    def extract_patches(self, image):
        _, depth, height, width = image.shape
        patches = []

        for z in range(0, depth - self.patch_size[0] + 1, self.patch_size[0]):
            for y in range(0, height - self.patch_size[1] + 1, self.patch_size[1]):
                for x in range(0, width - self.patch_size[2] + 1, self.patch_size[2]):
                    patch = image[:, z:z+self.patch_size[0], y:y+self.patch_size[1], x:x+self.patch_size[2]]
                    patch_std = np.std(patch)
                    patches.append((patch, patch_std))

        # Sort patches based on information measure (from most info to less info)
        patches.sort(key=lambda x: x[1], reverse=True)

        # Take the top 'num_patches' patches
        selected_patches = [patch for patch, _ in patches[:self.num_patches]]

        return selected_patches

    def __getitem__(self, index):
        file = self.data_files[index]
        image = nib.load(file).get_fdata()
        image = image.astype(np.float32)  # Convert image to NumPy array

        # Ensure that the image has the correct shape (1, depth, height, width)
        image = image[np.newaxis, ...]

        # Extract and select patches
        selected_patches = self.extract_patches(image)

        # Convert selected_patches to NumPy arrays
        selected_patches = np.array(selected_patches)

        # Extract the subject ID from the file name
        subject_id = os.path.splitext(os.path.basename(file))[0].split("_")[0][11:]

        # Retrieve the corresponding label from the dictionary
        label = self.subject_labels.get(subject_id, 3)
        #print(selected_patches.shape)

        return selected_patches, label

In [19]:

# Hyperparameters
batch_size = 32
learning_rate = 0.001
num_epochs = 20
num_patches = 200
# Split your dataset into training and validation sets
# Directory containing the concatenated images
concat_directory = "/home/yasmine/OASIS3/CNN/concat2"
dataset = MRICTDataset(concat_directory)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Initialize the model and optimizer
model = Conv3DModel()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [20]:
from tqdm import tqdm  # Import tqdm for the progress bar

# Training loop with progress bar
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    # Wrap train_loader with tqdm to add the progress bar
    for batch_data, batch_labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        optimizer.zero_grad()

        # Forward pass for each patch
        patch_outputs = []
        for patch_idx in range(num_patches):
            current_patch = batch_data[:, patch_idx, :, :, :]
            patch_output = model(current_patch)
            patch_outputs.append(patch_output)

        # Concatenate the patch outputs along the batch dimension
        outputs = torch.cat(patch_outputs, dim=1)

        # Compute loss for each subject (patch) separately
        loss = criterion(outputs, batch_labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

    # Calculate and print training accuracy
    train_accuracy = 100 * correct / total
    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    print(f"Train Loss: {total_loss / len(train_loader):.4f}")
    print(f"Train Accuracy: {train_accuracy:.2f}%")

    # Validation loop
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_data, batch_labels in val_loader:
            # Forward pass for each patch
            patch_outputs = []
            for patch_idx in range(num_patches):
                current_patch = batch_data[:, patch_idx, :, :, :]
                patch_output = model(current_patch)
                patch_outputs.append(patch_output)

            # Concatenate the patch outputs along the batch dimension
            outputs = torch.cat(patch_outputs, dim=1)

            # Compute loss for each subject (patch) separately
            loss = criterion(outputs, batch_labels)
            total_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += batch_labels.size(0)
            correct += (predicted == batch_labels).sum().item()

    # Calculate and print validation accuracy
    val_accuracy = 100 * correct / total
    print(f"Validation Loss: {total_loss / len(val_loader):.4f}")
    print(f"Validation Accuracy: {val_accuracy:.2f}%")

# After training and validation, you can include a testing loop in a similar manner


Epoch 1/20: 100%|██████████████████████████████████████| 34/34 [16:17<00:00, 28.76s/it]


Epoch [1/20]
Train Loss: 17.4819
Train Accuracy: 38.99%
Validation Loss: 5.2062
Validation Accuracy: 51.67%


Epoch 2/20: 100%|██████████████████████████████████████| 34/34 [44:27<00:00, 78.47s/it]


Epoch [2/20]
Train Loss: 4.2493
Train Accuracy: 53.26%
Validation Loss: 3.4074
Validation Accuracy: 53.16%


Epoch 3/20: 100%|███████████████████████████████████| 34/34 [1:08:47<00:00, 121.39s/it]


Epoch [3/20]
Train Loss: 2.6071
Train Accuracy: 59.24%
Validation Loss: 3.0378
Validation Accuracy: 56.13%


Epoch 4/20: 100%|███████████████████████████████████| 34/34 [1:07:19<00:00, 118.82s/it]


Epoch [4/20]
Train Loss: 2.3752
Train Accuracy: 62.31%
Validation Loss: 3.4824
Validation Accuracy: 54.65%


Epoch 5/20: 100%|███████████████████████████████████| 34/34 [1:14:57<00:00, 132.29s/it]


Epoch [5/20]
Train Loss: 2.3013
Train Accuracy: 61.75%
Validation Loss: 2.8729
Validation Accuracy: 53.53%


Epoch 6/20: 100%|███████████████████████████████████| 34/34 [1:14:04<00:00, 130.73s/it]


Epoch [6/20]
Train Loss: 2.0973
Train Accuracy: 61.66%
Validation Loss: 2.7971
Validation Accuracy: 57.99%


Epoch 7/20: 100%|███████████████████████████████████| 34/34 [1:06:49<00:00, 117.94s/it]


Epoch [7/20]
Train Loss: 2.1567
Train Accuracy: 61.94%
Validation Loss: 4.3395
Validation Accuracy: 53.53%


Epoch 8/20:  41%|███████████████▏                     | 14/34 [27:40<39:32, 118.61s/it]


KeyboardInterrupt: 

In [None]:
# Testing loop (similar to validation loop)
model.eval()
total_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for batch_data, batch_labels in test_loader:
        # Forward pass for each patch
        patch_outputs = []
        for patch_idx in range(num_patches):
            current_patch = batch_data[:, patch_idx, :, :, :]
            patch_output = model(current_patch)
            patch_outputs.append(patch_output)

        # Concatenate the patch outputs along the batch dimension
        outputs = torch.cat(patch_outputs, dim=1)

        # Compute loss for each subject (patch) separately
        loss = criterion(outputs, batch_labels)
        total_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

    print(f"Test Loss: {total_loss / len(test_loader):.4f}")
    print(f"Test Accuracy: {(100 * correct / total):.2f}%")