In [None]:
import os
import glob
import numpy as np
import pandas as pd

from pathlib import Path
from PIL import Image

import tensorflow as tf
from tensorflow import keras

In [None]:
LABEL_PATH = Path(os.getcwd()) / "data" / "labels"
SURVIVAL_PATH = Path(os.getcwd()) / "data" / "survival_data.csv"

In [None]:
image_list = sorted(glob.glob(str(LABEL_PATH) + "/*"))
df_surv = (
    pd.read_csv(SURVIVAL_PATH)
    .sort_values(by="BraTS19ID")
    .dropna(subset=["Survival", "Age"])
    .reset_index(drop=True)
)
y = np.array(df_surv["Survival"])
age = np.array(df_surv["Age"]).reshape(-1, 1)

In [None]:
image_list_clean = []
for id in df_surv["BraTS19ID"]:
    while True:
        if id in image_list[0]:
            image_list_clean.append(image_list[0])
            del image_list[0]
            break
        del image_list[0]

In [None]:
X = []
for i in range(len(image_list_clean)):
    img = np.load(image_list_clean[i])
    stacked = np.stack((img,) * 3, axis=-1)
    X.append(stacked)

In [None]:
def build_model(input_img, age):
    base_model = keras.applications.resnet50.ResNet50(
        weights="imagenet", include_top=False, input_tensor=input_img
    )
    a = keras.layers.GlobalAveragePooling2D()(base_model.output)
    a = keras.layers.concatenate([a, age])
    a = keras.layers.BatchNormalization()(a)
    output = keras.layers.Dense(1)(a)

    model = keras.models.Model(inputs=[input_img, age], outputs=output)
    return model, base_model

In [None]:
input_img = keras.layers.Input((240, 240, 3))
age_m = keras.layers.Input((1))

In [None]:
model, base_model = build_model(input_img, age_m)

for layer in base_model.layers:
    layer.trainable = True

model.compile(
    loss="mean_squared_error",
    optimizer=tf.keras.optimizers.Adam(),
    metrics=["mean_squared_error"],
)
model.summary()

In [None]:
X_ten = keras.backend.constant(X)
age_ten = keras.backend.constant(age)
target = keras.backend.constant(np.array([int(lab) for lab in y]))

In [None]:
history = model.fit(
    x=[X_ten, age_ten], y=target, epochs=25, validation_split=0.2, batch_size=8
)

In [None]:
model.save(os.getcwd() + "/models" + "/survival")