In [None]:
import os
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, SubsetRandomSampler
from tqdm import tqdm
import nibabel as nib
import torchio as tio
from nilearn.masking import compute_brain_mask
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import (
    f1_score, 
    precision_score, 
    recall_score, 
    roc_auc_score, 
    confusion_matrix, 
    roc_curve, 
    accuracy_score
)
import joblib

In [None]:
# Function to clip background slices in all three axes
def clip_background_slices(image_data, threshold=0):
    non_empty_slices_axial = [i for i in range(image_data.shape[2]) if np.max(image_data[:, :, i]) > threshold]
    clipped_image_axial = image_data[:, :, non_empty_slices_axial]
    non_empty_slices_coronal = [i for i in range(clipped_image_axial.shape[1]) if np.max(clipped_image_axial[:, i, :]) > threshold]
    clipped_image_coronal = clipped_image_axial[:, non_empty_slices_coronal, :]
    non_empty_slices_sagittal = [i for i in range(clipped_image_coronal.shape[0]) if np.max(clipped_image_coronal[i, :, :]) > threshold]
    clipped_image_sagittal = clipped_image_coronal[non_empty_slices_sagittal, :, :]
    return clipped_image_sagittal

# Function to normalize image data
def normalize_image(image_data):
    transform = tio.RescaleIntensity(out_min_max=(0, 1))
    image_data = transform(torch.tensor(image_data).unsqueeze(0)).squeeze(0).numpy()
    return image_data

# Function to visualize a few slices
def visualize_slices(image_data, title, num_slices=5):
    fig, axs = plt.subplots(1, num_slices, figsize=(15, 5))
    print(image_data.shape)
    for i in range(num_slices):
        slice_idx = image_data.shape[2] // (num_slices + 1) * (i + 1)
        axs[i].imshow(image_data[:, :, slice_idx], cmap='gray')
        axs[i].set_title(f'Slice {slice_idx}')
    plt.suptitle(title)
    plt.show()

