# MRI Embedding Extraction

This notebook extracts embeddings from MRI data using a pretrained CNN backbone.

**Data sources supported:**
- Kaggle 2D MRI image datasets (JPG/PNG)
- NACC volumetric MRI via S3 (NIfTI format)

**Output:** Compressed `.npz` file with float16 embeddings saved to Google Drive.

---

**Instructions:**
1. Select GPU runtime: Runtime → Change runtime type → T4 GPU
2. Run all cells in order
3. Embeddings will be saved to your Google Drive

In [None]:
# Mount Google Drive and clone the project repo
from google.colab import drive
drive.mount('/content/drive')

import os
PROJECT_DIR = '/content/drive/MyDrive/alzheimer-research'
os.makedirs(PROJECT_DIR, exist_ok=True)

# Clone repo if not already present
REPO_DIR = '/content/alzheimer-research'
if not os.path.exists(REPO_DIR):
    # Replace with your actual repo URL
    !git clone https://github.com/YOUR_USERNAME/alzheimer-research.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

os.chdir(REPO_DIR)
!pip install -q -r requirements.txt

In [None]:
import sys
sys.path.insert(0, REPO_DIR)

import numpy as np
import torch
from pathlib import Path
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## Option A: Kaggle 2D MRI Images

Download the Kaggle Alzheimer's multiclass dataset and extract embeddings.

In [None]:
# Download Kaggle dataset (set up kaggle.json first)
# !pip install -q kaggle
# !mkdir -p ~/.kaggle && cp /content/drive/MyDrive/kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json
# !kaggle datasets download -d aryansinghal10/alzheimers-multiclass-dataset-equal-and-augmented -p /tmp/mri_data
# !unzip -q /tmp/mri_data/*.zip -d /tmp/mri_data/

# OR: manually upload / specify the path to your data directory
MRI_DATA_DIR = '/tmp/mri_data'  # Update this path

# Expected structure:
# /tmp/mri_data/
#   NonDemented/
#   VeryMildDemented/
#   MildDemented/
#   ModerateDemented/

import os
if os.path.exists(MRI_DATA_DIR):
    for d in sorted(os.listdir(MRI_DATA_DIR)):
        full = os.path.join(MRI_DATA_DIR, d)
        if os.path.isdir(full):
            count = len(os.listdir(full))
            print(f'{d}: {count} images')
else:
    print(f'Data directory not found: {MRI_DATA_DIR}')
    print('Please download the dataset first (uncomment the kaggle commands above).')

In [None]:
from models.mri_cnn import MRIResNet2D
from torchvision import transforms
from PIL import Image
import pandas as pd

# Label mapping: folder name -> ordinal CDR class
LABEL_MAP = {
    'NonDemented': 0,
    'VeryMildDemented': 1,
    'MildDemented': 2,
    'ModerateDemented': 3,
}

# Image transform for 2D ResNet
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

# Load model
EMBED_DIM = 256
model = MRIResNet2D(embed_dim=EMBED_DIM, pretrained=True).to(device)
model.eval()
print(f'MRI ResNet-2D loaded (embed_dim={EMBED_DIM})')

In [None]:
# Extract embeddings from 2D MRI images
all_embeddings = []
all_labels = []
all_filenames = []

for class_name, label in LABEL_MAP.items():
    class_dir = os.path.join(MRI_DATA_DIR, class_name)
    if not os.path.exists(class_dir):
        print(f'Skipping {class_name}: directory not found')
        continue

    files = sorted(os.listdir(class_dir))
    print(f'Processing {class_name} ({len(files)} images)...')

    batch_images = []
    batch_size = 64

    for i, fname in enumerate(tqdm(files, desc=class_name)):
        fpath = os.path.join(class_dir, fname)
        try:
            img = Image.open(fpath).convert('L')  # grayscale
            img_tensor = transform(img)
            batch_images.append(img_tensor)
            all_labels.append(label)
            all_filenames.append(fname)
        except Exception as e:
            print(f'Error loading {fname}: {e}')
            continue

        # Process in batches
        if len(batch_images) >= batch_size:
            batch = torch.stack(batch_images).to(device)
            with torch.no_grad():
                embs = model.extract_embedding(batch)
            all_embeddings.append(embs.cpu().numpy())
            batch_images = []

    # Process remaining
    if batch_images:
        batch = torch.stack(batch_images).to(device)
        with torch.no_grad():
            embs = model.extract_embedding(batch)
        all_embeddings.append(embs.cpu().numpy())
        batch_images = []

embeddings = np.concatenate(all_embeddings, axis=0)
labels = np.array(all_labels)

print(f'\nExtracted {len(embeddings)} embeddings of dimension {embeddings.shape[1]}')
for name, lbl in LABEL_MAP.items():
    print(f'  {name}: {(labels == lbl).sum()}')

In [None]:
# Save compressed embeddings to Google Drive
SAVE_DIR = os.path.join(PROJECT_DIR, 'data_embeddings')
os.makedirs(SAVE_DIR, exist_ok=True)

