<a href="https://colab.research.google.com/github/alecseiterr/pleural_effusion/blob/main/Dmitrii_Utkin/tf_dataset_creation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


In [1]:
# from google.colab import drive
# drive.mount('/content/drive')


[Information on SimpleITK image drections](https://simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1DICOMOrientImageFilter.html)


In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob, json, os, shutil
import SimpleITK as sitk # install beforehand
from functools import reduce
from datetime import datetime
import ipywidgets as widgets
from enum import Enum

%matplotlib inline

## Setting up path constants


In [65]:
LOCAL_FOLDER = os.path.join("/Users", "dutking", "LOCAL")
PROJECT_FOLDER = os.path.join(LOCAL_FOLDER, "AI_uni", "radlogix")
EFFUSION_PATH = os.path.join(
    PROJECT_FOLDER, "dataset", "effusions_052023"
)  # /LUNG1-001/LUNG1-001_effusion_first_reviewer.nii.gz
CT_PATH = os.path.join(
    PROJECT_FOLDER, "dataset", "features"
)  # /LUNG1-001/09-18-2008-StudyID-NA-69331/0.000000-NA-82046
LUNGS_PATH = os.path.join(PROJECT_FOLDER, "dataset", "lungs_labels")
CSV_DF_PATH = os.path.join(
    PROJECT_FOLDER,
    "_github",
    "pleural_effusion",
    "Dmitrii_Utkin",
    "_docs",
    "clean_df_on_latest_ds.csv",
)
# INITIAL IMAGES DIRECTIONS IS MESSED UP
CT_COORDS = ["RPS", "RPI", "LPS", "LPI"]
MASK_COORDS = ["RAS", "RAI", "LAS", "LAI"]
COORDS = list(zip(CT_COORDS, MASK_COORDS))
print(COORDS)


[('RPS', 'RAS'), ('RPI', 'RAI'), ('LPS', 'LAS'), ('LPI', 'LAI')]


In [3]:
class Target(str, Enum):
    CT = "ct"
    EFFUSION = "effusion"
    LUNGS = "lungs"


class Mask(str, Enum):
    EFFUSION = "effusion"
    LUNGS = "lungs"


In [4]:
class Config:
    def __init__(
        self,
        model_name="my_model",
        resample_spacing=False,
        resample_size=True,
        resize_factor=2,
        normalize=False,
        use_positive_values=False,
        crop_values=True,
        crop_ranges=[[-1024, 150]],
        pad_filler=-1024,
        mask=Mask.EFFUSION,
        ratio=1 / 1,
        coords=COORDS,
    ):
        self.model_name = model_name
        self.resample_spacing = resample_spacing
        self.resample_size = resample_size
        self.resize_factor = resize_factor
        self.normalize = normalize
        self.use_positive_values = use_positive_values
        self.crop_values = crop_values
        self.crop_ranges = crop_ranges
        self.pad_filler = pad_filler
        self.mask = mask
        self.ratio = ratio
        self.coords = coords
        self.create_folders()
        self.set_target_z()
        self.set_dimension()

    def create_folders(self):
        self.model_path = os.path.join(PROJECT_FOLDER, "models", self.model_name)
        self.dataset_path = os.path.join(self.model_path, "tfdataset")
        self.logs_path = os.path.join(self.model_path, "logs")
        self.saved_models_path = os.path.join(self.model_path, "saved_models")
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
            os.makedirs(
                self.dataset_path,
            )
            os.makedirs(self.logs_path)
            os.makedirs(self.saved_models_path)

    def set_target_z(self):
        if (
            self.resample_spacing and self.resample_size
        ):  # resize factor assumed to be 2
            self.target_z = 192
            return

        if self.resample_spacing:
            self.target_z = 384
            return

        if self.resample_size:  # resize factor assumed to be 2
            self.target_z = 64
            return

        self.target_z = 128

    def set_dimension(self):
        if self.resample_size:  # resize factor assumed to be 2
            self.dimension = 256
            return

        self.dimension = 512

    def save_to_JSON(self, suffix=""):
        path = os.path.join(self.model_path, "configs")
        if not os.path.exists(path):
            os.makedirs(path)

        with open(os.path.join(path, f"{self.model_name}{suffix}.json"), "w") as file:
            json.dump(self.__dict__, file)

    def load_from_JSON(self, filepath):
        with open(filepath) as json_file:
            data = json.load(json_file)
            self.model_name = data.model_name
            self.resample_spacing = data.resample_spacing
            self.resample_size = data.resample_size
            self.resize_factor = data.resize_factor
            self.normalize = data.normalize
            self.use_positive_values = data.use_positive_values
            self.crop_values = data.crop_values
            self.crop_ranges = data.crop_ranges
            self.pad_filler = data.pad_filler
            self.target_z = data.target_z
            self.dimension = data.dimension
            self.mask = data.mask
            self.ratio = data.ratio
            self.coords = data.coords
            self.create_folders()


In [5]:
def get_lungs_dfs(use_latest=False):
    dfs_path = os.path.join(PROJECT_FOLDER, "models", "lungs_dfs")
    if not os.path.exists(dfs_path):
        os.makedirs(dfs_path)

    if use_latest:
        dfs = glob.glob(f"{dfs_path}/*.csv")
        if len(dfs) > 0:
            restored_dfs = {}
            for df in dfs:
                name = df.split("/")[-1].split("_")[1]
                df = pd.read_csv(df, index_col="PatientID")
                df.name = name
                restored_dfs[name] = df

            return restored_dfs["train"], restored_dfs["val"], restored_dfs["test"]
        else:
            print("No previous dataframes found. Creating new ones...")

    df = pd.read_csv(CSV_DF_PATH, index_col="PatientID")
    val_df = df.sample(n=50)
    val_df.to_csv(os.path.join(dfs_path, f"lungs_val_df_{DATE_TIME}.csv"))
    val_df.name = "val"
    print("validation set size:", len(val_df))
    df.drop(val_df.index, inplace=True)

    test_df = df.sample(n=25)
    test_df.to_csv(os.path.join(dfs_path, f"lungs_test_df_{DATE_TIME}.csv"))
    test_df.name = "test"
    print("test set size:", len(test_df))
    df.drop(test_df.index, inplace=True)

    train_df = df
    train_df.to_csv(os.path.join(dfs_path, f"lungs_train_df_{DATE_TIME}.csv"))
    train_df.name = "train"
    print("train set size:", len(train_df))
    return train_df, val_df, test_df


def get_effusion_dfs(features_to_labels_ratio=1, use_latest=False):
    dfs_path = os.path.join(PROJECT_FOLDER, "models", "dfs")
    if not os.path.exists(dfs_path):
        os.makedirs(dfs_path)

    if use_latest:
        dfs = glob.glob(f"{dfs_path}/*.csv")
        if len(dfs) > 0:
            restored_dfs = {}
            for df in dfs:
                name = df.split("/")[-1].split("_")[0]
                df = pd.read_csv(df, index_col="PatientID")
                df.name = name
                restored_dfs[name] = df

            return restored_dfs["train"], restored_dfs["val"], restored_dfs["test"]
        else:
            print("No previous dataframes found. Creating new ones...")

    df = pd.read_csv(CSV_DF_PATH, index_col="PatientID")
    amount_of_effusions = len(df.loc[(df["Effusion.Event"] == 1)])
    print(f"Amount of labels: {amount_of_effusions}")
    print(f"Amount of features: {len(df)}")

    balanced_df = pd.concat(
        [
            df.loc[(df["Effusion.Event"] == 1)],
            df.loc[(df["Effusion.Event"] == 0)].sample(
                n=int(amount_of_effusions * features_to_labels_ratio)
            ),
        ]
    )

    balanced_df.to_csv(os.path.join(dfs_path, f"balanced_df_{DATE_TIME}.csv"))
    print(f"Balanced df size: {len(balanced_df)}")

    val_df = pd.concat(
        [
            balanced_df.loc[(balanced_df["Effusion.Event"] == 1)].sample(n=10),
            balanced_df.loc[(balanced_df["Effusion.Event"] == 0)].sample(n=10),
        ]
    )
    val_df.to_csv(os.path.join(dfs_path, f"val_df_{DATE_TIME}.csv"))
    val_df.name = "val"
    print("validation set size:", len(val_df))
    balanced_df.drop(val_df.index, inplace=True)

    test_df = pd.concat(
        [
            balanced_df.loc[(balanced_df["Effusion.Event"] == 1)].sample(n=5),
            balanced_df.loc[(balanced_df["Effusion.Event"] == 0)].sample(n=5),
        ]
    )
    test_df.to_csv(os.path.join(dfs_path, f"test_df_{DATE_TIME}.csv"))
    test_df.name = "test"
    print("test set size:", len(test_df))
    balanced_df.drop(test_df.index, inplace=True)

    train_df = balanced_df
    train_df.to_csv(os.path.join(dfs_path, f"train_df_{DATE_TIME}.csv"))
    train_df.name = "train"
    print("train set size:", len(train_df))
    return train_df, val_df, test_df


In [6]:
def get_dicom_folder(id):
    for dicom_folder, _, files in os.walk(os.path.join(CT_PATH, id), topdown=False):
        if files[0].endswith(".dcm"):
            return dicom_folder


def get_nifti_file(id, path):
    if not os.path.exists(os.path.join(path, id)):
        return None

    file = glob.glob(os.path.join(path, id, "*.gz"))[0]
    return file


In [7]:
def get_feature_image(dicom_folder, coord=None):
    try:
        shutil.rmtree(os.path.join(dicom_folder, "__MACOSX"))
    except:
        pass

    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(dicom_folder)
    reader.SetFileNames(dicom_names)
    image = reader.Execute()
    # image = sitk.DICOMOrient(image, coord)
    return image


def get_label_image(file, coord=None):
    image = sitk.ReadImage(file)
    # image = sitk.DICOMOrient(image, coord)
    return image


def get_initial_image(target, id):
    match target:
        case Target.CT:
            x = get_dicom_folder(id)
            x = get_feature_image(x)
            return x
        case Target.EFFUSION:
            x = get_nifti_file(id, EFFUSION_PATH)
            if x is None:
                return x
            x = sitk.ReadImage(x)
            return x
        case Target.LUNGS:
            x = get_nifti_file(id, LUNGS_PATH)
            x = sitk.ReadImage(x)
            return x
        case _:
            return None


In [27]:
def resample_spacing(image):
    resample = sitk.ResampleImageFilter()
    resample.SetInterpolator(sitk.sitkLinear)
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(image.GetOrigin())
    new_spacing = [1, 1, 1]
    resample.SetOutputSpacing(new_spacing)

    orig_size = np.array(image.GetSize(), dtype=np.int32)
    orig_spacing = image.GetSpacing()
    new_size = orig_size * (np.array(orig_spacing) / np.array(new_spacing))
    new_size = np.ceil(new_size).astype(np.int32)  #  Image dimensions are in integers
    new_size = [int(s) for s in new_size]
    resample.SetSize(new_size)

    new_image = resample.Execute(image)
    return new_image


def resample_size(patient_CT, resize_factor):
    # original_CT = sitk.ReadImage(patient_CT,sitk.sitkInt32)
    original_CT = patient_CT
    dimension = original_CT.GetDimension()
    reference_physical_size = np.zeros(original_CT.GetDimension())
    reference_physical_size[:] = [
        (sz - 1) * spc if sz * spc > mx else mx
        for sz, spc, mx in zip(
            original_CT.GetSize(), original_CT.GetSpacing(), reference_physical_size
        )
    ]

    reference_origin = original_CT.GetOrigin()
    reference_direction = original_CT.GetDirection()

    reference_size = [round(sz / resize_factor) for sz in original_CT.GetSize()]
    reference_spacing = [
        phys_sz / (sz - 1)
        for sz, phys_sz in zip(reference_size, reference_physical_size)
    ]

    reference_image = sitk.Image(reference_size, original_CT.GetPixelIDValue())
    reference_image.SetOrigin(reference_origin)
    reference_image.SetSpacing(reference_spacing)
    reference_image.SetDirection(reference_direction)

    reference_center = np.array(
        reference_image.TransformContinuousIndexToPhysicalPoint(
            np.array(reference_image.GetSize()) / 2.0
        )
    )

    transform = sitk.AffineTransform(dimension)
    transform.SetMatrix(original_CT.GetDirection())

    transform.SetTranslation(np.array(original_CT.GetOrigin()) - reference_origin)

    centering_transform = sitk.TranslationTransform(dimension)
    img_center = np.array(
        original_CT.TransformContinuousIndexToPhysicalPoint(
            np.array(original_CT.GetSize()) / 2.0
        )
    )
    centering_transform.SetOffset(
        np.array(transform.GetInverse().TransformPoint(img_center) - reference_center)
    )
    centered_transform = sitk.CompositeTransform(transform)
    centered_transform.AddTransform(centering_transform)

    # sitk.Show(sitk.Resample(original_CT, reference_image, centered_transform, sitk.sitkLinear, 0.0))

    return sitk.Resample(
        original_CT, reference_image, centered_transform, sitk.sitkLinear, 0.0
    )


def set_final_size(image, target_slices, dimension, pad_filler):
    image_size = image.GetSize()
    z = image_size[-1]
    z_diff = target_slices - z
    if z_diff < 0:
        cropped = crop_z(image, np.abs(z_diff))
        return pad_xy(cropped, dimension, pad_filler)
    elif z_diff > 0:
        padded = pad_z(image, np.abs(z_diff), pad_filler)
        return pad_xy(padded, dimension, pad_filler)

    return pad_xy(image, dimension, pad_filler)


def pad_z(image, z_diff, pad_filler):
    z_top_pad = int(np.ceil(z_diff / 2))
    z_bottom_pad = int(np.floor(z_diff / 2))
    padded_image = sitk.ConstantPad(
        image, (0, 0, z_bottom_pad), (0, 0, z_top_pad), pad_filler
    )
    return padded_image


def crop_z(image, z_diff):
    z_top_crop = int(np.ceil(z_diff / 2))
    z_bottom_crop = int(np.floor(z_diff / 2))
    cropped_image = sitk.Crop(image, (0, 0, z_bottom_crop), (0, 0, z_top_crop))
    return cropped_image


def pad_xy(image, dimension, pad_filler):
    image_size = image.GetSize()
    xy_pad1 = int(np.ceil((dimension - image_size[0]) / 2))  # 256-250=6
    xy_pad2 = int(np.floor((dimension - image_size[0]) / 2))
    padded_image = sitk.ConstantPad(
        image, (xy_pad1, xy_pad1, 0), (xy_pad2, xy_pad2, 0), pad_filler
    )
    return padded_image


def crop_values(arr, config):
    ranges = config.crop_ranges
    if config.use_positive_values:
        ranges = np.array(config.crop_ranges) + 1024
    arrs = []
    for idx, range in enumerate(ranges):
        arrs.append(arr.copy())
        arrs[idx][arrs[idx] < range[0]] = range[0]
        arrs[idx][arrs[idx] > range[1]] = range[1]

    result = reduce(lambda a, b: a + b, arrs)
    return result


def normalize(image_array, config):
    norm_image_array = (image_array - config.crop_ranges[0][0]) / (
        config.crop_ranges[-1][-1] - config.crop_ranges[0][0]
    )
    norm_image_array[norm_image_array > 1] = 1.0
    norm_image_array[norm_image_array < 0] = 0.0
    return norm_image_array


def set_positive_values(image_array):
    pos_image_array = image_array + 1024
    return pos_image_array


def set_direction(image, coord):
    new_image = sitk.DICOMOrient(image, coord)
    return new_image


def set_dtype(image, target):
    match target:
        case Target.CT:
            x = image.astype(dtype=np.float64)
        case Target.LUNGS:
            x = image.astype(dtype=np.int16)
        case Target.EFFUSION:
            x = image.astype(dtype=np.int16)
    return x


In [52]:
def get_final_image(
    target=Target.CT, id="LUNG1-001", config=Config(), coords=("LPS", "LAS")
):
    x = get_initial_image(target, id)
    if x is None:
        return np.zeros(
            (config.target_z, config.dimension, config.dimension, 1), dtype=np.int16
        )

    if config.resample_spacing:
        x = resample_spacing(x)

    if config.resample_size:
        x = resample_size(x, config.resize_factor)

    x = set_final_size(x, config.target_z, config.dimension, 0)

    if target == Target.CT:
        x = set_direction(x, coords[0])
    else:
        x = set_direction(x, coords[1])

    x = sitk.GetArrayFromImage(x)

    match target:
        case Target.CT:
            if config.use_positive_values:
                x = set_positive_values(x)

            if config.crop_values:
                x = crop_values(x, config)

            if config.normalize:
                x = normalize(x)

        case Target.LUNGS:
            x = np.where(x > 0, 1, 0)

    x = x.reshape(config.target_z, config.dimension, config.dimension, 1)
    x = set_dtype(x, target)
    return x


In [13]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def create_example(feature, label, shape):
    feature = {
        "feature": _bytes_feature(feature.tobytes()),
        "label": _bytes_feature(label.tobytes()),
        "shape": _bytes_feature(shape.tobytes()),
    }

    return tf.train.Example(features=tf.train.Features(feature=feature))


In [60]:
def get_dataset(df, config):
    record_file = os.path.join(config.dataset_path, f"{df.name}.tfrecords")
    options = tf.io.TFRecordOptions(compression_type="GZIP")
    with tf.io.TFRecordWriter(record_file, options=options) as writer:
        for id in df.index:
            for coord in config.coords:
                ct = get_final_image(
                    target=Target.CT, id=id, config=config, coords=coord
                )
                mask = get_final_image(
                    target=config.mask, id=id, config=config, coords=coord
                )
                ct_shape = np.array(ct.shape)

                tf_example = create_example(ct, mask, ct_shape)
                writer.write(tf_example.SerializeToString())


In [61]:
CONFIGS = [
    Config(
        model_name="effusions_aug",
        resample_spacing=False,
        resample_size=True,
        resize_factor=2,
        normalize=False,
        use_positive_values=False,
        crop_values=True,
        crop_ranges=[[-1024, 200]],
        pad_filler=-1024,
        mask=Mask.EFFUSION,
        ratio=1 / 1,
        coords=COORDS,
    ),
    Config(
        model_name="lungs_aug",
        resample_spacing=False,
        resample_size=True,
        resize_factor=2,
        normalize=False,
        use_positive_values=False,
        crop_values=True,
        crop_ranges=[[-1024, 200]],
        pad_filler=-1024,
        mask=Mask.LUNGS,
        coords=COORDS,
    ),
]

"""
# FOR TESTING
CONFIGS = [
    Config(
        model_name="test_effusion",
        resample_spacing=False,
        resample_size=True,
        resize_factor=2,
        normalize=False,
        use_positive_values=False,
        crop_values=True,
        crop_ranges=[[-1024, 200]],
        pad_filler=-1024,
        mask=Mask.EFFUSION,
        ratio=1 / 1,
        coords=COORDS,
    ),
    Config(
        model_name="test_lungs",
        resample_spacing=False,
        resample_size=True,
        resize_factor=2,
        normalize=False,
        use_positive_values=False,
        crop_values=True,
        crop_ranges=[[-1024, 200]],
        pad_filler=-1024,
        mask=Mask.LUNGS,
        coords=COORDS,
    ),
]
"""


'\n# FOR TESTING\nCONFIGS = [\n    Config(\n        model_name="test_effusion",\n        resample_spacing=False,\n        resample_size=True,\n        resize_factor=2,\n        normalize=False,\n        use_positive_values=False,\n        crop_values=True,\n        crop_ranges=[[-1024, 200]],\n        pad_filler=-1024,\n        mask=Mask.EFFUSION,\n        ratio=1 / 1,\n        coords=COORDS,\n    ),\n    Config(\n        model_name="test_lungs",\n        resample_spacing=False,\n        resample_size=True,\n        resize_factor=2,\n        normalize=False,\n        use_positive_values=False,\n        crop_values=True,\n        crop_ranges=[[-1024, 200]],\n        pad_filler=-1024,\n        mask=Mask.LUNGS,\n        coords=COORDS,\n    ),\n]\n'

In [62]:
DATE_TIME = datetime.now().strftime("%d_%m_%Y__%H_%M_%S")

for config in CONFIGS:
    print(f"\nWriting datasets for {config.model_name}")
    match config.mask:
        case Mask.LUNGS:
            train_df, val_df, test_df = get_lungs_dfs(use_latest=False)
        case Mask.EFFUSION:
            train_df, val_df, test_df = get_effusion_dfs(config.ratio, use_latest=False)
    config.save_to_JSON()
    get_dataset(train_df, config)
    get_dataset(val_df, config)
    get_dataset(test_df, config)



Writing datasets for effusions_aug
Amount of labels: 54
Amount of features: 358
Balanced df size: 108
validation set size: 20
test set size: 10
train set size: 78

Writing datasets for lungs_aug
validation set size: 50
test set size: 25
train set size: 283


# READ


In [57]:
def parse_record(record):
    name_to_features = {
        "feature": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.string),
        "shape": tf.io.FixedLenFeature([], tf.string),
    }
    return tf.io.parse_single_example(record, name_to_features)


