# Alzheimer MRI 4-Class Detection using SAM as Feature Extractor

This notebook uses the Segment Anything Model (SAM) ViT backbone as a feature extractor and a lightweight classifier for Alzheimer MRI classification. Optimized for 16GB RAM.

In [1]:
pip install torchvision

Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
import sys

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# --- Dataset Setup ---
class_names = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']
class_to_idx = {name: i for i, name in enumerate(class_names)}

class AlzheimerMRIDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        for class_name in class_names:
            class_dir = os.path.join(root_dir, class_name)
            for fname in os.listdir(class_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.samples.append((os.path.join(class_dir, fname), class_to_idx[class_name]))
        self.transform = transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Data transforms
# SAM models are trained on 1024x1024 images. Resizing smaller may lead to less optimal features
# but is necessary for memory constraints. ViT-B's feature output is 256.
img_size = 1024 # Keep 1024 as native SAM input, but be aware of memory
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Split dataset (simple split for demo)
# IMPORTANT: Replace this path with the actual path to your dataset
dataset = AlzheimerMRIDataset(r'C:\\Users\\shrir\\Music\\New folder\\Alzheimer_MRI_4_classes_dataset', transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Change num_workers to 0 for CPU to avoid multiprocessing issues
# Reduced batch_size significantly due to memory constraints on 16GB RAM for 1024x1024 images
batch_size = 1 # Start with batch size 1 or 2 due to severe memory limitations
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# --- Load SAM ViT Backbone as Feature Extractor ---
sys.path.append('.')  # Ensure SAM code can be found if in current directory

# IMPORTANT: You need to download sam_vit_b_01ec64.pth for this to work
# It's a smaller model suitable for memory-constrained environments.
# Download from: https://github.com/facebookresearch/segment-anything#model-checkpoints
sam_checkpoint = r'C:\Users\shrir\Music\New folder\sam_vit_b_01ec64.pth' # Updated path
try:
    from segment_anything import sam_model_registry
    sam = sam_model_registry['vit_b'](checkpoint=sam_checkpoint) # Changed to 'vit_b'
except FileNotFoundError:
    print(f"Error: SAM checkpoint '{sam_checkpoint}' not found.")
    print("Please ensure the file path is correct.")
    sys.exit(1) # Exit if SAM checkpoint is not found
except ImportError:
    print("Error: 'segment_anything' library not found.")
    print("Please install it: pip install git+https://github.com/facebookresearch/segment-anything.git")
    sys.exit(1)

sam.eval()
sam.to(device)

# Freeze SAM parameters
for param in sam.parameters():
    param.requires_grad = False

# Use only the image encoder as feature extractor
def extract_features(images):
    with torch.no_grad():
        feats = sam.image_encoder(images)
        # Global average pooling to get a fixed-size feature vector
        pooled = feats.mean(dim=[2, 3])
    return pooled

# --- Classifier on top of SAM features ---
# THIS CLASS DEFINITION MUST BE HERE BEFORE IT'S USED
class SAMClassifier(nn.Module):
    def __init__(self, feat_dim, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        return self.classifier(x)

# Feature dimension for ViT-B is 256.
feat_dim = 256
num_classes = 4
model = SAMClassifier(feat_dim, num_classes).to(device)

# --- Training Loop ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 5  # Increase for better results
print(f"Starting training with batch_size={batch_size}, img_size={img_size}, feat_dim={feat_dim}")
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images, labels = images.to(device), labels.to(device)
        try:
            feats = extract_features(images)
            outputs = model(feats)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
        except RuntimeError as e:
            print(f"RuntimeError during training batch: {e}")
            print(f"Image shape: {images.shape}, Labels shape: {labels.shape}")
            continue # Skip this batch if an error occurs

    if len(train_loader.dataset) > 0: # Avoid division by zero if dataset is empty
        print(f'Train Loss: {running_loss/len(train_loader.dataset):.4f}')
    else:
        print('Train Loss: N/A (Empty training dataset)')


    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            try:
                feats = extract_features(images)
                outputs = model(feats)
                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
            except RuntimeError as e:
                print(f"RuntimeError during validation batch: {e}")
                print(f"Image shape: {images.shape}, Labels shape: {labels.shape}")
                continue # Skip this batch if an error occurs
    if total > 0: # Avoid division by zero
        print(f'Val Accuracy: {100*correct/total:.2f}%')
    else:
        print('Val Accuracy: N/A (Empty validation dataset)')

# --- Inference Function ---
def predict_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        feats = extract_features(image)
        logits = model(feats)
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        confidence = probs[0, pred].item()
    return class_names[pred], confidence

# Example usage (uncomment to test):
# try:
#     # Make sure to replace with an actual image path from your dataset
#     example_image_path = r'C:\Users\shrir\Music\New folder\Alzheimer_MRI_4_classes_dataset\NonDemented\example.jpg'
#     if os.path.exists(example_image_path):
#         pred_class, conf = predict_image(example_image_path)
#         print(f'Predicted: {pred_class} (Confidence: {conf:.2f})')
#     else:
#         print(f"Example image not found at {example_image_path}. Please provide a valid path.")
# except Exception as e:
#     print(f"Error during inference example: {e}")

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm
import sys

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# --- Dataset Setup ---
class_names = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']
class_to_idx = {name: i for i, name in enumerate(class_names)}

class AlzheimerMRIDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        for class_name in class_names:
            class_dir = os.path.join(root_dir, class_name)
            for fname in os.listdir(class_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.samples.append((os.path.join(class_dir, fname), class_to_idx[class_name]))
        self.transform = transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Data transforms
# Data transforms
img_size = 1024 # Changed from 224 to 1024
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Split dataset (simple split for demo)
# IMPORTANT: Replace this path with the actual path to your dataset
dataset = AlzheimerMRIDataset(r'C:\\Users\\shrir\\Music\\New folder\\Alzheimer_MRI_4_classes_dataset', transform=transform) 
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Change num_workers from 2 to 0 to address potential hanging issues, especially on Windows
# Change batch_size to a smaller value
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0) # Changed from 16 to 2
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0)   # Changed from 16 to 2

# --- Load SAM ViT Backbone as Feature Extractor ---
sys.path.append('.')  # Ensure SAM code can be found if in current directory

# You need to have the SAM model code available (e.g., cloned from https://github.com/facebookresearch/segment-anything)
# and the pre-trained SAM checkpoint file ('sam_vit_h_4b8939.pth') in the same directory or specified path.
from segment_anything import sam_model_registry

sam_checkpoint = 'sam_vit_h_4b8939.pth'
try:
    sam = sam_model_registry['vit_h'](checkpoint=sam_checkpoint)
except FileNotFoundError:
    print(f"Error: SAM checkpoint '{sam_checkpoint}' not found.")
    print("Please download it from https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb (under 'Download the SAM checkpoint')")
    print("Or ensure 'segment_anything' library and the checkpoint are in the correct path.")
    sys.exit(1) # Exit if SAM checkpoint is not found

sam.eval()
sam.to(device)

# Freeze SAM parameters
for param in sam.parameters():
    param.requires_grad = False

# Use only the image encoder as feature extractor
def extract_features(images):
    with torch.no_grad():
        feats = sam.image_encoder(images)
        # Global average pooling to get a fixed-size feature vector
        pooled = feats.mean(dim=[2, 3])
    return pooled

# --- Classifier on top of SAM features ---
class SAMClassifier(nn.Module):
    def __init__(self, feat_dim, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        return self.classifier(x)

feat_dim = 1024  # For ViT-H
num_classes = 4
model = SAMClassifier(feat_dim, num_classes).to(device)

# --- Training Loop ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 5  # Increase for better results
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images, labels = images.to(device), labels.to(device)
        feats = extract_features(images)
        outputs = model(feats)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    print(f'Train Loss: {running_loss/len(train_loader.dataset):.4f}')

    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            feats = extract_features(images)
            outputs = model(feats)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    print(f'Val Accuracy: {100*correct/total:.2f}%')

# --- Inference Function ---
def predict_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        feats = extract_features(image)
        logits = model(feats)
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        confidence = probs[0, pred].item()
    return class_names[pred], confidence

# Example usage (uncomment to test):
# pred_class, conf = predict_image('Alzheimer_MRI_4_classes_dataset/NonDemented/example.jpg')
# print(f'Predicted: {pred_class} (Confidence: {conf:.2f})')

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:
# --- Dataset Setup ---
class_names = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']
class_to_idx = {name: i for i, name in enumerate(class_names)}

class AlzheimerMRIDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        for class_name in class_names:
            class_dir = os.path.join(root_dir, class_name)
            for fname in os.listdir(class_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.samples.append((os.path.join(class_dir, fname), class_to_idx[class_name]))
        self.transform = transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Data transforms
img_size = 224
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Split dataset (simple split for demo)
dataset = AlzheimerMRIDataset(r'C:\\Users\\shrir\\Music\\New folder\\Alzheimer_MRI_4_classes_dataset', transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Change num_workers from 2 to 0
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)

In [None]:
# --- Dataset Setup ---
class_names = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']
class_to_idx = {name: i for i, name in enumerate(class_names)}

class AlzheimerMRIDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.samples = []
        for class_name in class_names:
            class_dir = os.path.join(root_dir, class_name)
            for fname in os.listdir(class_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.samples.append((os.path.join(class_dir, fname), class_to_idx[class_name]))
        self.transform = transform
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Data transforms
img_size = 224
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Split dataset (simple split for demo)
dataset = AlzheimerMRIDataset(r'C:\Users\shrir\Music\New folder\Alzheimer_MRI_4_classes_dataset', transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

In [None]:
# --- Load SAM ViT Backbone as Feature Extractor ---
import sys
sys.path.append('.')  # If SAM code is in the current directory

# You need to have the SAM model code available, e.g. from https://github.com/facebookresearch/segment-anything
from segment_anything import sam_model_registry

sam_checkpoint = 'sam_vit_h_4b8939.pth'
sam = sam_model_registry['vit_h'](checkpoint=sam_checkpoint)
sam.eval()
sam.to(device)

# Freeze SAM parameters
for param in sam.parameters():
    param.requires_grad = False

# Use only the image encoder as feature extractor
def extract_features(images):
    with torch.no_grad():
        feats = sam.image_encoder(images)
        # Global average pooling
        pooled = feats.mean(dim=[2, 3])
    return pooled


In [None]:
# --- Classifier on top of SAM features ---
class SAMClassifier(nn.Module):
    def __init__(self, feat_dim, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        return self.classifier(x)

feat_dim = 1024  # For ViT-H
num_classes = 4
model = SAMClassifier(feat_dim, num_classes).to(device)

In [None]:
# --- Training Loop ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 5  # Increase for better results
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images, labels = images.to(device), labels.to(device)
        feats = extract_features(images)
        outputs = model(feats)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    print(f'Train Loss: {running_loss/len(train_loader.dataset):.4f}')
    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            feats = extract_features(images)
            outputs = model(feats)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    print(f'Val Accuracy: {100*correct/total:.2f}%')

In [None]:
# --- Inference Function ---
def predict_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        feats = extract_features(image)
        logits = model(feats)
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        confidence = probs[0, pred].item()
    return class_names[pred], confidence

# Example usage:
# pred_class, conf = predict_image('Alzheimer_MRI_4_classes_dataset/NonDemented/example.jpg')
# print(f'Predicted: {pred_class} (Confidence: {conf:.2f})')