In [None]:
import os
from pathlib import Path
from typing import Dict, Optional, Union

import numpy as np
import pandas as pd
from numpy.typing import NDArray


def load_file(file_path: Path) -> Union[NDArray, Dict[str, NDArray]]:
    """
    Load a file in one of the formats provided in the OpenKBP dataset
    """
    if file_path.stem == "voxel_dimensions":
        return np.loadtxt(file_path)

    loaded_file_df = pd.read_csv(file_path, index_col=0)
    if loaded_file_df.isnull().values.any():  # Data is a mask
        loaded_file = np.array(loaded_file_df.index).squeeze()
    else:  # Data is a sparse matrix
        loaded_file = {"indices": loaded_file_df.index.values, "data": loaded_file_df.data.values}

    return loaded_file


def get_paths(directory_path: Path, extension: Optional[str] = None) -> list[Path]:
    """
    Get the paths of every file contained in `directory_path` that also has the extension `extension` if one is provided.
    """
    all_paths = []

    if not directory_path.is_dir():
        pass
    elif extension is None:
        dir_list = os.listdir(directory_path)
        for name in dir_list:
            if "." != name[0]:  # Ignore hidden files
                all_paths.append(directory_path / str(name))
    else:
        data_root = Path(directory_path)
        for file_path in data_root.glob("*.{}".format(extension)):
            file_path = Path(file_path)
            if "." != file_path.stem[0]:
                all_paths.append(file_path)

    return all_paths


def sparse_vector_function(x, indices=None) -> dict[str, NDArray]:
    """Convert a tensor into a dictionary of the non-zero values and their corresponding indices
    :param x: the tensor or, if indices is not None, the values that belong at each index
    :param indices: the raveled indices of the tensor
    :return:  sparse vector in the form of a dictionary
    """
    if indices is None:
        y = {"data": x[x > 0], "indices": np.nonzero(x.flatten())[-1]}
    else:
        y = {"data": x[x > 0], "indices": indices[x > 0]}
    return y

In [None]:
from typing import Union

from numpy.typing import NDArray


class DataShapes:
    def __init__(self):
        self.patient_shape = (128, 128, 128)

    @property
    def dose(self) -> tuple[int, int, int, int]:
        """Dose deposited within the patient tensor"""
        return self.patient_shape + (1,)

    @property
    def predicted_dose(self) -> tuple[int, int, int, int]:
        """Predicted dose that should be deposited within the patient tensor"""
        return self.dose

    @property
    def ct(self) -> tuple[int, int, int, int]:
        """CT image grey scale within the patient tensor"""
        return self.patient_shape + (1,)

    @property
    def Brainstem(self) -> tuple[int, int, int, int]:
        return self.ct

    @property
    def SpinalCord(self) -> tuple[int, int, int, int]:
        return self.ct

    @property
    def LeftParotid(self) -> tuple[int, int, int, int]:
        return self.ct

    @property
    def RightParotid(self) -> tuple[int, int, int, int]:
        return self.ct

    @property
    def Esophagus(self) -> tuple[int, int, int, int]:
        return self.ct

    @property
    def Larynx(self) -> tuple[int, int, int, int]:
        return self.ct

    @property
    def Mandible(self) -> tuple[int, int, int, int]:
        return self.ct

    @property
    def PTV56(self) -> tuple[int, int, int, int]:
        return self.ct

    @property
    def PTV63(self) -> tuple[int, int, int, int]:
        return self.ct

    @property
    def PTV70(self) -> tuple[int, int, int, int]:
        return self.ct

    @property
    def structure_masks(self) -> tuple[int, int, int, int]:
        """Mask of all structures in patient"""
        return self.patient_shape + (self.num_rois,)

    @property
    def possible_dose_mask(self) -> tuple[int, int, int, int]:
        """Mask where dose can be deposited"""
        return self.patient_shape + (1,)

    @property
    def voxel_dimensions(self) -> tuple[float]:
        """Physical dimensions of patient voxels (in mm)"""
        return tuple((3,))

    def from_data_names(self, data_names: list[str]) -> dict[str, Union[NDArray, tuple[float]]]:
        data_shapes = {}
        for name in data_names:
            data_shapes[name] = getattr(self, name)
        return data_shapes


In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Optional

import numpy as np
from numpy.typing import NDArray