def decode_record(record, config):
    feature = tf.io.decode_raw(
        record["feature"],
        out_type="float64",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    label = tf.io.decode_raw(
        record["label"],
        out_type="int16",
        little_endian=True,
        fixed_length=None,
        name=None,
    )

    shape = tf.io.decode_raw(
        record["shape"],
        out_type="int64",
        little_endian=True,
        fixed_length=None,
        name=None,
    )
    label = tf.cast(
        tf.reshape(label, (config.target_z, config.dimension, config.dimension, 1)),
        dtype=tf.float32,
    )

    feature = tf.cast(
        tf.reshape(feature, (config.target_z, config.dimension, config.dimension, 1)),
        dtype=tf.float64,
    )
    label.set_shape((config.target_z, config.dimension, config.dimension, 1))

    feature.set_shape((config.target_z, config.dimension, config.dimension, 1))
    return (feature, label)


In [58]:
def plot_image(image, slice_num):
    my_slice = image[slice_num, :, :]
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12, 5))
    ax.imshow(feature_slice, cmap="bone", interpolation="none")
    plt.title(f"Slice {slice_num}")
    plt.show()


def plot_overlay(feature, label):
    mid_slice = int(feature.shape[0] / 2)
    feature_slice = feature[mid_slice, :, :]
    label_slice = np.squeeze(label[mid_slice, :, :])
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
    ax[0].imshow(feature_slice, cmap="bone", interpolation="none")
    ax[1].imshow(feature_slice, cmap="bone", interpolation="none")
    ax[1].imshow(
        label_slice, cmap="prism", vmin=0, vmax=1, alpha=0.5 * (np.squeeze(label_slice))
    )
    plt.title(f"Slice {mid_slice}")
    plt.show()


