In [1]:
!pip install dipy
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dipy
  Downloading dipy-1.6.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (8.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m27.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dipy
Successfully installed dipy-1.6.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m548.0 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1


In [2]:
import os
from google.colab import drive
mount_path = '/content/drive'
drive_path = mount_path+"/MyDrive/dti-transformer/code/model"
results_path = drive_path+"/results"
test_data = mount_path+'/MyDrive/dti-transformer/dti_data'
drive.mount(mount_path)
os.chdir(drive_path)

Mounted at /content/drive


In [3]:
# Daniel Bandala @ nov-2022
# dti-model validation script
# general libraries
import csv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from math import log10, sqrt
# diffussion image processing
from dipy.io.image import load_nifti
# import torch libraries
import torch 
from torch import nn
# import dataset auxiliar libraries
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error
from dti_model import DiffusionTensorModel
from data_loader import data_preprocessing

In [10]:
signals = 7
maps = ["RGB"] #"MD","MO","L1","L2","L3",FA

In [12]:
# load model
checkpoint = torch.load(results_path+'/dti_rgb.weights') #dti_fa.weights FA_2023-03-30
model = DiffusionTensorModel(
    in_chans=signals,
    out_chans=3,
    img_size=140,
    embed_dim=64,
    n_heads=[1,2,4,8],
    mlp_ratio=[2,2,4,4],
    reduction_ratio=1,
    depth_prob=0.2,
    tanh_output=False
)
# use model in cpu for validation (gpu for training)
_ = model.to('cpu')
_ = model.load_state_dict(checkpoint) #torch.load(, map_location=torch.device('cpu'))
_ = model.eval()

In [13]:
test_list = [test_data+'/HCP/test/case_12',
              test_data+'/HCP/test/case_14',
              test_data+'/HCP/test/case_31',
              test_data+'/ADNI/test/case_13',
              test_data+'/ADNI/test/case_14',
              test_data+'/ADNI/test/case_30'
             ]

In [14]:
test_results = [["Slice","MSE","NMSE","SSIM","PSNR","Full path"]]
for data_path in test_list:
    print(f"Processing {data_path}")
    data_eval, label_eval = data_preprocessing(data_path, maps=maps, signals=signals)
    for sidx in range(data_eval.shape[0]):
        data = data_eval[sidx]
        label = label_eval[sidx]
        with torch.no_grad():
            output = model(data)
        # detach data
        label_np = label.detach().numpy()
        output_np = output.detach().numpy()
        # calculate metrics
        label_mean = label_np.mean()
        mse = mean_squared_error(label_np, output_np)
        nmse = mse/label_mean if label_mean!=0 else 0
        ssi = ssim(label_np, output_np, data_range=label_np.max() - label_np.min(), channel_axis = 0) #channel_axis = 0
        psnr = 20*log10(1/sqrt(mse))
        # append results
        test_results.append([os.path.basename(data_path)+f'_{sidx}',mse,nmse,ssi,psnr,data_path])

Processing /content/drive/MyDrive/dti-transformer/dti_data/HCP/test/case_12
Processing /content/drive/MyDrive/dti-transformer/dti_data/HCP/test/case_14
Processing /content/drive/MyDrive/dti-transformer/dti_data/HCP/test/case_31
Processing /content/drive/MyDrive/dti-transformer/dti_data/ADNI/test/case_13
Processing /content/drive/MyDrive/dti-transformer/dti_data/ADNI/test/case_14
Processing /content/drive/MyDrive/dti-transformer/dti_data/ADNI/test/case_30


In [15]:
# save results to csv file
pd.DataFrame(test_results).to_csv(results_path+"/test.csv", index=False, header=False)