class DataBatch:
    def __init__(
        self,
        dose: Optional[NDArray] = None,
        predicted_dose: Optional[NDArray] = None,
        ct: Optional[NDArray] = None,
        Brainstem: Optional[NDArray] = None,
        SpinalCord: Optional[NDArray] = None,
        RightParotid: Optional[NDArray] = None,
        LeftParotid: Optional[NDArray] = None,
        Esophagus: Optional[NDArray] = None,
        Larynx: Optional[NDArray] = None,
        Mandible: Optional[NDArray] = None,
        PTV56: Optional[NDArray] = None,
        PTV63: Optional[NDArray] = None,
        PTV70: Optional[NDArray] = None,
        possible_dose_mask: Optional[NDArray] = None,
        voxel_dimensions: Optional[NDArray] = None,
        patient_list: Optional[list[str]] = None,
        patient_path_list: Optional[list[Path]] = None,
    ):
        self.dose = dose
        self.predicted_dose = predicted_dose
        self.ct = ct
        self.Brainstem = Brainstem
        self.SpinalCord = SpinalCord
        self.RightParotid = RightParotid
        self.LeftParotid = LeftParotid
        self.Esophagus = Esophagus
        self.Larynx = Larynx
        self.Mandible = Mandible
        self.PTV56 = PTV56
        self.PTV63 = PTV63
        self.PTV70 = PTV70
        self.possible_dose_mask = possible_dose_mask
        self.voxel_dimensions = voxel_dimensions
        self.patient_list = patient_list
        self.patient_path = patient_path_list

    @classmethod
    def initialize_from_required_data(cls, data_dimensions: dict[str, NDArray], batch_size: int) -> DataBatch:
        attribute_values = {}
        for data, dimensions in data_dimensions.items():
            batch_data_dimensions = (batch_size, *dimensions)
            attribute_values[data] = np.zeros(batch_data_dimensions)
        return cls(**attribute_values)

    def set_values(self, data_name: str, batch_index: int, values: NDArray):
        getattr(self, data_name)[batch_index] = values



In [None]:
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Union

import numpy as np
from more_itertools import windowed
from numpy.typing import NDArray
from tqdm import tqdm


