In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import nibabel as nib
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchio as tio
from nilearn.masking import compute_brain_mask
from transformers import Dinov2Model
from sklearn.model_selection import KFold
from sklearn.metrics import (
    f1_score, 
    precision_score, 
    recall_score, 
    roc_auc_score, 
    confusion_matrix, 
    roc_curve
)
from tqdm import tqdm
from pytorch_metric_learning import losses, miners

In [None]:
# Function to clip background slices in all three axes
def clip_background_slices(image_data, threshold=0):
    non_empty_slices_axial = [i for i in range(image_data.shape[2]) if np.max(image_data[:, :, i]) > threshold]
    clipped_image_axial = image_data[:, :, non_empty_slices_axial]
    non_empty_slices_coronal = [i for i in range(clipped_image_axial.shape[1]) if np.max(clipped_image_axial[:, i, :]) > threshold]
    clipped_image_coronal = clipped_image_axial[:, non_empty_slices_coronal, :]
    non_empty_slices_sagittal = [i for i in range(clipped_image_coronal.shape[0]) if np.max(clipped_image_coronal[i, :, :]) > threshold]
    clipped_image_sagittal = clipped_image_coronal[non_empty_slices_sagittal, :, :]
    return clipped_image_sagittal

# Function to normalize image data
def normalize_image(image_data):
    transform = tio.RescaleIntensity(out_min_max=(0, 1))
    image_data = transform(torch.tensor(image_data).unsqueeze(0)).squeeze(0).numpy()
    return image_data

# Function to visualize a few slices
def visualize_slices(image_data, title, num_slices=5):
    fig, axs = plt.subplots(1, num_slices, figsize=(15, 5))
    print(image_data.shape)
    for i in range(num_slices):
        slice_idx = image_data.shape[2] // (num_slices + 1) * (i + 1)
        axs[i].imshow(image_data[:, :, slice_idx], cmap='gray')
        axs[i].set_title(f'Slice {slice_idx}')
    plt.suptitle(title)
    plt.show()

