In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
from typing import TypeAlias

import numpy as np
import nibabel
import matplotlib.pyplot as plt

In [3]:
subject_dir = Path("FREESURFER/PAT_001/")
parenchyma_mask(Path("FREESURFER/PAT_001/"))

NameError: name 'parenchyma_mask' is not defined

In [4]:
CSF_LABELS = [4, 5, 14, 15, 24, 43, 44, 72, 213, 221]

NiftiImage: TypeAlias = nibabel.nifti1.Nifti1Image
def create_parenchyma_mask(aseg_file: Path) -> NiftiImage:
    aseg = nibabel.load(aseg_file)
    aseg_data = aseg.get_fdata().astype(int)

    csf_mask = np.zeros(aseg_data.shape, dtype=bool)
    for csf_label in CSF_LABELS:
        csf_mask[(aseg_data == csf_label)] = True
    brain_mask = (aseg_data > 0) * (~csf_mask)
    return nibabel.Nifti1Image(brain_mask.astype(float), aseg.affine)


def parenchyma_mask(subject_dir: Path) -> NiftiImage:
    output_dir = subject_dir / "mri"
    output_file = output_dir / "parenchyma_mask.mgz"
    aseg_file = subject_dir / "mri" / "aseg.mgz"
    image = create_parenchyma_mask(aseg_file)
    nibabel.save(image, output_file)
    return output_file


def syntethic_t1_constant(t1: float, aseg_file: Path) -> NiftiImage:
    mask = create_parenchyma_mask(aseg_file)
    
    return 


subject_dir = Path("FREESURFER/PAT_001/")
parenchyma_mask(Path("FREESURFER/PAT_001/"))

PosixPath('FREESURFER/PAT_001/mri/parenchyma_mask.mgz')

In [5]:
# T1-map thresholds, everything else is considered noise.
TMIN = 0.2
TMAX = 4.2
B = 1.48087682
RELAXIVITY_CONSTANT = 3.2


def signal_to_t1(St: np.ndarray, S0: np.ndarray, T10: np.ndarray, b: float) -> np.ndarray:
#     return T10 - np.log(np.maximum(St, S0) / S0) / b
    return T10 - np.log(St / S0) / b



def t1_to_concentration(T1: np.ndarray, T10: np.ndarray, r: float) -> np.ndarray:
    return 1.0 / r * (1.0 / T1 - 1.0 / T10)


def signal_to_consentration(St: np.ndarray, S0: np.ndarray, T10: np.ndarray,
                           r: float, b: float) -> np.ndarray:
    T1 = signal_to_t1(St, S0, T10, b)
#     Thresholding processing. 
    T1 = np.where(T1 > TMIN, T1, TMIN)
    T1 = np.where(T1 < TMAX, T1, TMAX)
    return t1_to_concentration(T1, T10, r)

In [None]:
input_dir = subject_dir / "NORMALIZED"
input_images = sorted(input_dir.iterdir())

output_dir = subject_dir / "CONCS_s"
output_dir.mkdir(exist_ok=True)


# Generate fake T1-map
mask = create_parenchyma_mask(subject_dir / "mri/aseg.mgz").get_fdata().astype(bool)
t1map = 1.0 * mask


# Process baseline.
image = nibabel.load(input_images[0])
base_intensity = image.get_fdata()
base_mask = (mask > 0) * (base_intensity > 0)
base_intensity *= base_mask
base_transform = image.affine


for image_path in input_images[1:]:
    image = nibabel.load(image_path)
    St = image.get_fdata() * base_mask
    assert np.allclose(base_transform, image.affine), "Registering not good enough."
    
    concentration = np.zeros_like(St)
    concentration[base_mask] = signal_to_consentration(St[base_mask], base_intensity[base_mask], t1map[base_mask], r=RELAXIVITY_CONSTANT, b=B)
    concentration_image = nibabel.Nifti1Image(concentration, base_transform)
    
    nibabel.save(concentration_image, output_dir / image_path.name)
    print(concentration.min(), concentration.max())
    plt.imshow(concentration[100])
    plt.colorbar()
    plt.show()

### Some Stats

In [None]:
from datetime import datetime

start_time = datetime.strptime(input_images[0].stem, "%Y%m%d_%H%M%S")



mass = [0.0]
mean = [0.0]
median = [0.0]
times = [start_time]
for image_path in sorted(output_dir.iterdir()):
    if "template" in image_path.name:
        continue
        
    image = nibabel.load(image_path)
    image_data = image.get_fdata()
    
    total_mass = image_data.sum() * (1e-3)**2
    mean_concentration = image_data[image_data > 0.0].mean()
    median_conentration = np.median(image_data[image_data > 0.0])
    mass.append(total_mass)
    mean.append(mean_concentration)
    median.append(median_conentration)
    
    time = datetime.strptime(image_path.stem, "%Y%m%d_%H%M%S")
    print(time, total_mass)
    times.append(datetime.strptime(image_path.stem, "%Y%m%d_%H%M%S"))

In [None]:
plt.plot(times, mass, "o-", label="mass (mmol)")
plt.plot(times, mean, "o-", label="mean (mM)")
plt.plot(times, median, "o-", label="median (mM)")

plt.legend()
plt.show()