class DataLoader:
    """Loads OpenKBP csv data in structured format for dose prediction models."""

    def __init__(self, patient_paths: List[Path], batch_size: int = 2):
        """
        :param patient_paths: list of the paths where data for each patient is stored
        :param batch_size: the number of data points to lead in a single batch
        """
        self.patient_paths = patient_paths
        self.batch_size = batch_size

        # Light processing of attributes
        self.paths_by_patient_id = {patient_path.stem: patient_path for patient_path in self.patient_paths}
        self.required_files: Optional[Dict] = None
        self.mode_name: Optional[str] = None

        # Parameters that should not be changed unless OpenKBP data is modified
        self.rois = dict(
            oars=["Brainstem", "SpinalCord", "RightParotid", "LeftParotid", "Esophagus", "Larynx", "Mandible"],
            targets=["PTV56", "PTV63", "PTV70"],
        )
        self.full_roi_list = sum(map(list, self.rois.values()), [])  # make a list of all rois
        # self.num_rois = len(self.full_roi_list)
        self.data_shapes = DataShapes()

    @property
    def patient_id_list(self) -> List[str]:
        return list(self.paths_by_patient_id.keys())

    def get_batches(self) -> Iterator[DataBatch]:
        batches = windowed(self.patient_paths, n=self.batch_size, step=self.batch_size)
        complete_batches = (batch for batch in batches if None not in batch)
        for batch_paths in tqdm(complete_batches):
            yield self.prepare_data(batch_paths)

    def get_patients(self, patient_list: List[str]) -> DataBatch:
        file_paths_to_load = [self.paths_by_patient_id[patient] for patient in patient_list]
        return self.prepare_data(file_paths_to_load)

    def set_mode(self, mode: str) -> None:
        """Set parameters based on `mode`."""
        self.mode_name = mode
        if mode == "training_model":
            required_data = ["dose", "ct", "Brainstem", "SpinalCord", "RightParotid", "LeftParotid", "Esophagus", "Larynx", "Mandible", "PTV56", "PTV63", "PTV70", "possible_dose_mask", "voxel_dimensions"]
        elif mode == "predicted_dose":
            required_data = [mode]
            self._force_batch_size_one()
        elif mode == "evaluation":
            required_data = ["dose", "Brainstem", "SpinalCord", "RightParotid", "LeftParotid", "Esophagus", "Larynx", "Mandible", "PTV56", "PTV63", "PTV70", "possible_dose_mask", "voxel_dimensions"]
            self._force_batch_size_one()
        elif mode == "dose_prediction":
            required_data = ["ct", "Brainstem", "SpinalCord", "RightParotid", "LeftParotid", "Esophagus", "Larynx", "Mandible", "PTV56", "PTV63", "PTV370", "possible_dose_mask", "voxel_dimensions"]
            self._force_batch_size_one()
        else:
            raise ValueError(f"Mode `{mode}` does not exist. Mode must be either training_model, prediction, predicted_dose, or evaluation")
        self.required_files = self.data_shapes.from_data_names(required_data)

    def _force_batch_size_one(self) -> None:
        if self.batch_size != 1:
            self.batch_size = 1
            Warning("Batch size has been changed to 1 for dose prediction mode")

    def shuffle_data(self) -> None:
        np.random.shuffle(self.patient_paths)

    def prepare_data(self, file_paths_to_load: List[Path]) -> DataBatch:
        """Prepares data containing samples in batch so that they are loaded in the proper shape: (n_samples, *dim, n_channels)"""

        batch_data = DataBatch.initialize_from_required_data(self.required_files, self.batch_size)
        batch_data.patient_list = [patient_path.stem for patient_path in file_paths_to_load]
        batch_data.patient_path_list = file_paths_to_load
        batch_data.structure_mask_names = self.full_roi_list

        # Populate batch with requested data
        for index, patient_path in enumerate(file_paths_to_load):
            raw_data = self.load_data(patient_path)
            for key in self.required_files:
                batch_data.set_values(key, index, self.shape_data(key, raw_data))

        return batch_data

    def load_data(self, path_to_load: Path) -> Union[NDArray, dict[str, NDArray]]:
        """Load data in its raw form."""
        data = {}
        if path_to_load.is_dir():
            files_to_load = get_paths(path_to_load)
            for file_path in files_to_load:
                is_required = file_path.stem in self.required_files
                is_required_roi = file_path.stem in self.full_roi_list
                if is_required or is_required_roi:
                    data[file_path.stem] = load_file(file_path)
        else:
            data[self.mode_name] = load_file(path_to_load)

        return data

    def shape_data(self, key: str, data: dict) -> NDArray:
        """Shapes into form that is amenable to tensorflow and other deep learning packages."""

        shaped_data = np.zeros(self.required_files[key])
        if key in data:
            if key == "ct" or key =="dose":
                np.put(shaped_data, data[key]["indices"], data[key]["data"])
            elif key == "voxel_dimensions":
                shaped_data = data[key]
            else:
                np.put(shaped_data, data[key], int(1))

        return shaped_data


In [None]:
from keras.layers import Layer, Softmax, Multiply, Add
from keras.layers import  Conv3D
from keras.models import Model
import keras.backend as K
import tensorflow as tf


