─────────────────────────────────────────────────────────────
# Deep BOLD Perceptual Similarity (DBPS) - Model Implementation

Basis:
  - Inspired by Zhang et al. (2018), "The Unreasonable Effectiveness of Deep Features
    as a Perceptual Metric" (https://arxiv.org/abs/1801.03924)
  - Instead of natural images (based on ImageNet), we apply the concept to fMRI BOLD slices.
  - Replace VGG features with a custom CNN feature extractor trained on fMRI GT data.

Goal:
  - Learn a feature extractor that measures perceptual similarity between
    GT (BOLD-activated) fMRI slices and denoised (or noised) slices.
  - Compute slice-wise similarity distances as a quantitative denoising metric.

This notebook builds:
  1. Data loading pipeline for GT slices.
  2. A simple CNN-based feature extractor.
  3. LPIPS-like feature-distance calculation (no 2AFC head).
  4. Training loop on GT data to learn domain-specific features.

─────────────────────────────────────────────────────────────

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

In [2]:
# ─── Device Setup ──────────────────────────────────────────────
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [3]:
# ─── Data Paths ────────────────────────────────────────────────
gt_path = '../inputs/ground_truth/gt_func_train_1.npy'  # shape (64, 64, 156, 300)
# We memory-map the data to avoid large RAM usage (7.5 GB for float64).
gt = np.load(gt_path, mmap_mode='r')
print("GT volume shape:", gt.shape)

GT volume shape: (64, 64, 156, 300)


In [4]:
from torch.utils.data import Dataset

class FMRI2DSliceDataset(Dataset):
    """
    Dataset for loading 2D fMRI slices from a 4D fMRI volume (64, 64, Z, T).
    Each sample is a single 2D slice (shape: 64×64), loaded as a float32 tensor.
    """
    def __init__(self, volume: np.ndarray):
        """
        Args:
            volume (np.ndarray): 4D numpy array, shape (H, W, Z, T).
        """
        super().__init__()
        self.volume = volume
        self.H, self.W, self.Z, self.T = volume.shape

    def __len__(self):
        # The dataset size is the total number of 2D slices (Z * T)
        return self.Z * self.T

    def __getitem__(self, idx):
        """
        Given a flat index (0…Z*T), map to (z, t) and return 2D slice.
        """
        z = idx // self.T
        t = idx % self.T
        # Load as float32 tensor
        slice_2d = self.volume[:, :, z, t].astype(np.float32)
        return torch.from_numpy(slice_2d).unsqueeze(0)  # shape: (1, 64, 64)

In [5]:
class SimpleFMRIEncoder(nn.Module):
    """
    A small CNN feature extractor for fMRI slices.
    We keep it shallow (3 conv blocks) to avoid overfitting.
    """
    def __init__(self):
        super().__init__()
        # Input: (1, 64, 64)
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.AvgPool2d(2)  # downsample after each conv

    def forward(self, x):
        # 1st conv block
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        # 2nd conv block
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)
        # 3rd conv block
        x = self.conv3(x)
        x = F.relu(x)
        x = self.pool(x)
        return x  # output shape: (64, 8, 8)

In [None]:
from torch.utils.data import DataLoader

# ──────────────────────────────────────────────────────────────
# Create Dataset & DataLoader
# We sample random 2D slices from the 4D fMRI volume as training samples.
# This way, our CNN learns to handle realistic anatomical + functional structures.
# ──────────────────────────────────────────────────────────────
dataset = FMRI2DSliceDataset(gt)
batch_size = 32
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
print(f"Dataset size: {len(dataset)} samples, Batch size: {batch_size}")

# ──────────────────────────────────────────────────────────────
# Instantiate the CNN feature extractor (our Encoder).
# It takes 1-channel (64,64) slices and compresses them into deep feature maps.
# ──────────────────────────────────────────────────────────────
encoder = SimpleFMRIEncoder().to(device)
print("Encoder architecture:")
print(encoder)

# ──────────────────────────────────────────────────────────────
# Loss & Optimizer
# We use a simple "autoencoding-like" objective: downsample input slice as target.
# No explicit anatomical labels needed, purely self-supervised.
# ──────────────────────────────────────────────────────────────
optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
print("Optimizer:", optimizer)
print("Loss function:", loss_fn)

# ──────────────────────────────────────────────────────────────
# Training loop
# Each epoch:
#   - Feed random 2D slices to the encoder
#   - Let the encoder learn robust fMRI-specific feature representations
# ──────────────────────────────────────────────────────────────
epochs = 5  # adjust as needed!
for epoch in range(epochs):
    encoder.train()
    total_loss = 0.0

    for step, batch in enumerate(loader):
        batch = batch.to(device)  # shape (B, 1, 64, 64)

        # Forward pass: extract deep feature maps
        features = encoder(batch)

        # Target: downsampled input slice to shape (B, 1, 8, 8)
        target = F.avg_pool2d(batch, kernel_size=8)
        target = target.repeat(1, 64, 1, 1)  # match output channels (64)

        # Compute MSE loss
        loss = loss_fn(features, target)

        # Backpropagation & optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate loss for statistics
        total_loss += loss.item() * batch.size(0)

        # Print every 50 steps for better progress visibility
        if step % 50 == 0:
            print(f"Epoch {epoch+1}, Step {step}/{len(loader)} - Batch Loss: {loss.item():.6f}")

    avg_loss = total_loss / len(dataset)
    print(f"Epoch {epoch+1}/{epochs} - Average Loss: {avg_loss:.6f}")
    print("─────────────────────────────────────────────")


Dataset size: 46800 samples, Batch size: 32
Encoder architecture:
SimpleFMRIEncoder(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)
Loss function: MSELoss()