def plot_3d(
    images,
    colors=[[0.5, 0.5, 1], [0.9, 0.1, 0.9]],
    alpha=[0.1, 0.9],
    threshold=[-600, 0],
):
    """
    Эта функция создает 3D визуализацию изображения на основе заданного порогового значения.
    Параметры:
    images: [ndarray]
        3D массив, представляющий изображения.
    colors: [float]
        Массив значений цвета фигур.
    alpha: [float]
        Массив значеий прозрачности фигур.
    threshold: х, необязательный
        Пороговое значение, используемое для создания 3D модели.
    """
    fig = plt.figure(figsize=(9, 9))
    ax = fig.add_subplot(111, projection="3d")

    for idx, image in enumerate(images):
        # Расположим сканирование вертикально,
        # так чтобы голова пациента была вверху, лицом к камере
        image = image.transpose(2, 0, 1)
        image = image[:, :, ::-1]

        if idx == 0:
            # Устанавливаем границы для каждой оси в соответствии с формой транспонированного изображения
            ax.set_xlim(0, image.shape[0])
            ax.set_ylim(0, image.shape[1])
            ax.set_zlim(0, image.shape[2])

        # Получаем вершины и грани 3D модели, используя marching_cubes
        verts, faces, _, _ = measure.marching_cubes(image, threshold[idx])
        # Создаем объект Figure в matplotlib и добавляем в него 3D подзаголовок

        # Создаем коллекцию треугольников из вершин и граней, устанавливаем цвет и прозрачность, добавляем в подзаголовок
        mesh = Poly3DCollection(verts[faces], alpha=alpha[idx])
        face_color = colors[idx]
        mesh.set_facecolor(face_color)
        ax.add_collection3d(mesh)

    # Отображаем визуализацию
    plt.show()