# Save embeddings as float16 for compression
emb_path = os.path.join(SAVE_DIR, 'mri_embeddings.npz')
np.savez_compressed(emb_path, embeddings=embeddings.astype(np.float16))
print(f'Saved MRI embeddings to {emb_path}')
print(f'File size: {os.path.getsize(emb_path) / 1024 / 1024:.2f} MB')

# Save labels
labels_path = os.path.join(SAVE_DIR, 'labels.csv')
df = pd.DataFrame({
    'filename': all_filenames,
    'label': labels,
    'class_name': [list(LABEL_MAP.keys())[l] for l in labels],
})
df.to_csv(labels_path, index=False)
print(f'Saved labels to {labels_path}')

## Option B: NACC 3D MRI Volumes via S3

Download NIfTI volumes from NACC S3 bucket and extract 3D embeddings.

**Important:** Set your S3 credentials as environment variables, never hardcode them.

In [None]:
# Install AWS CLI and nibabel for NIfTI loading
!pip install -q boto3 nibabel

import boto3
from getpass import getpass

# Securely input credentials (DO NOT hardcode)
AWS_ACCESS_KEY = getpass('Enter AWS Access Key ID: ')
AWS_SECRET_KEY = getpass('Enter AWS Secret Access Key: ')
S3_BUCKET = 'adsp-phc-quickaccess'
S3_PREFIX = 'investigator/'

s3 = boto3.client(
    's3',
    aws_access_key_id=AWS_ACCESS_KEY,
    aws_secret_access_key=AWS_SECRET_KEY,
)

# List available files
response = s3.list_objects_v2(Bucket=S3_BUCKET, Prefix=S3_PREFIX, MaxKeys=20)
if 'Contents' in response:
    for obj in response['Contents']:
        print(f"{obj['Key']}  ({obj['Size'] / 1024 / 1024:.1f} MB)")
else:
    print('No objects found. Check credentials and bucket/prefix.')

In [None]:
import nibabel as nib
from scipy.ndimage import zoom
from models.mri_cnn import MRIResNet3D

# Load 3D model
model_3d = MRIResNet3D(embed_dim=EMBED_DIM, pretrained=True).to(device)
model_3d.eval()

TARGET_SHAPE = (64, 128, 128)  # (D, H, W) for 3D ResNet input

def preprocess_nifti(nifti_path):
    """Load and preprocess a NIfTI MRI volume."""
    img = nib.load(nifti_path)
    data = img.get_fdata().astype(np.float32)

    # Normalize to [0, 1]
    data = (data - data.min()) / (data.max() - data.min() + 1e-8)

    # Resize to target shape
    factors = [t / s for t, s in zip(TARGET_SHAPE, data.shape)]
    data = zoom(data, factors, order=1)

    # Add batch and channel dims: (1, 1, D, H, W)
    tensor = torch.tensor(data).unsqueeze(0).unsqueeze(0)
    return tensor

print('3D MRI model ready.')

In [None]:
# Download and process NACC MRI volumes
# This cell downloads one volume at a time to minimize disk usage

TEMP_DIR = '/tmp/nacc_mri'
os.makedirs(TEMP_DIR, exist_ok=True)

# List NIfTI files in the S3 bucket
paginator = s3.get_paginator('list_objects_v2')
nifti_keys = []
for page in paginator.paginate(Bucket=S3_BUCKET, Prefix=S3_PREFIX):
    for obj in page.get('Contents', []):
        if obj['Key'].endswith(('.nii', '.nii.gz')):
            nifti_keys.append(obj['Key'])

print(f'Found {len(nifti_keys)} NIfTI files')

nacc_embeddings = []
nacc_ids = []

for key in tqdm(nifti_keys[:100], desc='Processing NACC MRI'):  # Limit for demo
    local_path = os.path.join(TEMP_DIR, os.path.basename(key))
    try:
        # Download
        s3.download_file(S3_BUCKET, key, local_path)

        # Preprocess and extract embedding
        volume = preprocess_nifti(local_path).to(device)
        with torch.no_grad():
            emb = model_3d.extract_embedding(volume)
        nacc_embeddings.append(emb.cpu().numpy())
        nacc_ids.append(os.path.basename(key))

        # Delete raw file immediately
        os.remove(local_path)
    except Exception as e:
        print(f'Error processing {key}: {e}')
        if os.path.exists(local_path):
            os.remove(local_path)

if nacc_embeddings:
    nacc_emb_array = np.concatenate(nacc_embeddings, axis=0)
    print(f'Extracted {len(nacc_emb_array)} NACC embeddings')

    # Save
    nacc_path = os.path.join(SAVE_DIR, 'nacc_mri_embeddings.npz')
    np.savez_compressed(nacc_path, embeddings=nacc_emb_array.astype(np.float16), ids=nacc_ids)
    print(f'Saved to {nacc_path} ({os.path.getsize(nacc_path) / 1024 / 1024:.2f} MB)')

In [None]:
# Cleanup temp files
import shutil
if os.path.exists('/tmp/mri_data'):
    shutil.rmtree('/tmp/mri_data')
    print('Cleaned up /tmp/mri_data')
if os.path.exists('/tmp/nacc_mri'):
    shutil.rmtree('/tmp/nacc_mri')
    print('Cleaned up /tmp/nacc_mri')

print('Done! Embeddings saved to Google Drive.')