# Brain-Tumor-Progression dataset visualization and preprocessing

## Imports

In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pydicom
import re

from datetime import datetime
from ipywidgets import interact, IntSlider
from scipy import interpolate
from typing import Any, Callable, Sequence

## Load DICOM images

In [None]:
Image = np.ndarray
Study = dict[str, Image]
Subject = dict[str, Study]
Data = dict[str, Subject]

Node = Image | Study | Subject | Data


class ReUtils:
    _subject = re.compile(r"PGBM-\d{3}")
    _date = re.compile(r"(\d{2}-\d{2}-\d{4})")
    _study = re.compile(r"\d+\.\d+-(\w+)-\d+")
    _dicom = re.compile(r".*\.dcm$")

    _folder_patterns = [_subject, _date, _study]

    @staticmethod
    def is_dicom(name: str) -> bool:
        return ReUtils._dicom.match(name) is not None

    @staticmethod
    def folder_match(name: str) -> str:
        for pattern in ReUtils._folder_patterns:
            try:
                return ReUtils._first_search(name, pattern)
            except:
                continue

        raise ValueError(f"No pattern found in {name}.")

    @staticmethod
    def _first_search(text: str, pattern: re.Pattern) -> str:
        match = pattern.search(text)

        if match:
            return match.group(pattern.groups)
        else:
            raise ValueError(f"Cannot find pattern {pattern} in {text}.")


def load_dicom_images_folder(path: str) -> np.ndarray:
    dicom_file_paths = [
        os.path.join(path, file) for file in os.listdir(path) if ReUtils.is_dicom(file)
    ]

    dicom_files = [pydicom.dcmread(path) for path in dicom_file_paths]
    dicom_files = sorted(dicom_files, key=lambda file: file.InstanceNumber)

    dicom_images = [file.pixel_array for file in dicom_files]
    images = np.stack(dicom_images)

    return images


def load_dataset(path: str) -> Node:
    node = {}

    for item in os.listdir(path):
        try:
            key = ReUtils.folder_match(item)
            subnode_path = os.path.join(path, item)
            node[key] = load_dataset(subnode_path)
        except:
            continue

    return node if any(node) else load_dicom_images_folder(path)

In [None]:
data: Data = load_dataset("../data/raw/Brain-Tumor-Progression") # type: ignore

## Images types and shapes per subject

In [None]:
image_shapes_of_studies = {
    f"{subject} {date}": {study: image.shape for study, image in images.items()}
    for subject, node in data.items()
    for date, images in node.items()
} 

In [None]:
image_shapes_of_studies_df = pd.DataFrame(image_shapes_of_studies).T.sort_index()
image_shapes_of_studies_df

The table above shows that:
- the height of the images can be 22, 23 or 24,
- image resolutions are 512x512, 320x260, 320x280, 260x320, 256x256,
- 15 subjects have all 10 image types, there are 16 subjects have 9 image types in common, all subjects have 6 image types in common.


Taking this into account, the preprocessing function should: 
- assume a height of 22 (taken from the selected end), 
- stretch/crop images to one resolution, 
- consider subjects that have given image types.

## Visualize studies

In [None]:
def show_3d_image(image):
    def plot_height(x):
        plt.imshow(image[x,:,:])
        plt.title(f"Height {x}")
        plt.show()

    x_slider = IntSlider(min=0, max=image.shape[0] - 1, description="Height")
    interact(plot_height, x=x_slider)

### 512x512 resolution examples

In [None]:
show_3d_image(data["PGBM-004"]["01-12-1994"]["T1post"]) # type: ignore

### 320x260 resolution examples

In [None]:
show_3d_image(data["PGBM-009"]["01-03-1991"]["T1post"]) # type: ignore

### 320x280 resolution examples

In [None]:
show_3d_image(data["PGBM-003"]["03-29-1995"]["T1post"]) # type: ignore

### 260x320 resolution examples

In [None]:
show_3d_image(data["PGBM-013"]["09-18-1989"]["T1post"]) # type: ignore

### 256x256 resolution examples

In [None]:
show_3d_image(data["PGBM-011"]["06-29-1989"]["T1post"]) # type: ignore

## Make shape common

### Reshape

In [None]:
def get_target_shape(node: Node) -> tuple[int, int, int]:
    if isinstance(node, np.ndarray):
        h, l, w = node.shape[0:3] # Height, long, width
        return h, l, w
    else:
        shapes = [get_target_shape(n) for n in node.values()]
        min_sizes = [min(sizes) for sizes in zip(*shapes)]
        return tuple(min_sizes)

In [None]:
target_shape = get_target_shape(data)
target_shape