# Function to preprocess NIfTI file and visualize each step
def preprocess_nifti_with_visualization(file_path, target_size=(224, 224)):
    try:
        # Load NIfTI image
        nifti_image = nib.load(file_path)
        image_data = nifti_image.get_fdata()
        
        # Convert sagittal view to axial view
        axial_image = np.transpose(image_data, (0, 2, 1))
        #visualize_slices(axial_image, "Axial View of Original Image")

        # Skull stripping
        brain_mask = compute_brain_mask(nifti_image, threshold=0.1).get_fdata()

        # Check if brain mask is empty
        if np.sum(brain_mask) == 0:
            print(f"Skipping {file_path} due to empty brain mask.")
            return None
        
        # Transpose brain mask to match axial image
        brain_mask = np.transpose(brain_mask, (0, 2, 1))
        stripped_image = axial_image * brain_mask
        #visualize_slices(stripped_image, "Skull Stripped Image")

        # Resample to (1, 1, 1)
        resample = tio.Resample((1, 1, 1), image_interpolation='linear')
        subject = tio.Subject(image=tio.ScalarImage(tensor=stripped_image[np.newaxis, ...]))
        resampled_subject = resample(subject)
        image_data_resampled = resampled_subject.image.data.numpy().squeeze()
        #visualize_slices(image_data_resampled, "Resampled Image")

        # Normalize image data
        normalized_image = normalize_image(image_data_resampled)
        #visualize_slices(normalized_image, "Normalized Image")

        # Clip background slices
        image_data_clipped = clip_background_slices(normalized_image)
        #visualize_slices(image_data_clipped, "Clipped Image")

        # Extract 50 slices from the middle
        middle = image_data_clipped.shape[2] // 2
        start = max(0, middle - 25)
        end = min(image_data_clipped.shape[2], middle + 25)
        extracted_slices = image_data_clipped[:, :, start:end]
        #visualize_slices(extracted_slices, "Extracted Middle Slices")

        # Resize slices to target size
        resized_slices = []
        for slice_idx in range(extracted_slices.shape[2]):
            slice_data = extracted_slices[:, :, slice_idx]
            resized_slice = F.interpolate(torch.tensor(slice_data).unsqueeze(0).unsqueeze(0).float(), size=target_size, mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
            resized_slices.append(resized_slice.numpy())  
        resized_slices = np.stack(resized_slices, axis=-1)  
        #visualize_slices(resized_slices, "Resized Slices")

        return resized_slices
    
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None

# Load all samples and preprocess
root_dir = '/kaggle/input/adni3dataset/ADNI3'
classes = {'AD': 0, 'CN': 1, 'MCI': 2 , 'EMCI': 3}
samples = []
for class_label, class_idx in classes.items():
    class_dir = os.path.join(root_dir, f'{class_label}_Collection/ADNI')
    for root, _, files in os.walk(class_dir):
        for file_name in files:
            if file_name.endswith('.nii'):
                file_path = os.path.join(root, file_name)
                preprocessed_data = preprocess_nifti_with_visualization(file_path)
                if preprocessed_data is not None:
                    samples.append((preprocessed_data, class_idx))


all_slices = []
for data, class_idx in samples:
    for slice_idx in range(data.shape[2]):
        slice_data = data[:, :, slice_idx]
        all_slices.append((slice_data, class_idx))

In [None]:
class ExtractedSlicesDataset(Dataset):
    def __init__(self, slices):
        self.slices = slices

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        slice_data, class_idx = self.slices[idx]
        
        # Convert slice_data to torch tensor
        slice_data = torch.tensor(slice_data).float()
        
        # Repeat channel dimension to make 3 channels
        resized_slice = slice_data.repeat(3, 1, 1)  # Shape: [3, 224, 224]

        return resized_slice, class_idx

extracted_slices_dataset = ExtractedSlicesDataset(all_slices)

In [None]:
# Create DataLoader for the full dataset
batch_size = 25
full_loader = DataLoader(extracted_slices_dataset, batch_size=batch_size, shuffle=True)

# Print shapes
print(f"Full dataset size: {len(extracted_slices_dataset)}")

for i, (slices, label) in enumerate(full_loader):
    if i >= 1:  # Visualize only the first batch
        break
    print(f"Batch {i+1}:")
    print("Slices shape:", slices.shape)
    print("Labels:", label)


In [None]:
# Visualize a few slices
num_slices_to_visualize = 5  
fig, axes = plt.subplots(1, num_slices_to_visualize, figsize=(15, 5))

for i in range(num_slices_to_visualize):
    slice_data, class_idx = extracted_slices_dataset[40+i]
    slice_data_np = slice_data.numpy()  
    
    if slice_data_np.shape[0] == 3:
        slice_data_np = slice_data_np[0]  # Grayscale, take only one channel
    
    axes[i].imshow(slice_data_np, cmap='gray')
    axes[i].set_title(f"Class: {class_idx}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
def compute_specificity(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    tn = cm[0, 0]  # True negatives
    fp = cm[0, 1]  # False positives
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    return specificity


def extract_features(model, dataloader, device):
    model.eval()
    features = []
    labels = []
    
    with torch.no_grad():
        for inputs, lbls in tqdm(dataloader, desc="Extracting features", leave=False):
            inputs = inputs.to(device)
            lbls = lbls.to(device)
            
            outputs = model.features(inputs)
            outputs = torch.flatten(outputs, 1)
            
            features.append(outputs.cpu())
            labels.append(lbls.cpu())
    
    features = torch.cat(features, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()
    
    return features, labels


def initialize_densenet(num_classes):
    model = models.densenet201(pretrained=True)
    num_features = model.classifier.in_features
    model.classifier = nn.Linear(num_features, num_classes)
    return model


def evaluate_model(model, test_loader, device, num_classes):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Evaluating', leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(torch.softmax(outputs, dim=1).cpu().numpy())  

    # Accuracy
    test_accuracy = 100 * correct / total

    # F1, Precision, and Recall
    test_f1 = f1_score(all_labels, all_preds, average='weighted')
    test_precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    test_recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)

    # Specificity Calculation (One-vs-Rest approach)
    cm = confusion_matrix(all_labels, all_preds)
    specificity_per_class = []
    for i in range(num_classes):
        tn = cm.sum() - (cm[i, :].sum() + cm[:, i].sum() - cm[i, i])
        fp = cm[:, i].sum() - cm[i, i]
        specificity = tn / (tn + fp)
        specificity_per_class.append(specificity)
    test_specificity = sum(specificity_per_class) / num_classes  # Average specificity across classes

    # ROC AUC
    try:
        if num_classes > 2:
            test_roc_auc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='weighted')
        else:
            test_roc_auc = roc_auc_score(all_labels, all_probs[:, 1])
    except ValueError:
        test_roc_auc = None  

    return test_accuracy, test_f1, test_precision, test_recall, test_specificity, test_roc_auc


def train_densenet(model, train_loader, num_epochs, criterion, optimizer, device):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}', leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

def run_folds(fold_start, fold_end, dataset, num_classes, num_epochs=100):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    criterion = torch.nn.CrossEntropyLoss()
    
    slices = [slice_data for slice_data, _ in dataset.slices]
    labels = [class_idx for _, class_idx in dataset.slices]
    
    skf = StratifiedKFold(n_splits=fold_end)

    for fold, (train_idx, test_idx) in enumerate(skf.split(slices, labels)):
        if fold < fold_start:
            continue
        
        print(f"\nFold {fold+1}/{fold_end}")
        
        model = initialize_densenet(num_classes=num_classes).to(device)
        
        train_slices = [slices[i] for i in train_idx]
        test_slices = [slices[i] for i in test_idx]
        train_labels = [labels[i] for i in train_idx]
        test_labels = [labels[i] for i in test_idx]

        train_dataset = ExtractedSlicesDataset(list(zip(train_slices, train_labels)))
        test_dataset = ExtractedSlicesDataset(list(zip(test_slices, test_labels)))
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
        
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        # Train DenseNet for the specified number of epochs
        train_densenet(model, train_loader, num_epochs, criterion, optimizer, device)
        
        # Extract DenseNet features for GaussianNB
        train_features, train_labels = extract_features(model, train_loader, device)
        test_features, test_labels = extract_features(model, test_loader, device)
        
        # Train Gaussian Naive Bayes
        gnb_classifier = GaussianNB()
        gnb_classifier.fit(train_features, train_labels)
        test_predictions = gnb_classifier.predict(test_features)

        # Compute GaussianNB metrics
        test_accuracy = accuracy_score(test_labels, test_predictions)
        test_f1 = f1_score(test_labels, test_predictions, average='weighted')
        test_precision = precision_score(test_labels, test_predictions, average='weighted', zero_division=0)
        test_recall = recall_score(test_labels, test_predictions, average='weighted', zero_division=0)
        test_specificity = compute_specificity(test_labels, test_predictions)
        try:
            test_roc_auc = roc_auc_score(test_labels, test_predictions, multi_class='ovr', average='weighted')
        except ValueError:
            test_roc_auc = None  
        
        roc_auc_value = test_roc_auc if test_roc_auc is not None else 'N/A'  
        print(f"Fold {fold+1} - GaussianNB Metrics: Accuracy: {test_accuracy:.4f}, F1: {test_f1:.4f}, Precision: {test_precision:.4f}, "
              f"Recall: {test_recall:.4f}, Specificity: {test_specificity:.4f}, ROC AUC: {roc_auc_value}")

        # Save DenseNet and GaussianNB models for this fold
        torch.save(model.state_dict(), os.path.join('/kaggle/working/', f'fold_{fold+1}_densenet.pth'))
        joblib.dump(gnb_classifier, os.path.join('/kaggle/working/', f'fold_{fold+1}_gaussian_nb.pkl'))
        
        model = None
        
run_folds(fold_start=0, fold_end=10, dataset=extracted_slices_dataset, num_classes=4)