class SelfAttention(Layer):
    def __init__(self, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        super(SelfAttention, self).build(input_shape)
        assert len(input_shape) == 5, "Input shape should be 5D (batch, height, width, depth, channels)"

        self.query_conv = Conv3D(filters=128, kernel_size=(1, 1, 1), padding='same')
        self.key_conv = Conv3D(filters=128, kernel_size=(1, 1, 1), padding='same')
        self.value_conv = Conv3D(filters=128, kernel_size=(1, 1, 1), padding='same')
        self.softmax = Softmax(axis=-1)

    def call(self, x):
        q = self.query_conv(x)
        k = self.key_conv(x)
        v = self.value_conv(x)

        attention_weights = self.softmax(tf.matmul(q, k, transpose_b=True))

        output = tf.matmul(attention_weights, v)

        return output + x

In [None]:
from keras.layers import Layer
class MatrixMultyply(Layer):
    def __init__(self):
        super(MatrixMultyply, self).__init__()

    def call(self, inputs):
        assert len(inputs) == 2, "This layer requires exactly two inputs."

        return tf.matmul(inputs[0], inputs[1], transpose_b=True)


In [None]:
from typing import Optional
from keras.layers import Activation, AveragePooling3D, Conv3D, Conv3DTranspose, Input, LeakyReLU, SpatialDropout3D, concatenate, MaxPooling3D, Dense, UpSampling3D, Dot, Conv1D, Reshape, Multiply
from keras.layers import BatchNormalization
from keras.models import Model

class DefineDoseFromCT:

    def __init__(
        self,
        data_shapes: DataShapes,
        gen_optimizer,
    ):
        self.data_shapes = data_shapes
        self.gen_optimizer = gen_optimizer
        self.input_shape = (128, 128, 128, 1)

    def make_feature_enhance_module(self, x):
        x = concatenate(x)
        x = Conv3D(11, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = Conv3D(128, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        return x

    def make_feauture_fusion_module(self, ct_encode, brainstem_encode, leftParotid_encode, rightParotid_encode, mandible_encode, spinalCord_encode, esophagus_encode, larynx_encode, ptv56_encode, ptv63_encode, ptv70_encode):
        ct_encode = SelfAttention()(ct_encode)
        brainstem_encode = SelfAttention()(brainstem_encode)
        left_parotid_encode = SelfAttention()(leftParotid_encode)
        right_parotid_encode = SelfAttention()(rightParotid_encode)
        mandible_encode = SelfAttention()(mandible_encode)
        spinal_cord_encode = SelfAttention()(spinalCord_encode)
        esophagus_encode = SelfAttention()(esophagus_encode)
        larynx_encode = SelfAttention()(larynx_encode)
        ptv56_encode = SelfAttention()(ptv56_encode)
        ptv63_encode = SelfAttention()(ptv63_encode)
        ptv70_encode = SelfAttention()(ptv70_encode)

        fused_features = Add()([ct_encode, brainstem_encode, left_parotid_encode,
                        right_parotid_encode, mandible_encode, spinal_cord_encode,
                        esophagus_encode, larynx_encode, ptv56_encode,
                        ptv63_encode, ptv70_encode])

        ffm = Conv3D(128,(1,1,1), strides=1, padding = 'same')(fused_features)
        ffm = BatchNormalization()(ffm)
        ffm = LeakyReLU(alpha=0)(ffm)
        return ffm

    def make_graph_convolution_block(self, ffm, fem):
        x = MatrixMultyply()([ffm, fem])
        x = Conv3D(128, (1,1,1), strides=1)(x)
        x = MatrixMultyply()([x, fem])

        # x = concatenate([ffm, fem])
        x = Conv3D(filters=128, kernel_size=1, strides=1)(x)

        x = Conv3D(128, (3,3,3), strides=1, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(alpha=0)(x)

        x = concatenate([x, ffm])
        x = Conv3D(128, (1,1,1), strides=1)(x)
        x = BatchNormalization()(x)

        x = Conv3D(128,(3,3,3), strides=1, padding = 'same')(x)
        x = BatchNormalization()(ffm)
        x = LeakyReLU(alpha=0)(ffm)
        x = Conv3D(128, (1,1,1), strides=1, padding='same')(x)

        return x

    def make_encode_block(self, y):
        x = Conv3D(16, (3,3,3), strides=1, padding="same")(y)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = Conv3D(32, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = MaxPooling3D((2,2,2), strides=2)(x)

        x = Conv3D(32, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = Conv3D(64, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = MaxPooling3D((2,2,2), strides=2)(x)

        x = Conv3D(64, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = Conv3D(128, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)

        return x

    def define_generator(self) -> Model:
        print(self.data_shapes.ct)
        ct = Input(self.data_shapes.ct)
        ct_encode = self.make_encode_block(ct)

        brainstem = Input(self.data_shapes.Brainstem)
        brainstem_encode = self.make_encode_block(brainstem)

        leftParotid = Input(self.data_shapes.LeftParotid)
        leftParotid_encode = self.make_encode_block(leftParotid)

        rightParotid = Input(self.data_shapes.RightParotid)
        rightParotid_encode = self.make_encode_block(rightParotid)

        mandible = Input(self.data_shapes.LeftParotid)
        mandible_encode = self.make_encode_block(mandible)

        spinalCord = Input(self.data_shapes.SpinalCord)
        spinalCord_encode = self.make_encode_block(spinalCord)

        esophagus = Input(self.data_shapes.Esophagus)
        esophagus_encode = self.make_encode_block(esophagus)

        larynx = Input(self.data_shapes.Larynx)
        larynx_encode = self.make_encode_block(larynx)

        ptv56 = Input(self.data_shapes.PTV56)
        ptv56_encode = self.make_encode_block(ptv56)

        ptv63 = Input(self.data_shapes.PTV63)
        ptv63_encode = self.make_encode_block(ptv63)

        ptv70 = Input(self.data_shapes.PTV70)
        ptv70_encode = self.make_encode_block(ptv70)

        fem = self.make_feature_enhance_module([brainstem_encode, leftParotid_encode, rightParotid_encode, mandible_encode, spinalCord_encode, esophagus_encode, larynx_encode, ptv56_encode, ptv63_encode, ptv70_encode])
        ffm = self.make_feauture_fusion_module(ct_encode, brainstem_encode, leftParotid_encode, rightParotid_encode, mandible_encode, spinalCord_encode, esophagus_encode, larynx_encode, ptv56_encode, ptv63_encode, ptv70_encode)

        graph_convolution = self.make_graph_convolution_block(ffm, fem)

        x = Conv3DTranspose(filters=128, kernel_size=(3,3,3), strides=2, padding="same")(graph_convolution)
        x = Conv3D(128, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = Conv3D(64, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)

        x = Conv3DTranspose(filters= 64, kernel_size=(3,3,3), strides=2, padding="same")(x)
        x = Conv3D(64, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = Conv3D(32, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)

        x = Conv3D(1, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = Conv3D(1, (3,3,3), strides=1, padding="same")(x)
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)

        output = Conv3D(1, 1, activation='sigmoid')(x)

        generator = Model(inputs=[ct, brainstem, leftParotid, rightParotid, mandible,spinalCord, esophagus, larynx, ptv56, ptv63, ptv70 ], outputs=output, name="generator")
        generator.compile(loss="mean_absolute_error", optimizer = self.gen_optimizer)
        generator.summary()
        return generator


In [None]:
import os
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
from keras.models import load_model
from keras.optimizers import Adam

class PredictionModel(DefineDoseFromCT):
    def __init__(self, data_loader: DataLoader, results_patent_path: Path, model_name: str, stage: str) -> None:
        """
        :param data_loader: An object that loads batches of image data
        :param results_patent_path: The path at which all results and generated models will be saved
        :param model_name: The name of your model, used when saving and loading data
        :param stage: Identify stage of model development (train, validation, test)
        """
        super().__init__(
            data_shapes=data_loader.data_shapes,
            gen_optimizer=Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999),
        )

        # set attributes for data shape from data loader
        self.generator = None
        self.model_name = model_name
        self.data_loader = data_loader
        self.full_roi_list = data_loader.full_roi_list

        # Define training parameters
        self.current_epoch = 0
        self.last_epoch = 200

        # Make directories for data and models
        model_results_path = results_patent_path / model_name
        self.model_dir = model_results_path / "models"
        self.model_dir.mkdir(parents=True, exist_ok=True)
        self.prediction_dir = model_results_path / f"{stage}-predictions"
        self.prediction_dir.mkdir(parents=True, exist_ok=True)

        # Make template for model path
        self.model_path_template = self.model_dir / "epoch_"

    def train_model(self, epochs: int = 200, save_frequency: int = 5, keep_model_history: int = 2) -> None:
        """
        :param epochs: the number of epochs the model will be trained over
        :param save_frequency: how often the model will be saved (older models will be deleted to conserve storage)
        :param keep_model_history: how many models are kept on a rolling basis (deletes older than save_frequency * keep_model_history epochs)
        """
        self._set_epoch_start()
        self.last_epoch = epochs
        self.initialize_networks()
        if self.current_epoch == epochs:
            print(f"The model has already been trained for {epochs}, so no more training will be done.")
            return
        self.data_loader.set_mode("training_model")
        for epoch in range(self.current_epoch, epochs):
            self.current_epoch = epoch
            print(f"Beginning epoch {self.current_epoch}")
            self.data_loader.shuffle_data()

            for idx, batch in enumerate(self.data_loader.get_batches()):
                model_loss = self.generator.train_on_batch([batch.ct, batch.Brainstem, batch.LeftParotid, batch.RightParotid, batch.Mandible,batch.SpinalCord, batch.Esophagus, batch.Larynx, batch.PTV56, batch.PTV63, batch.PTV70], [batch.dose])
                print(f"Model loss at epoch {self.current_epoch} batch {idx} is {model_loss:.3f}")

            self.manage_model_storage(save_frequency, keep_model_history)

    def _set_epoch_start(self) -> None:
        all_model_paths = get_paths(self.model_dir, extension="h5")
        for model_path in all_model_paths:
            *_, epoch_number = model_path.stem.split("epoch_")
            if epoch_number.isdigit():
                self.current_epoch = max(self.current_epoch, int(epoch_number))

    def initialize_networks(self) -> None:
        if self.current_epoch >= 1:
            self.generator = load_model(self._get_generator_path(self.current_epoch))
        else:
            self.generator = self.define_generator()

    def manage_model_storage(self, save_frequency: int = 1, keep_model_history: Optional[int] = None) -> None:
        """
        Manage the model storage while models are trained. Note that old models are deleted based on how many models the users has asked to keep.
        We overwrite old files (rather than deleting them) to ensure the Collab users don't fill up their Google Drive trash.
        :param save_frequency: how often the model will be saved (older models will be deleted to conserve storage)
        :param keep_model_history: how many models back are kept (older models will be deleted to conserve storage)
        """
        effective_epoch_number = self.current_epoch + 1  # Epoch number + 1 because we're at the start of the next epoch
        if 0 < np.mod(effective_epoch_number, save_frequency) and effective_epoch_number != self.last_epoch:
            Warning(f"Model at the end of epoch {self.current_epoch} was not saved because it is skipped when save frequency {save_frequency}.")
            return

        # The code below is clunky and was only included to bypass the Google Drive trash, which fills quickly with normal save/delete functions
        epoch_to_overwrite = effective_epoch_number - keep_model_history * (save_frequency or float("inf"))
        if epoch_to_overwrite >= 0:
            initial_model_path = self._get_generator_path(epoch_to_overwrite)
            self.generator.save(initial_model_path)
            os.rename(initial_model_path, self._get_generator_path(effective_epoch_number))  # Helps bypass Google Drive trash
        else:  # Save via more conventional method because there is no model to overwrite
            self.generator.save(self._get_generator_path(effective_epoch_number))

    def _get_generator_path(self, epoch: Optional[int] = None) -> Path:
        epoch = epoch or self.current_epoch
        return self.model_dir / f"epoch_{epoch}.h5"

    def predict_dose(self, epoch: int = 1) -> None:
        """Predicts the dose for the given epoch number"""
        self.generator = load_model(self._get_generator_path(epoch))
        os.makedirs(self.prediction_dir, exist_ok=True)
        self.data_loader.set_mode("dose_prediction")

        print("Predicting dose with generator.")
        for batch in self.data_loader.get_batches():
            dose_pred = self.generator.predict([batch.ct, batch.structure_masks])
            dose_pred = dose_pred * batch.possible_dose_mask
            dose_pred = np.squeeze(dose_pred)
            dose_to_save = sparse_vector_function(dose_pred)
            dose_df = pd.DataFrame(data=dose_to_save["data"].squeeze(), index=dose_to_save["indices"].squeeze(), columns=["data"])
            (patient_id,) = batch.patient_list
            dose_df.to_csv("{}/{}.csv".format(self.prediction_dir, patient_id))


In [None]:
import shutil
from pathlib import Path

if __name__ == "__main__":

    prediction_name = "baseline"  # Name model to train and number of epochs to train it for
    test_time = False  # Only change this to True when the model has been fully tuned on the validation set
    num_epochs = 15  # This should probably be increased to 100-200 after your dry run

    # Define project directories
    primary_directory = Path("/content/drive/MyDrive/Project II")  # directory where everything is stored
    provided_data_dir = primary_directory / "provided-data"
    training_data_dir = provided_data_dir / "train-pats"
    validation_data_dir = provided_data_dir / "validation-pats"
    testing_data_dir = provided_data_dir / "test-pats"
    results_dir = primary_directory / "results_2"  # where any data generated by this code (e.g., predictions, models) are stored

    # Prepare the data directory
    training_plan_paths = get_paths(training_data_dir)  # gets the path of each plan's directory

    # Train a model
    data_loader_train = DataLoader(training_plan_paths)
    dose_prediction_model_train = PredictionModel(data_loader_train, results_dir, prediction_name, "train")
    dose_prediction_model_train.train_model(num_epochs, save_frequency=1, keep_model_history=20)


(128, 128, 128, 1)
Model: "generator"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_34 (InputLayer)       [(None, 128, 128, 128, 1)]   0         []                            
                                                                                                  
 input_35 (InputLayer)       [(None, 128, 128, 128, 1)]   0         []                            
                                                                                                  
 input_36 (InputLayer)       [(None, 128, 128, 128, 1)]   0         []                            
                                                                                                  
 input_37 (InputLayer)       [(None, 128, 128, 128, 1)]   0         []                            
                                                                       

0it [00:00, ?it/s]