In [None]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import yaml
from pathlib import Path
from ext.lab2im import utils
from ext.lab2im.edit_volumes import align_volume_to_ref
from SynthSeg.brain_generator import read_tfrecords
import tensorflow as tf

## Peek into data

Data downloaded from https://datashare.mpcdf.mpg.de/f/348054516

In [None]:
image_nii = nib.load("t1w_pdw_validation_pair/t1_pdw_scaled_256.nii")

In [None]:
image = image_nii.get_fdata()
image.shape

In [None]:
labels = nib.load("t1w_pdw_validation_pair/label_256.nii").get_fdata()

In [None]:
plt.imshow(image[100,..., 0], alpha=1.0)
plt.imshow(labels[100, ...], alpha=0.5)

## Align CBS volume to train volumes

Train images have a reference aff matrix of np.eye(4) ...

In [None]:
image, _, aff, n_dims, n_channels, h, im_res = utils.get_volume_info("./t1w_pdw_validation_pair/t1_pdw_scaled_256.nii", return_volume=True)

In [None]:
plt.imshow(image[100,:,:, 0])

In [None]:
image_aligned = align_volume_to_ref(image, aff, aff_ref=np.eye(4), n_dims=3)

In [None]:
labels_aligned = align_volume_to_ref(labels, aff, aff_ref=np.eye(4), n_dims=3)

In [None]:
plt.imshow(image_aligned[150,:,:, 1], alpha=1.0)
plt.imshow(labels_aligned[150,:,:], alpha=0.2)
plt.colorbar()

## Map labels

In [None]:
class LabelMapping:
    def __init__(self, cfg_path):
        cfg_path = Path(cfg_path)
        with cfg_path.open() as file:
            cfg = yaml.safe_load(file)
        output_labels = cfg["output_labels"]
        self._mapping = utils.get_mapping_lut(np.unique(output_labels))

    def lab2idx(self, label: int) -> int:
        try:
            return self._mapping[label]
        except IndexError:
            return 0

    def idx2lab(self, idx: int):
        return np.nonzero(self._mapping == idx)[0][0]

In [None]:
label_mapping = LabelMapping("./generator.yml")

In [None]:
labels_aligned = labels_aligned.astype(np.int32)
np.unique(labels_aligned), len(np.unique(labels_aligned))

In [None]:
labels_mapped = np.array(list(map(label_mapping.lab2idx, labels_aligned.flatten()))).reshape(labels_aligned.shape)

In [None]:
np.unique(labels_mapped)

## Create TFRecords

In [None]:
compression_type = None
file = "./t1w_pdw_validation_pair.tfrecord"

with tf.io.TFRecordWriter(
    str(file), options=tf.io.TFRecordOptions(compression_type=compression_type)
) as writer:
    # create tf example
    features = {
        "image": tf.train.Feature(
            bytes_list=tf.train.BytesList(
                value=[tf.io.serialize_tensor(image_aligned.astype(np.float32)).numpy()]
            )
        ),
        "labels": tf.train.Feature(
            bytes_list=tf.train.BytesList(
                value=[tf.io.serialize_tensor(labels_mapped).numpy()]
            )
        ),
    }  

    example = tf.train.Example(features=tf.train.Features(feature=features))

    # write to file
    writer.write(example.SerializeToString())

In [None]:
# sanity check
file = "./t1w_pdw_validation_pair.tfrecord"
ds = read_tfrecords([file])
img, lab = next(iter(ds))
np.allclose(img.numpy(), image_aligned), np.allclose(lab.numpy(), labels_mapped)

Output file uploaded to https://datashare.mpcdf.mpg.de/f/348054516

# Alignment: tests

## CBS full scale image

In [None]:
image, _, _, _, _, _, _ = utils.get_volume_info("../01_from_datashare/T1w_rescaled_for_segmentation.nii", return_volume=True)

In [None]:
image.shape

In [None]:
plt.imshow(image[:, :, 150])

## CBS half scale image

In [None]:
image_hs, _, aff, n_dims, n_channels, h, im_res = utils.get_volume_info("./t1w_pdw_validation_pair/t1_pdw_scaled_256.nii", return_volume=True)

In [None]:
plt.imshow(image_hs[:,:,100,0])

**No change in orientation**

## Train full scale image

In [None]:
train_ds = read_tfrecords(["train_tfrecords/000000_512.tfrecord"])
train_it = iter(train_ds)

In [None]:
train_img, _ = next(train_it)

In [None]:
image_aligned = align_volume_to_ref(image, aff, aff_ref=np.eye(4), n_dims=3)

In [None]:
image_aligned.shape

In [None]:
plt.imshow(image_aligned[150,:,:], alpha=0.5)
plt.imshow(train_img[200, :, :, 0], alpha=0.5)

In [None]:
plt.imshow(image_aligned[:, 200,:], alpha=0.5)
plt.imshow(train_img[:, 250, :, 0], alpha=0.5)