def overlay_slices_slider(images, titles=[], cmaps=["bone", "prism", "jet"]):
    @widgets.interact(axial_slice=(0, images[0].shape[0] - 1))
    def axial_slicer(axial_slice=0):
        num_plots = len(images)
        if len(titles) < num_plots:
            diff = num_plots - len(titles)
            for i in range(diff):
                titles.append("")
        fig, ax = plt.subplots(1, num_plots, figsize=(5 * num_plots, 5), squeeze=False)
        ax[0][0].imshow(images[0][axial_slice, :, :], cmap=cmaps[0])
        ax[0][0].set_title(titles[0])
        ax[0][0].axis("off")
        for x in range(num_plots - 1):
            ax[0][x + 1].imshow(images[0][axial_slice, :, :], cmap=cmaps[0])
            ax[0][x + 1].imshow(
                images[x + 1][axial_slice, :, :],
                cmap=cmaps[x + 1],
                alpha=0.5 * images[x + 1][axial_slice],
            )
            ax[0][x + 1].set_title(titles[x + 1])
            ax[0][x + 1].axis("off")
        plt.show()


In [None]:
for config in CONFIGS[:1]:
    print("Reading data for model:", config.model_name)
    print(config.dataset_path)
    record_file = os.path.join(config.dataset_path, "test.tfrecords")
    dataset = tf.data.TFRecordDataset(record_file, compression_type="GZIP")

    for record in dataset:
        parsed_record = parse_record(record)
        feature, label = decode_record(parsed_record, config)
        feature = feature.numpy().squeeze()
        label = label.numpy().squeeze()
        # plot_overlay(feature, label)
        overlay_slices_slider(
            images=[feature, label],
            titles=["Feature", "Lungs"],
            cmaps=["bone", "prism"],
        )
