In [1]:
import os
from pathlib import Path
import random
import numpy as np

import torch
from torch.utils.data import DataLoader
import torchmetrics

from eeg2fmri_datasets import EEG2fMRIDataset
from models import create_unet, EEGEncoder, fMRIDecoder, EEG2fMRINet
import utils
import data_cfg

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from matplotlib import animation, rc
rc('animation', html='jshtml')

In [2]:
def create_animation(gt_data: np.ndarray, pred_data: np.ndarray, figsize: int=5):
    """Visual animation for per-slice comparison
    """
    fig, axs = plt.subplots(1, 2, figsize=(figsize*2, figsize))

    gt_slice = utils.normalize_data(gt_data[:, :, 0])
    pred_slice = utils.normalize_data(pred_data[:, :, 0])
    
    im1 = axs[0].imshow(np.rot90(gt_slice, k=-1), cmap='gray')
    im2 = axs[1].imshow(np.rot90(pred_slice, k=-1), cmap='gray')

    axs[0].set_title("Ground-truth")
    axs[1].set_title("Generated")

    def animate_func(i):
        gt_slice = utils.normalize_data(gt_data[:, :, i])
        pred_slice = utils.normalize_data(pred_data[:, :, i])
        
        im1.set_data(np.rot90(gt_slice, k=-1))
        im2.set_data(np.rot90(pred_slice, k=-1))
        
    plt.tight_layout()
    plt.close()
    
    num_slices = gt_data.shape[-1]
    
    return animation.FuncAnimation(fig, animate_func, frames=num_slices, interval=100)

### Load data

In [3]:
data_name = 'NODDI'
fmri_channel = 30

assert data_name in ['NODDI', 'Oddball', 'CNEPFL']

data_root = Path(data_cfg.processed_data_roots[data_name])

# NODDI
if data_name == 'NODDI':
    test_list = ['43']

    # global min-max value
    eeg_min = -3.904906883760493
    eeg_max = 7.937204954155734
    
# Oddball
elif data_name == 'Oddball':
    test_ID = [9, 10] # test [9 - 10]
    test_list = []
    for idx in test_ID:
        indv_data = f"sub{idx:03}/task001_run001"
        test_list.append(indv_data)
    
    fmri_channel = 32
    
    # global min-max value
    eeg_min = -2.466110737041575
    eeg_max = 6.480417369333849
    
# CNEPFL
elif data_name == 'CNEPFL':
    individuals = sorted([Path(x).stem for x in os.listdir(data_root)])
    test_list = individuals[-4:] # last 4 individuals

    # global min-max value
    eeg_min = -4.551622643288133
    eeg_max = 7.93715188090758

eeg_test, fmri_test = utils.load_h5_from_list(data_root, test_list)

# normalize data (fmri_test is already in range [0 - 1])
eeg_test = utils.normalize_data(eeg_test, base_range=(eeg_min, eeg_max))

  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
# create datasets
test_dataset = EEG2fMRIDataset(eeg_test, fmri_test)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

### Load model

In [5]:
# define networks
eeg_encoder = EEGEncoder(in_channels=20, img_size=64)
unet_module = create_unet(in_channels=256, out_channels=256)
fmri_decoder = fMRIDecoder(in_channels=256, out_channels=fmri_channel)

model = EEG2fMRINet(eeg_encoder=eeg_encoder, unet_module=unet_module, fmri_decoder=fmri_decoder)

In [6]:
# replace the path to your checkpoint
ckpt_path = 'NODDI_idv43_ssim_0.6556.pth'
device = "cuda:0"

model.load_state_dict(torch.load(ckpt_path, weights_only=True))
model = model.to(device)
model = model.eval()

### Evaluation

In [7]:
ssim_metric = torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=1.0)
psnr_metric = torchmetrics.image.PeakSignalNoiseRatio()
rmse_metric = torchmetrics.MeanSquaredError(squared=False) # if False returns RMSE value

ssim_metric = ssim_metric.to(device)
psnr_metric = psnr_metric.to(device)
rmse_metric = rmse_metric.to(device)

total_batch = len(test_loader)

test_ssim_score = 0.0
test_psnr_score = 0.0
test_rmse_score = 0.0

for batch in tqdm(test_loader):
    eeg_batch, fmri_batch = batch
    
    eeg_batch = eeg_batch.to(device)
    fmri_batch = fmri_batch.to(device)
    
    # model prediction
    with torch.no_grad():
        pred_fmri = model(eeg_batch)

    # calculate SSIM & PSNR & RMSE
    test_ssim_score += ssim_metric(pred_fmri, fmri_batch).item()
    test_psnr_score += psnr_metric(pred_fmri, fmri_batch).item()
    test_rmse_score += rmse_metric(pred_fmri, fmri_batch).item()
    
ssim_score = test_ssim_score/total_batch
psnr_score = test_psnr_score/total_batch
rmse_score = test_rmse_score/total_batch

print(f"SSIM: {ssim_score:.7f}")
print(f"PSNR: {psnr_score:.7f}")
print(f"RMSE: {rmse_score:.7f}")

  0%|          | 0/274 [00:00<?, ?it/s]

SSIM: 0.6556156
PSNR: 19.3428653
RMSE: 0.1079007


### Predict and visualize results

In [8]:
# randomly sample data
eeg_data, fmri_data = random.choice(test_dataset)
eeg_data = eeg_data.unsqueeze(0)
eeg_data = eeg_data.to(device)

# model prediction
with torch.no_grad():
    pred_fmri = model(eeg_data)

In [9]:
# visualization
pred_data = pred_fmri[0].cpu().numpy()

pred_data = np.transpose(pred_data, (1, 2, 0))
gt_data = np.transpose(fmri_data, (1, 2, 0))

create_animation(gt_data, pred_data)