In [None]:
def resize_width_2d(image: np.ndarray, width: int) -> np.ndarray:
    L, W = image.shape

    if width > W:
        left_pad = (width - W) // 2
        right_pad = W + left_pad
        new_image = np.zeros((L, width))
        new_image[:, left_pad:right_pad] = image
    elif width < W:
        crop_start = (W - width) // 2
        crop_end = crop_start + width
        new_image = image[:, crop_start:crop_end]
    else:
        new_image = image

    return new_image


def rescale_long_2d(image: np.ndarray, long: int) -> np.ndarray:
    L, W = image.shape
    W_temp = int(long / L * W)

    x = np.arange(0, W)
    y = np.arange(0, L)

    new_x = np.linspace(0, W - 1, W_temp)
    new_y = np.linspace(0, L - 1, long)

    interpolator = interpolate.interp2d(x, y, image, kind="linear")

    return interpolator(new_x, new_y)


def reshape_3d(image_3d: np.ndarray, height: int, long: int, width: int) -> np.ndarray:
    images_2d = []

    for image in image_3d[:height]:
        image = rescale_long_2d(image, long)
        image = resize_width_2d(image, width)
        images_2d.append(image)

    return np.stack(images_2d)


def change_shape(node: Node, height: int, long: int, width: int) -> Node:
    if isinstance(node, np.ndarray):
        return reshape_3d(node, height, long, width)
    else:
        return {
            key: change_shape(sub_node, height, long, width)
            for key, sub_node in node.items()
        } # type: ignore

In [None]:
height, long, _ = target_shape
data_reshaped = change_shape(data, height, long, long) # Make long and width the same - will be useful

### 512x512 resolution examples

In [None]:
show_3d_image(data_reshaped["PGBM-004"]["01-12-1994"]["T1post"]) # type: ignore

### 320x260 resolution examples

In [None]:
show_3d_image(data_reshaped["PGBM-009"]["01-03-1991"]["T1post"]) # type: ignore

### 320x280 resolution examples

In [None]:
show_3d_image(data_reshaped["PGBM-003"]["03-29-1995"]["T1post"]) # type: ignore

### 260x320 resolution examples

In [None]:
show_3d_image(data_reshaped["PGBM-013"]["09-18-1989"]["T1post"]) # type: ignore

### 256x256 resolution examples

In [None]:
show_3d_image(data_reshaped["PGBM-011"]["06-29-1989"]["T1post"]) # type: ignore

## Stack images and store in HDF5 file

In [None]:
def all_keys(d: dict[str, Any], possible_keys: Sequence[str]) -> bool:
    actual_keys = set(d.keys())
    return all([k in actual_keys for k in possible_keys])


def stack_images(
    images: dict[str, np.ndarray], order_function: Callable[[str], Any]
) -> np.ndarray:
    pairs = [(key, image) for key, image in images.items()]
    pairs_sorted = sorted(pairs, key=lambda pair: order_function(pair[0]))
    images_sorted = [image for _, image in pairs_sorted]
    return np.stack(images_sorted)


def get_date(date_str: str) -> datetime:
    return datetime.strptime(date_str, "%m-%d-%Y")


def get_time_data(data: Data) -> np.ndarray:
    time_data = {
        subject: sorted([get_date(date) for date in node.keys()])
        for subject, node in data.items()
    }

    time_deltas = {
        subject: np.array([(dates[i - 1] - dates[i]).days for i in range(1, len(dates))])
        for subject, dates in time_data.items()
    }

    return stack_images(time_deltas, lambda x: x)


def preprocess(data: Data, image_types: Sequence[str], target_path: str):
    data_filtered = {
        subject: {
            date: {image_type: image for image_type, image in study.items()}
            for date, study in node.items()
        }
        for subject, node in data.items()
        if all([all_keys(study, image_types) for study in node.values()])
    }

    studies_stacked = {
        subject: {
            date: stack_images(images, lambda x: x) for date, images in node.items()
        }
        for subject, node in data_filtered.items()
    }

    dates_stacked = {
        subject: stack_images(images, get_date)
        for subject, images in studies_stacked.items()
    }

    all_stacked = stack_images(dates_stacked, lambda x: x)

    time_data = get_time_data(data)

    image_types_array = np.array(image_types, dtype="S")

    target_directory = os.path.dirname(target_path)

    if not os.path.exists(target_directory):
        os.makedirs(target_directory)

    with h5py.File(target_path, "w") as file:
        file.create_dataset("images", data=all_stacked)
        file.create_dataset("times", data=time_data)
        file.create_dataset("metadata", data=image_types_array)

In [None]:
image_types = [
    "dT1",
    "T1post",
    "T2reg",
    "MaskTumor",
    "sRCBVreg",
    "FLAIRreg",
    "ADCreg",
    "T1prereg",
    "nRCBVreg",
    "nCBFreg",
]

preprocess(
    data_reshaped, image_types, "../data/preprocessed/Brain-Tumor-Progression/test.h5"
)

## Load preprocessed images