MIT License
Copyright (c) 2023 Okyaz Eminaga
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

In [None]:
from lifelines.utils.concordance import concordance_index
from pathlib import Path
import tensorflow as tf
from tensorflow.compat.v1.keras.backend import set_session
from typing import Dict, Iterable, Sequence, Tuple, Optional
import pandas as pd
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
set_session(tf.compat.v1.Session(config=config))


In [None]:
from albumentations import (
    Compose, RandomBrightness, JpegCompression, HueSaturationValue, RandomContrast, HorizontalFlip,
    Rotate, RandomSizedCrop, CenterCrop
)
transforms = Compose([
    Rotate(limit=40),
    RandomBrightness(limit=0.1),
    JpegCompression(quality_lower=85, quality_upper=100, p=0.5),
    HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30,
                       val_shift_limit=20, p=0.5),
    RandomContrast(limit=0.2, p=0.5),
    HorizontalFlip()
])

no_change_transform = Compose([CenterCrop(4096, 4096, always_apply=True), RandomSizedCrop(
    [512, 586], 512, 512, p=1.0, always_apply=True)])


In [4]:
from PIL import Image


def _make_riskset(time: np.ndarray) -> np.ndarray:
    """Compute mask that represents each sample's risk set.

    Parameters
    ----------
    time : np.ndarray, shape=(n_samples,)
        Observed event time sorted in descending order.

    Returns
    -------
    risk_set : np.ndarray, shape=(n_samples, n_samples)
        Boolean matrix where the `i`-th row denotes the
        risk set of the `i`-th instance, i.e. the indices `j`
        for which the observer time `y_j >= y_i`.
    """
    assert time.ndim == 1, "expected 1D array"

    # sort in descending order
    o = np.argsort(-time, kind="mergesort")
    n_samples = len(time)
    risk_set = np.zeros((n_samples, n_samples), dtype=np.bool_)
    for i_org, i_sort in enumerate(o):
        ti = time[i_sort]
        k = i_org
        while k < n_samples and ti == time[o[k]]:
            k += 1
        risk_set[i_sort, o[:k]] = True
    return risk_set