# Function to preprocess NIfTI file and visualize each step
def preprocess_nifti_with_visualization(file_path, target_size=(224, 224)):
    try:
        # Load NIfTI image
        nifti_image = nib.load(file_path)
        image_data = nifti_image.get_fdata()
        
        # Convert sagittal view to axial view
        axial_image = np.transpose(image_data, (0, 2, 1))
        #visualize_slices(axial_image, "Axial View of Original Image")

        # Skull stripping
        brain_mask = compute_brain_mask(nifti_image, threshold=0.1).get_fdata()

        # Check if brain mask is empty
        if np.sum(brain_mask) == 0:
            print(f"Skipping {file_path} due to empty brain mask.")
            return None
        
        # Transpose brain mask to match axial image
        brain_mask = np.transpose(brain_mask, (0, 2, 1))
        stripped_image = axial_image * brain_mask
        #visualize_slices(stripped_image, "Skull Stripped Image")

        # Resample to (1, 1, 1)
        resample = tio.Resample((1, 1, 1), image_interpolation='linear')
        subject = tio.Subject(image=tio.ScalarImage(tensor=stripped_image[np.newaxis, ...]))
        resampled_subject = resample(subject)
        image_data_resampled = resampled_subject.image.data.numpy().squeeze()
        #visualize_slices(image_data_resampled, "Resampled Image")

        # Normalize image data
        normalized_image = normalize_image(image_data_resampled)
        #visualize_slices(normalized_image, "Normalized Image")

        # Clip background slices
        image_data_clipped = clip_background_slices(normalized_image)
        #visualize_slices(image_data_clipped, "Clipped Image")

        # Extract 50 slices from the middle
        middle = image_data_clipped.shape[2] // 2
        start = max(0, middle - 25)
        end = min(image_data_clipped.shape[2], middle + 25)
        extracted_slices = image_data_clipped[:, :, start:end]
        #visualize_slices(extracted_slices, "Extracted Middle Slices")

        # Resize slices to target size
        resized_slices = []
        for slice_idx in range(extracted_slices.shape[2]):
            slice_data = extracted_slices[:, :, slice_idx]
            resized_slice = F.interpolate(torch.tensor(slice_data).unsqueeze(0).unsqueeze(0).float(), size=target_size, mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
            resized_slices.append(resized_slice.numpy())  
        resized_slices = np.stack(resized_slices, axis=-1)  
        #visualize_slices(resized_slices, "Resized Slices")

        return resized_slices
    
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None

# Load all samples and preprocess
root_dir = '/kaggle/input/adni3dataset/ADNI3'
classes = {'AD': 0, 'CN': 1, 'MCI': 2 , 'EMCI': 3}
samples = []
for class_label, class_idx in classes.items():
    class_dir = os.path.join(root_dir, f'{class_label}_Collection/ADNI')
    for root, _, files in os.walk(class_dir):
        for file_name in files:
            if file_name.endswith('.nii'):
                file_path = os.path.join(root, file_name)
                preprocessed_data = preprocess_nifti_with_visualization(file_path)
                if preprocessed_data is not None:
                    samples.append((preprocessed_data, class_idx))


all_slices = []
for data, class_idx in samples:
    for slice_idx in range(data.shape[2]):
        slice_data = data[:, :, slice_idx]
        all_slices.append((slice_data, class_idx))

In [None]:
class ExtractedSlicesDataset(Dataset):
    def __init__(self, slices):
        self.slices = slices

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        slice_data, class_idx = self.slices[idx]
        
        
        slice_data = torch.tensor(slice_data).float()
        
        # Repeat channel dimension to make 3 channels
        resized_slice = slice_data.repeat(3, 1, 1)  # Shape: [3, 224, 224]

        return resized_slice, class_idx

extracted_slices_dataset = ExtractedSlicesDataset(all_slices)

In [None]:
# Create DataLoader for the full dataset
batch_size = 25
full_loader = DataLoader(extracted_slices_dataset, batch_size=batch_size, shuffle=True)

# Print shapes
print(f"Full dataset size: {len(extracted_slices_dataset)}")

for i, (slices, label) in enumerate(full_loader):
    if i >= 1:  # Visualize only the first batch
        break
    print(f"Batch {i+1}:")
    print("Slices shape:", slices.shape)
    print("Labels:", label)


In [None]:
# Visualize a few slices
num_slices_to_visualize = 5  
fig, axes = plt.subplots(1, num_slices_to_visualize, figsize=(15, 5))

for i in range(num_slices_to_visualize):
    slice_data, class_idx = extracted_slices_dataset[40+i]
    slice_data_np = slice_data.numpy()  
    
    if slice_data_np.shape[0] == 3:
        slice_data_np = slice_data_np[0]  # Grayscale, take only one channel
    
    axes[i].imshow(slice_data_np, cmap='gray')
    axes[i].set_title(f"Class: {class_idx}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
class SplineLinear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
        self.init_scale = init_scale
        super().__init__(in_features, out_features, bias=False, **kw)

    def reset_parameters(self) -> None:
        nn.init.xavier_uniform_(self.weight)


class ReflectionalSwitchFunction(nn.Module):
    def __init__(
            self,
            grid_min: float = -2.,
            grid_max: float = 2.,
            num_grids: int = 8,
            exponent: int = 2,
            denominator: float = 0.33,  
    ):
        super().__init__()
        grid = torch.linspace(grid_min, grid_max, num_grids)
        self.grid = torch.nn.Parameter(grid, requires_grad=False)
        self.denominator = denominator  
        self.inv_denominator = 1 / self.denominator  

    def forward(self, x):
        diff = (x[..., None] - self.grid)
        diff_mul = diff.mul(self.inv_denominator)
        diff_tanh = torch.tanh(diff_mul)
        diff_pow = -diff_tanh.mul(diff_tanh)
        diff_pow += 1
        return diff_pow  


class FasterKANLayer(nn.Module):
    def __init__(
            self,
            input_dim: int,
            output_dim: int,
            grid_min: float = -2.,
            grid_max: float = 2.,
            num_grids: int = 8,
            exponent: int = 2,
            denominator: float = 0.33,
            use_base_update: bool = True,
            base_activation=F.silu,
            spline_weight_init_scale: float = 0.1,
    ) -> None:
        super().__init__()
        self.layernorm = nn.LayerNorm(input_dim)
        self.rbf = ReflectionalSwitchFunction(grid_min, grid_max, num_grids, exponent, denominator)
        self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)

    def forward(self, x, time_benchmark=False):
        if not time_benchmark:
            spline_basis = self.rbf(self.layernorm(x)).view(x.shape[0], -1)
        else:
            spline_basis = self.rbf(x).view(x.shape[0], -1)
        
        ret = self.spline_linear(spline_basis)
        return ret

class FasterKAN(nn.Module):
    def __init__(
            self,
            layers_hidden: List[int],
            grid_min: float = -2.,
            grid_max: float = 2.,
            num_grids: int = 8,
            exponent: int = 2,
            denominator: float = 0.33,
            use_base_update: bool = True,
            base_activation=F.silu,
            spline_weight_init_scale: float = 0.667,
    ) -> None:
        super().__init__()
        self.layers = nn.ModuleList([
            FasterKANLayer(
                in_dim, out_dim,
                grid_min=grid_min,
                grid_max=grid_max,
                num_grids=num_grids,
                exponent=exponent,
                denominator=denominator,
                use_base_update=use_base_update,
                base_activation=base_activation,
                spline_weight_init_scale=spline_weight_init_scale,
            ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
class kanBlock(nn.Module):
    def __init__(self, dim, num_heads=8, hdim_kan=192, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=attn_drop)
        self.drop_path = nn.Identity() if drop_path <= 0. else nn.Dropout(drop_path)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.kan = FasterKAN(
            layers_hidden=[dim, hdim_kan, dim],
            grid_min=-2.,
            grid_max=2.,
            num_grids=8,
            exponent=2,
            denominator=0.33,
            use_base_update=True,
            base_activation=act_layer(),
            spline_weight_init_scale=0.1
        )

    def forward(self, x):
        b, t, d = x.shape
        x_norm1 = self.norm1(x) 
        attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
        x = x + self.drop_path(attn_output)
        
        x_norm2 = self.norm2(x).reshape(-1, d)
        kan_output = self.kan(x_norm2).reshape(b, t, d) 
        x = x + self.drop_path(kan_output)
        
        return x

In [None]:
class DinoV2KAN(nn.Module):
    def __init__(self, num_classes=10):
        super(DinoV2KAN, self).__init__()
        
        # Load the pre-trained DINOv2 model
        self.dino = Dinov2Model.from_pretrained('facebook/dinov2-base')
        
        # Freeze DINOv2 layers
        for param in self.dino.parameters():
            param.requires_grad = False
        
        # Get the number of features from the DINOv2 model
        dino_feature_size = self.dino.config.hidden_size
        
        # Initialize kanBlock
        self.kan_block = kanBlock(
            dim=dino_feature_size,
            num_heads=12,
            hdim_kan=768,
            mlp_ratio=4.,
            qkv_bias=False,
            qk_scale=None,
            drop=0.1,
            attn_drop=0.1,
            drop_path=0.1,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm
        )
        
        # Final classification layer
        self.classifier = nn.Linear(dino_feature_size, num_classes)
    
    def forward(self, x):
        # Pass the input through the pre-trained DINOv2 model
        outputs = self.dino(x)
        features = outputs.last_hidden_state.mean(dim=1)  
        
        # Pass the features through kanBlock
        x = self.kan_block(features.unsqueeze(1)).squeeze(1)
        
        # Pass through the final classifier
        x = self.classifier(x)
        
        return x

In [None]:
# Initialize contrastive loss and miner
contrastive_loss = losses.ContrastiveLoss()
miner = miners.MultiSimilarityMiner()

In [None]:
# Hyperparameters
num_epochs = 100
learning_rate = 0.0001
betas = (0.9, 0.999)
eps = 1e-07
batch_size = 25
num_classes = 4  
k_folds = 10
checkpoint_dir = '/kaggle/working/'
random_seed = 42  

# Create directory for checkpoints if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)

# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize KFold with a fixed random seed
kf = KFold(n_splits=k_folds, shuffle=True, random_state=random_seed)

# Function to calculate specificity
def specificity(y_true, y_pred, num_classes):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
    specificities = []
    
    for i in range(num_classes):
        tp = cm[i, i]
        fn = np.sum(cm[i, :]) - tp
        fp = np.sum(cm[:, i]) - tp
        tn = np.sum(cm) - (tp + fn + fp)

        specificity = tn / (tn + fp) if (tn + fp) != 0 else 0
        specificities.append(specificity)
    
    # Return the average specificity across all classes
    return np.mean(specificities)


def run_folds(fold_start, fold_end):
    alpha = 0.1  # Hybrid loss weight for contrastive loss
    for fold, (train_ids, test_ids) in enumerate(kf.split(extracted_slices_dataset), start=fold_start):
        if fold >= fold_end:
            break
        
        print(f'Fold {fold + 1}/{k_folds}')
        
        train_sampler = SubsetRandomSampler(train_ids)
        test_sampler = SubsetRandomSampler(test_ids)
        
        # Create loaders
        train_loader = DataLoader(extracted_slices_dataset, batch_size=batch_size, sampler=train_sampler)
        test_loader = DataLoader(extracted_slices_dataset, batch_size=batch_size, sampler=test_sampler)
        
        # Initialize model, optimizer, loss function
        model = DinoV2KAN(num_classes=num_classes).to(device)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas, eps=eps)
        criterion_ce = nn.CrossEntropyLoss()

        for epoch in range(num_epochs):
            # Training loop
            model.train()
            train_loss_total = 0.0
            train_correct = 0
            train_total = 0

            for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} Training', leave=False):
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)

                # Compute both losses
                loss_ce = criterion_ce(outputs, labels)
                miner_output = miner(outputs, labels)  
                loss_contrastive = contrastive_loss(outputs, labels, miner_output)

                # Hybrid loss: alpha * contrastive + (1 - alpha) * cross-entropy
                total_loss = alpha * loss_contrastive + (1 - alpha) * loss_ce
                total_loss.backward()
                optimizer.step()

                _, predicted = torch.max(outputs, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()

                train_loss_total += total_loss.item() * inputs.size(0)

            # Test loop
            model.eval()
            test_loss_total = 0.0
            test_correct = 0
            test_total = 0
            all_labels = []
            all_predictions = []
            all_probs = []

            with torch.no_grad():
                for inputs, labels in tqdm(test_loader, desc=f'Test Epoch {epoch + 1}', leave=False):
                    inputs, labels = inputs.to(device), labels.to(device)

                    outputs = model(inputs)
                    loss_ce = criterion_ce(outputs, labels)

                    _, predicted = torch.max(outputs, 1)
                    test_total += labels.size(0)
                    test_correct += (predicted == labels).sum().item()

                    test_loss_total += loss_ce.item() * inputs.size(0)

                    all_labels.extend(labels.cpu().numpy())
                    all_predictions.extend(predicted.cpu().numpy())
                    all_probs.extend(torch.softmax(outputs, dim=1).cpu().numpy())

        test_loss = test_loss_total / len(test_loader.dataset)
        test_acc = 100.0 * test_correct / test_total

        # Compute additional metrics
        f1 = f1_score(all_labels, all_predictions, average='weighted')
        precision = precision_score(all_labels, all_predictions, average='weighted')
        recall = recall_score(all_labels, all_predictions, average='weighted')  
        specificity_avg = specificity(all_labels, all_predictions, num_classes=num_classes)

        
        all_labels_binarized = label_binarize(all_labels, classes=list(range(num_classes)))
        roc_auc = roc_auc_score(all_labels_binarized, np.array(all_probs), multi_class="ovr", average="weighted")

        
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, f'fold_{fold + 1}_best_model.pth'))

        
        print(f'\nMetrics for Fold {fold + 1}:')
        print(f'Test Accuracy: {test_acc:.2f}%')
        print(f'F1 Score: {f1:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f}')
        print(f'Specificity: {specificity_avg:.4f} | ROC-AUC: {roc_auc:.4f}\n')

        # Plot ROC curve for the fold
        plt.figure()
        fpr, tpr, _ = roc_curve(all_labels, all_probs)
        plt.plot(fpr, tpr, label=f'Class 1 (area = {roc_auc:.4f})')

        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC Curve for Fold {fold + 1}')
        plt.legend(loc='best')
        plt.show()

run_folds(0, 10)