In [None]:
plt.imshow(image_aligned[:, :, 140], alpha=0.5)
plt.imshow(train_img[:, :, 120, 0], alpha=0.5)

**Seems to be aligned!**

## Align labels

In [None]:
image_hs, _, aff, n_dims, n_channels, h, im_res = utils.get_volume_info("./t1w_pdw_validation_pair/t1_pdw_scaled_256.nii", return_volume=True)

In [None]:
labels, _, _, _, _, _, _ = utils.get_volume_info("./t1w_pdw_validation_pair/label_256.nii", return_volume=True)

In [None]:
image_hs_aligned = align_volume_to_ref(image_hs, aff, aff_ref=np.eye(4), n_dims=3)
labels_aligned = align_volume_to_ref(labels, aff, aff_ref=np.eye(4), n_dims=3)

In [None]:
plt.imshow(image_hs_aligned[150,:,:, 0], alpha=0.9)
plt.imshow(labels_aligned[150,:,:], alpha=0.2)
plt.colorbar()

# ... Playground

In [None]:
valid_nii = nib.load("./t1w_pdw_validation_pair/t1_pdw_scaled_256.nii")
valid_img = valid_nii.get_fdata()
valid_img.shape

In [None]:
plt.imshow(valid_img[:,:,1, 0])

In [None]:
from ext.lab2im.utils import get_volume_info
from ext.lab2im.edit_volumes import align_volume_to_ref

In [None]:
im, _, aff, n_dims, n_channels, h, im_res = get_volume_info("./t1w_pdw_validation_pair/t1_pdw_scaled_256.nii", return_volume=True)
#im, _, aff, n_dims, n_channels, h, im_res = get_volume_info("./T1w_rescaled_for_segmentation.nii", return_volume=True)

In [None]:
aff

In [None]:
plt.imshow(im[:, 220, :, 0])

In [None]:
im2 = align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=3, return_copy=False)

In [None]:
plt.imshow(im2[:,:,20,0])

In [None]:
train_ds = read_tfrecords(["train_tfrecords/000000_512.tfrecord"])
train_it = iter(train_ds)

In [None]:
train_img, _ = next(train_it)

In [None]:
plt.imshow(train_img[:, :, 150, 0])

In [None]:
from SynthSeg.brain_generator import BrainGenerator
BrainGenerator.tfrecord_to_brain?

In [None]:
img_nii = nib.load("../01_from_datashare/T1w.nii")
img = img_nii.get_fdata()

In [None]:
plt.imshow(img[100,...])

In [None]:
img.max()

In [None]:
plt.hist(crop(test_vol).flatten(), bins=100);

In [None]:
test_vol = nib.load("../01_from_datashare/T1w_rescaled_for_segmentation.nii").get_fdata()
test_vol.shape

In [None]:
plt.imshow(test_vol[400, :, :])

In [None]:
def crop(volume):
    max_idxs = []
    for dim in volume.shape:
        max_idxs.append(dim - 256)
        
    slices = []
    for max_idx in max_idxs:
        rand_int = np.random.randint(0, max_idx)
        slices.append(slice(rand_int, rand_int + 256))

    return volume[slices[0], slices[1], slices[2]]

In [None]:
iterator = iter(train_ds)

In [None]:
train_img, label_img = next(iterator)

In [None]:
train_img.shape, train_img[..., 0].numpy().sum()

In [None]:
plt.hist(train_img.numpy()[..., 0].flatten(), bins=100);
plt.hist(img[...,0].flatten(), bins=100, alpha=0.3);

In [None]:
from SynthSeg.brain_generator import read_tfrecords

In [None]:
train_ds = read_tfrecords(["./train_tfrecords/000000.tfrecord"])

In [None]:
plt.imshow(seg[100,...], alpha=0.5)
plt.imshow(img_rs[100, ...], alpha=0.5)

In [None]:
img_rs_nii = nib.load("../01_from_datashare/T1w_rescaled_for_segmentation.nii")
img_rs = img_rs_nii.get_fdata()
img_rs.max()

Run segmentation as described in `approach.ipynb`

In [None]:
seg_nii = nib.load("T1w_segmentations2.nii")
seg = seg_nii.get_fdata()

In [None]:
seg_nii.header.get_data_shape()

In [None]:
labels = np.unique(seg)
labels, len(labels)

In [None]:
img_rs.shape, seg.shape

In [None]:
from SynthSeg.analysis.contrast_analysis import clip_and_rescale_nifti

root_dir = "/home/david/mpcdf/cbs/segmentation/SynthSeg"
input_file = f"{root_dir}/data/cbs/01_from_datashare/T1w.nii"
output_file = f"{root_dir}/data/cbs/t1w_pdw_config/T1w_rescaled_for_segmentation.nii"

clip_and_rescale_nifti(
    nifti_file=input_file,
    out_file=output_file,
    min_clip=0.0,
    max_clip=2000,
    min_out=0.0,
    max_out=1.0
)