class InputFunction(object):
    """Callable input function that computes the risk set for each batch.

    Parameters
    ----------
    images : np.ndarray, shape=(n_samples, height, width)
        Image data.
    time : np.ndarray, shape=(n_samples,)
        Observed time.
    event : np.ndarray, shape=(n_samples,)
        Event indicator.
    batch_size : int, optional, default=64
        Number of samples per batch.
    drop_last : int, optional, default=False
        Whether to drop the last incomplete batch.
    shuffle : bool, optional, default=False
        Whether to shuffle data.
    seed : int, optional, default=89
        Random number seed.
    """

    def __init__(self,
                 x: np.ndarray,
                 time: np.ndarray,
                 event: np.ndarray,
                 augmentation: bool = False,
                 input_size: (int, int) = (512, 512),
                 channel_number: int = 3,
                 batch_size: int = 32,
                 drop_last: bool = False,
                 shuffle: bool = False,
                 k: int = 1,
                 read_file: bool = False,
                 repeat: int = 1,
                 resize_img: bool = False,
                 seed: int = 89) -> None:
        self.x = x
        self.time = time
        self.input_size = input_size
        self.augmentation = augmentation
        self.event = event
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.shuffle = shuffle
        self.seed = seed
        self.repeat = repeat
        self.k = k
        self.resize_img = resize_img
        self.read_file = read_file
        self.channel_number = channel_number

    def size(self) -> int:
        """Total number of samples."""
        return len(self.x)

    def steps_per_epoch(self) -> int:
        """Number of batches for one epoch."""
        return int(np.floor(len(self.x) / self.batch_size))

    def _get_data_batch(self, index: np.ndarray) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
        """Compute risk set for samples in batch."""
        time = self.time[index].copy()
        event = self.event[index].copy()
        x = self.x[index].copy()
        if self.read_file:
            images = []
            for fl in x:
                img = Image.open(fl)
                img = img.resize((5120//self.k, 5120//self.k))
                img = np.array(img)
                # img = cv2.resize(img, (5120//self.k, 5120//self.k))
                data = {"image": img}
                if self.augmentation:
                    aug_data = transforms(**data)
                else:
                    aug_data = no_change_transform(**data)
                aug_img = aug_data["image"]
                images.append(aug_img)
            x = np.array(images)
        else:
            if self.resize_img:
                x_tmp = []
                for j in range(x.shape[0]):
                    x_tmp.append(
                        resize(x[j], self.input_size, preserve_range=True).astype(np.uint8))
                x = np.array(x_tmp)
            if self.augmentation:
                for i in range(x.shape[0]):
                    data = {"image": x[i]}
                    aug_data = transforms(**data)
                    x[i] = aug_data["image"]

        labels = {
            "label_event": event.astype(np.int32),
            "label_time": time.astype(np.float32),
            "label_riskset": _make_riskset(time)
        }
        return x, labels

    def _iter_data(self) -> Iterable[Tuple[np.ndarray, Dict[str, np.ndarray]]]:
        """Generator that yields one batch at a time."""
        index = np.arange(self.size())
        rnd = np.random.RandomState(self.seed)

        if self.shuffle:
            rnd.shuffle(index)
        for b in range(self.steps_per_epoch()):
            start = b * self.batch_size
            idx = index[start:(start + self.batch_size)]
            yield self._get_data_batch(idx)

        if not self.drop_last:
            start = self.steps_per_epoch() * self.batch_size
            idx = index[start:]
            yield self._get_data_batch(idx)

    def _get_shapes(self) -> Tuple[tf.TensorShape, Dict[str, tf.TensorShape]]:
        """Return shapes of data returned by `self._iter_data`."""
        batch_size = self.batch_size if self.drop_last else None
        h, w = self.input_size
        c = self.channel_number
        images = tf.TensorShape([batch_size, h, w, c])

        labels = {k: tf.TensorShape((batch_size,))
                  for k in ("label_event", "label_time")}
        labels["label_riskset"] = tf.TensorShape((batch_size, batch_size))
        return images, labels

    def _get_dtypes(self) -> Tuple[tf.DType, Dict[str, tf.DType]]:
        """Return dtypes of data returned by `self._iter_data`."""
        labels = {"label_event": tf.int32,
                  "label_time": tf.float32,
                  "label_riskset": tf.bool}
        return tf.float32, labels

    def _make_dataset(self) -> tf.data.Dataset:
        """Create dataset from generator."""
        options = tf.data.Options()
        options.experimental_optimization.noop_elimination = True
        # options.experimental_optimization.apply_default_optimizations=True
        options.experimental_optimization.map_parallelization = True
        ds = tf.data.Dataset.from_generator(
            self._iter_data,
            self._get_dtypes(),
            self._get_shapes()
        )
        ds = ds.with_options(options)
        if self.repeat > 1:
            return ds.repeat(self.repeat)
        else:
            return ds

    def __call__(self) -> tf.data.Dataset:
        return self._make_dataset()


def safe_normalize(x: tf.Tensor) -> tf.Tensor:
    """Normalize risk scores to avoid exp underflowing.

    Note that only risk scores relative to each other matter.
    If minimum risk score is negative, we shift scores so minimum
    is at zero.
    """
    x_min = tf.reduce_min(x, axis=0)
    c = tf.zeros_like(x_min)
    norm = tf.where(x_min < 0, -x_min, c)
    return x + norm


def logsumexp_masked(risk_scores: tf.Tensor,
                     mask: tf.Tensor,
                     axis: int = 0,
                     keepdims: Optional[bool] = None) -> tf.Tensor:
    """Compute logsumexp across `axis` for entries where `mask` is true."""
    risk_scores.shape.assert_same_rank(mask.shape)

    with tf.name_scope("logsumexp_masked"):
        mask_f = tf.cast(mask, risk_scores.dtype)
        risk_scores_masked = tf.math.multiply(risk_scores, mask_f)
        # for numerical stability, substract the maximum value
        # before taking the exponential
        amax = tf.reduce_max(risk_scores_masked, axis=axis, keepdims=True)
        risk_scores_shift = risk_scores_masked - amax

        exp_masked = tf.math.multiply(tf.exp(risk_scores_shift), mask_f)
        exp_sum = tf.reduce_sum(exp_masked, axis=axis, keepdims=True)
        output = amax + tf.math.log(exp_sum)
        if not keepdims:
            output = tf.squeeze(output, axis=axis)
    return output


class CoxPHLoss(tf.keras.losses.Loss):
    """Negative partial log-likelihood of Cox's proportional hazards model."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self,
             y_true: Sequence[tf.Tensor],
             y_pred: tf.Tensor) -> tf.Tensor:
        """Compute loss.

        Parameters
        ----------
        y_true : list|tuple of tf.Tensor
            The first element holds a binary vector where 1
            indicates an event 0 censoring.
            The second element holds the riskset, a
            boolean matrix where the `i`-th row denotes the
            risk set of the `i`-th instance, i.e. the indices `j`
            for which the observer time `y_j >= y_i`.
            Both must be rank 2 tensors.
        y_pred : tf.Tensor
            The predicted outputs. Must be a rank 2 tensor.

        Returns
        -------
        loss : tf.Tensor
            Loss for each instance in the batch.
        """
        event, riskset = y_true
        predictions = y_pred

        pred_shape = predictions.shape
        if pred_shape.ndims != 2:
            raise ValueError("Rank mismatch: Rank of predictions (received %s) should "
                             "be 2." % pred_shape.ndims)

        if pred_shape[1] is None:
            raise ValueError("Last dimension of predictions must be known.")

        if pred_shape[1] != 1:
            raise ValueError("Dimension mismatch: Last dimension of predictions "
                             "(received %s) must be 1." % pred_shape[1])

        if event.shape.ndims != pred_shape.ndims:
            raise ValueError("Rank mismatch: Rank of predictions (received %s) should "
                             "equal rank of event (received %s)" % (
                                 pred_shape.ndims, event.shape.ndims))

        if riskset.shape.ndims != 2:
            raise ValueError("Rank mismatch: Rank of riskset (received %s) should "
                             "be 2." % riskset.shape.ndims)

        event = tf.cast(event, predictions.dtype)
        predictions = safe_normalize(predictions)

        with tf.name_scope("assertions"):
            assertions = (
                tf.debugging.assert_less_equal(event, 1.),
                tf.debugging.assert_greater_equal(event, 0.),
                tf.debugging.assert_type(riskset, tf.bool)
            )

        # move batch dimension to the end so predictions get broadcast
        # row-wise when multiplying by riskset
        pred_t = tf.transpose(predictions)
        # compute log of sum over risk set for each row
        rr = logsumexp_masked(pred_t, riskset, axis=1, keepdims=True)
        assert rr.shape.as_list() == predictions.shape.as_list()

        losses = tf.math.multiply(event, rr - predictions)

        return losses  # *0.00001


In [5]:
class CindexMetric:
    """Computes concordance index across one epoch."""

    def reset_states(self) -> None:
        """Clear the buffer of collected values."""
        self._data = {
            "label_time": [],
            "label_event": [],
            "prediction": []
        }

    def update_state(self, y_true: Dict[str, tf.Tensor], y_pred: tf.Tensor) -> None:
        """Collect observed time, event indicator and predictions for a batch.

        Parameters
        ----------
        y_true : dict
            Must have two items:
            `label_time`, a tensor containing observed time for one batch,
            and `label_event`, a tensor containing event indicator for one batch.
        y_pred : tf.Tensor
            Tensor containing predicted risk score for one batch.
        """
        self._data["label_time"].append(y_true["label_time"].numpy())
        self._data["label_event"].append(y_true["label_event"].numpy())
        self._data["prediction"].append(tf.squeeze(y_pred).numpy())

    def result(self) -> Dict[str, float]:
        """Computes the concordance index across collected values.

        Returns
        ----------
        metrics : dict
            Computed metrics.
        """
        data = {}
        for k, v in self._data.items():
            data[k] = np.concatenate(v)
        # (event_times, predicted_scores, event_observed=None)
        results = concordance_index(
            data["label_time"],
            data["prediction"],
            data["label_event"] == 1)

        result_data = {}
        # names = ("cindex")#, "concordant", "discordant", "tied_risk")
        # for k, v in zip(names, results):
        result_data["cindex"] = results

        return result_data


In [7]:
time_train = np.load("./time_train_shuffled_10x.npy")
event_train = np.load("./event_train_shuffled_10x.npy")
image_train = np.load("./image_train_shuffled_10x.npy", mmap_mode="r")

time_valid = np.load("./time_valid_shuffled_10x.npy")
event_valid = np.load("./event_valid_shuffled_10x.npy")
image_valid = np.load("./image_valid_shuffled_10x.npy", mmap_mode="r")


In [12]:
train_fn = InputFunction(image_train, time_train, event_train,
                         drop_last=True,
                         augmentation=True,
                         repeat=1,
                         shuffle=True,
                         resize_img=False,
                         input_size=(512, 512),
                         batch_size=16)
eval_fn = InputFunction(image_valid, time_valid, event=event_valid, resize_img=False,
                        input_size=(512, 512))


In [16]:
from plexusnet.architecture import PlexusNet


In [19]:
import tensorflow.compat.v2.summary as summary
import tensorflow_addons as tfa
from sklearn.metrics import roc_auc_score
from tensorflow.python.ops import summary_ops_v2
from tqdm import tqdm

path_model = "./PlexusNET_BCR_10x_COX"
if not os.path.exists(path_model):
    os.mkdir(path_model)


class TrainAndEvaluateModel:

    def __init__(self, model, model_dir, train_dataset, eval_dataset,
                 learning_rate, num_epochs, steps_per_epoch):
        self.num_epochs = num_epochs
        self.model_dir = model_dir

        self.model = model

        self.train_ds = train_dataset
        self.val_ds = eval_dataset

        self.optimizer = tfa.optimizers.MovingAverage(
            tf.optimizers.Adam(learning_rate=learning_rate))
        self.loss_fn = CoxPHLoss()

        self.train_loss_metric = tf.keras.metrics.Mean(name="train_loss")
        self.val_loss_metric = tf.keras.metrics.Mean(name="val_loss")
        self.val_cindex_metric = CindexMetric()
        self.steps_per_epoch = steps_per_epoch

    @tf.function
    def train_one_step(self, x, y_event, y_riskset):
        y_event = tf.expand_dims(y_event, axis=1)
        with tf.GradientTape() as tape:
            logits = self.model(x, training=True)
            train_loss = self.loss_fn(
                y_true=[y_event, y_riskset], y_pred=logits)

        with tf.name_scope("gradients"):

            grads = tape.gradient(train_loss, self.model.trainable_weights)
            self.optimizer.apply_gradients(
                zip(grads, self.model.trainable_weights))
        return train_loss, logits

    def train_and_evaluate(self):
        ckpt = tf.train.Checkpoint(
            step=tf.Variable(0, dtype=tf.int64),
            optimizer=self.optimizer,
            model=self.model)
        ckpt_manager = tf.train.CheckpointManager(
            ckpt, str(self.model_dir), max_to_keep=self.num_epochs)

        if ckpt_manager.latest_checkpoint:
            ckpt.restore(ckpt_manager.latest_checkpoint)
            print(
                f"Latest checkpoint restored from {ckpt_manager.latest_checkpoint}.")

        train_summary_writer = summary.create_file_writer(
            str(self.model_dir / "train"))
        val_summary_writer = summary.create_file_writer(
            str(self.model_dir / "valid"))

        for epoch in range(self.num_epochs):
            with train_summary_writer.as_default():
                self.train_one_epoch(ckpt.step, epoch)
            self.model.save(f"{path_model}/model_{epoch}")

            # Run a validation loop at the end of each epoch.
            with val_summary_writer.as_default():
                v = self.evaluate(ckpt.step)

        save_path = ckpt_manager.save()
        print(f"Saved checkpoint for step {ckpt.step.numpy()}: {save_path}")

    def train_one_epoch(self, step_counter, epoch):
        progress = tqdm(self.train_ds, total=self.steps_per_epoch)
        for x, y in progress:
            train_loss, logits = self.train_one_step(
                x, y["label_event"], y["label_riskset"])

            step = int(step_counter)
            if step == 0:
                # see https://stackoverflow.com/questions/58843269/display-graph-using-tensorflow-v2-0-in-tensorboard
                func = self.train_one_step.get_concrete_function(
                    x, y["label_event"], y["label_riskset"])
                summary_ops_v2.graph(func.graph)

            # Update training metric.
            self.train_loss_metric.update_state(train_loss)

            # Log every 200 batches.
            mean_loss = self.train_loss_metric.result()
            progress.set_description_str(
                desc=f"Epoch: {1+epoch}/{self.num_epochs} | Loss: {mean_loss:.4f}")
            # save summaries
            summary.scalar("loss", mean_loss, step=step_counter)
            # Reset training metrics
            self.train_loss_metric.reset_states()

            step_counter.assign_add(1)

    @tf.function
    def evaluate_one_step(self, x, y_event, y_riskset):
        y_event = tf.expand_dims(y_event, axis=1)
        val_logits = self.model(x, training=False)
        val_loss = self.loss_fn(y_true=[y_event, y_riskset], y_pred=val_logits)
        return val_loss, val_logits

    def evaluate(self, step_counter):
        self.val_cindex_metric.reset_states()

        for x_val, y_val in self.val_ds:
            val_loss, val_logits = self.evaluate_one_step(
                x_val, y_val["label_event"], y_val["label_riskset"])

            # Update val metrics
            self.val_loss_metric.update_state(val_loss)
            self.val_cindex_metric.update_state(y_val, val_logits)

        val_loss = self.val_loss_metric.result()
        summary.scalar("loss",
                       val_loss,
                       step=step_counter)
        self.val_loss_metric.reset_states()

        val_cindex = self.val_cindex_metric.result()
        for key, value in val_cindex.items():
            summary.scalar(key, value, step=step_counter)
        print(
            f"Validation: loss = {val_loss:.4f}, cindex = {val_cindex['cindex']:.4f}")
        return val_cindex['cindex']


In [20]:
!rm - rf ckpts-PlexusNET_BCR_10x_COX
!mkdir ckpts-PlexusNET_BCR_10x_COX


In [21]:
tf.keras.backend.clear_session()  # Clears the state of the previous model
model = PlexusNet(depth=5, length=2, junction=3, n_class=1, final_activation="linear", initial_filter=6, filter_num_for_first_convlayer=4,
                  input_shape=(512, 512), ApplyLayerNormalization=True, run_all_BN=False, type_of_block="soft_att", GlobalPooling="avg").model


In [None]:
model.summary()


In [22]:
lr_sh = tfa.optimizers.CyclicalLearningRate(initial_learning_rate=1e-6,
                                            maximal_learning_rate=1e-3,
                                            step_size=train_fn.steps_per_epoch()*4,
                                            scale_fn=lambda x: 1.,
                                            scale_mode="cycle",
                                            name="MyCyclicScheduler")

trainer = TrainAndEvaluateModel(
    model=model,
    model_dir=Path("./ckpts-PlexusNET_BCR_10x_COX"),
    train_dataset=train_fn(),
    eval_dataset=eval_fn(),
    learning_rate=lr_sh,
    num_epochs=50,
    steps_per_epoch=train_fn.steps_per_epoch()
)


In [24]:
# Load the TensorBoard notebook extension.
%load_ext tensorboard


In [None]:
trainer.train_and_evaluate()


In [None]:
from plexusnet.architecture import LoadModel
import PIL
from tqdm import tqdm
from collections import defaultdict
valid_set = pd.read_csv("valid_set.csv")
valid_set = valid_set[valid_set.Filename.str.contains(
    "/B/") == False]  # Exclude benign samples
valid_set["X1st.BCR.Type"].value_counts()
valid_set["BCR_status"] = 1-valid_set["X1st.BCR.Type"].str.contains("-")
print(valid_set["X1st.BCR.Type"].value_counts())
print(valid_set["BCR_status"].value_counts())


def GetResult(results):
    case_lst = defaultdict(list)
    y_true_case = defaultdict(list)
    Gls_case = defaultdict(list)
    time_case = defaultdict(list)
    for i, fl in enumerate(results):
        case_id = list(results.keys())[i].split("-")[2]
        case_lst[case_id].extend(results[fl])
        time_case[case_id].append(
            valid_set["Interval.RP.to.BCR.or.last.contact.death"].iloc[i])
        y_true_case[case_id].append(valid_set.BCR_status.iloc[i])
        Gls_case[case_id].append(list(results.keys())[i].split("/")[2])
    y_true_lst = []
    y_pred_lst = []
    y_time_lst = []
    for key in y_true_case:
        y_true_lst.append(y_true_case[key][0])
        _m = np.mean(case_lst[key])
        y_pred_lst.append(_m)
        y_time_lst.append(time_case[key][0])
    print(roc_auc_score(y_true_lst, y_pred_lst), concordance_index(
        y_time_lst, 1-np.array(y_pred_lst), y_true_lst))
    return {'roc': roc_auc_score(y_true_lst, y_pred_lst),
            'cindex': concordance_index(y_time_lst, 1-np.array(y_pred_lst), y_true_lst)}


def RunAnalyses(model_best):
    results = defaultdict(list)
    heatmaps = defaultdict(list)
    for fl in tqdm(valid_set.Filename):
        img = np.array(PIL.Image.open(fl).resize((5120//4, 5120//4)))
        img = img[128:-128, 128:-128]
        heatmap = np.zeros((3, 3), dtype=np.float)
        patch = []
        for j in range(0, img.shape[0]-256, 256):
            for i in range(0, img.shape[1]-256, 256):
                patch.append(img[j:j+512, i:i+512])
        pr = model_best.predict(np.array(patch), verbose=0)
        k = 0
        for j in range(0, 3):
            for i in range(0, 3):
                heatmap[j, i] = pr[k]
                k += 1
        heatmaps[fl] = heatmap
        results[fl] = pr
    return heatmaps, results


# Run Analyze and
heatmaps_model = {}
results_model = {}
cindex_model = {}
roc_model = {}

for epoch in range(1, 50):
    print(epoch)
    model_best = LoadModel(f"./PlexusNET_BCR_10x_COX/model_{epoch}")
    heatmaps, results = RunAnalyses(model_best)
    v = GetResult(results)
    heatmaps_model[epoch] = heatmaps
    results_model[epoch] = results
    cindex_model[epoch] = v["cindex"]
    roc_model[epoch] = v["roc"]


In [None]:
epoch = 49
model.load_weights(f"./PlexusNET_BCR_10x_COX/model_{epoch:02d}")


In [43]:
test_set = pd.read_csv("./test_set_OnlyTumor.csv")
development_set = pd.read_csv("./development_set_OnlyTumor.csv")


In [44]:
folder_path = "./dataset/"


def RunAnalyses(model_best, dataset):
    results = defaultdict(list)
    heatmaps = defaultdict(list)
    for fl_ in tqdm(dataset.Filename):

        fl = folder_path+fl_[1:]
        # resize to 1280x1280 = ~10x
        img = np.array(PIL.Image.open(fl).resize((5120//4, 5120//4)))
        # reduce white area as the TMA core is centered.
        img = img[128:-128, 128:-128]
        heatmap = np.zeros((3, 3), dtype=np.float)
        patch = []
        for j in range(0, img.shape[0]-256, 256):
            for i in range(0, img.shape[1]-256, 256):
                img_C = np.array(PIL.Image.fromarray(
                    img[j:j+512, i:i+512]).resize((512, 512)), dtype=np.uint8)
                patch.append(img_C)
        pr = model_best.predict(np.array(patch), verbose=0)
        k = 0
        for j in range(0, 3):
            for i in range(0, 3):
                heatmap[j, i] = pr[k]
                k += 1
        heatmaps[fl] = heatmap
        results[fl] = pr
    return heatmaps, results


In [None]:
_, results_development_set = RunAnalyses(model, development_set)


In [None]:
_, results_testset = RunAnalyses(model, test_set)


In [51]:
columns_development_set = defaultdict(list)
for j in range(9):
    for k in results_development_set:
        columns_development_set[j].append(
            results_development_set[k][j].flatten()[0])


In [52]:
columns_test = defaultdict(list)
for j in range(9):
    for k in results_testset:
        columns_test[j].append(results_testset[k][j].flatten()[0])


In [53]:
for col in columns_test:
    test_set[col] = columns_test[col]
test_set.to_csv(f"PlexusNet_COX_test_set.csv")


In [54]:
for col in columns_development_set:
    development_set[col] = columns_development_set[col]
development_set.to_csv(f"PlexusNet_COX_development_set.csv")
