In [None]:
import sys
sys.path.append("..")
import os


os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from model_cpp.model_env_cpp import CellEnvironment, transform_densities
import cppCellModel

from DDPG_DENSNET.OUActionNoise import OUActionNoise
from DDPG_DENSNET.algorithm import get_actor, get_critic, policy, update_target, learn
import tensorflow as tf
from DDPG_DENSNET.Buffer import Buffer
# from model.cell_environment import CellEnvironment, transform_densities
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcol
import json
import os
from tensorflow.keras.models import load_model
import tensorflow as tf
from tqdm import tqdm
from misc.draw_treatment import make_img, make_img3
from sklearn.preprocessing import MinMaxScaler

# import tensorflow as tf
# physical_devices = tf.config.list_physical_devices('GPU')
# for device in physical_devices:
#     tf.config.experimental.set_memory_growth(device, True)


scaler = MinMaxScaler(feature_range=(0, 255))


def save_tumor_image(data, tick):
    data = transform_densities(data)
    sizes = np.shape(data)
    fig = plt.figure()
    fig.set_size_inches(1. * sizes[0] / sizes[1], 1, forward = False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.imshow(data)
    plt.savefig('tmp/t'+str(tick), dpi=500)
    plt.close()

def save_dose_map(data, tick):
    pos = plt.imshow(data, vmin=0, vmax=70, cmap=mcol.LinearSegmentedColormap.from_list("MyCmapName", [[0, 0, 0.6], "r"]))
    cb = plt.colorbar(pos, ticks=[0, 35, 70])
    cb.set_label(label='[Gy]', size='large', weight='bold')
    cb.ax.tick_params(labelsize='large')
    plt.axis('off')
    plt.tight_layout(pad=0.05)
    plt.savefig('tmp/d'+str(tick))
    plt.close()




env = CellEnvironment('segmentation', False, 'dose', 'AC', True)
env.init_dataset()
env.init_dose_map()

std_dev = 0.2
ou_noise = OUActionNoise(mean=np.zeros(1), std_deviation=float(std_dev) * np.ones(1))

actor_model = load_model('tmp/3actor_DENSNET.h5')
critic_model = load_model('tmp/3critic_model_DENSNET.h5')

# actor_model.summary()
# critic_model.summary()
target_actor = get_actor()
target_critic = get_critic()
# target_actor = load_model('tmp/actor_DENSNET.h5')
# target_critic = load_model('tmp/critic_model_DENSNET.h5')


# Making the weights equal initially
target_actor.set_weights(actor_model.get_weights())
target_critic.set_weights(critic_model.get_weights())

# Learning rate for actor-critic models
critic_lr = 0.002
actor_lr = 0.001

critic_optimizer = tf.keras.optimizers.Adam(critic_lr)
actor_optimizer = tf.keras.optimizers.Adam(actor_lr)

k = 1
steps_per_epoch = 50
# Discount factor for future rewards
gamma = 0.99
# Used to update target networks
tau = 0.005

buffer = Buffer(50000, 64)


#To store reward history of each episode
ep_reward_list = []
# To store average reward history of last few episodes
avg_reward_list = []

count = 0
length_success = []
avg_rad = []
avg_h_cell_killed = []
max_KH =np.zeros(k)
avg_percentage = []
avg_doses = []


# Takes about 4 min to train
for i in range(k):
    env = CellEnvironment('segmentation', False, 'dose', 'AC', True)
    env.init_dataset()
    env.init_dose_map()
    _ = env.reset(-1)

    obs_dim1 = np.array(env.observe()).squeeze() * (255.0)
    obs_dim2 = cppCellModel.observeGlucose(env.controller_capsule) *(255/5300)
    obs_dim3 = cppCellModel.observeOxygen(env.controller_capsule) *(255/170000)

    prev_state = tf.convert_to_tensor(np.array([obs_dim1, obs_dim2, obs_dim3]).reshape((50,50,3)))

    for t in tqdm(range(steps_per_epoch)):

        tf_prev_state = tf.expand_dims(tf.convert_to_tensor(prev_state), 0)

        action, saction = policy(actor_model, tf_prev_state, ou_noise, cond=False)
        # Recieve state and reward from environment.
        #print("Acting ...")
        reward, dose, time, KH = env.act(action)
        if KH > max_KH[i]:
            max_KH[i]=KH

        obs_dim1 = np.array(env.observe()).squeeze() * (255.0)
        obs_dim2 = cppCellModel.observeGlucose(env.controller_capsule) *(255/5300)
        obs_dim3 = cppCellModel.observeOxygen(env.controller_capsule) *(255/170000)

        state = tf.convert_to_tensor(np.array([obs_dim1, obs_dim2, obs_dim3]).reshape((50,50,3)))

        done, which_terminal = env.inTerminalState()
        #print("Recording ...")
        buffer.record((prev_state, action, reward, state))
        #print("Updating ...")
        # update_target(target_actor.variables, actor_model.variables, tau)
        # update_target(target_critic.variables, critic_model.variables, tau)
        #print(done)
        # End this episode when `done` is True
        if which_terminal == 'W':
            count += 1
            break

        prev_state = state

    length_success.append(env.get_tick() - 350)
    avg_rad.append(env.total_dose)
    avg_percentage.append(env.surviving_fraction())
    avg_h_cell_killed.append(env.radiation_h_killed)
    avg_doses.append(env.num_doses)

rads = np.array(avg_rad)
percentages = np.array(avg_percentage)
fracs = np.array(avg_doses)
durations = np.array(length_success)

print("TCP = ", count / k)
print("Avg rad", np.mean(rads), "Std error:", np.std(rads))
print("Avg length in successes", np.mean(durations), "Std error:", np.std(durations))
print("Avg number of doses", np.mean(fracs), "Std error:", np.std(fracs))
print("Avg hcells killed", np.mean(avg_h_cell_killed), "Std error:", np.std(avg_h_cell_killed))
print("Avg surviving fraction: ", np.mean(percentages), "Std error:", np.std(percentages))
print("Avg MAX hcells killed per fraction : ", np.mean(max_KH), "Std error:", np.std(max_KH))



print("done")

path = "./tmp"
if not os.path.exists(path):
   os.makedirs(path)

save_tumor_image(env.tumor_images[0][1], env.tumor_images[0][0])
save_dose_map(env.dose_maps[0][1], env.dose_maps[0][0])
save_tumor_image(env.tumor_images[int(len(env.tumor_images) / 3)][1], env.tumor_images[int(len(env.tumor_images) / 3)][0])
save_dose_map(env.dose_maps[int(len(env.tumor_images) / 3)][1], env.dose_maps[int(len(env.tumor_images) / 3)][0])
save_tumor_image(env.tumor_images[int(len(env.tumor_images) / 2)][1], env.tumor_images[int(len(env.tumor_images) / 2)][0])
save_dose_map(env.dose_maps[int(len(env.tumor_images) / 2)][1], env.dose_maps[int(len(env.tumor_images) / 2)][0])
save_tumor_image(env.tumor_images[int(len(env.tumor_images) * 2 / 3)][1], env.tumor_images[int(len(env.tumor_images) * 2 / 3)][0])
save_dose_map(env.dose_maps[int(len(env.tumor_images) * 2 / 3)][1], env.dose_maps[int(len(env.tumor_images) * 2 / 3)][0])
save_tumor_image(env.tumor_images[-1][1], env.tumor_images[-1][0])
save_dose_map(env.dose_maps[-1][1], env.dose_maps[-1][0])
ticks4 = [env.tumor_images[0][0], env.tumor_images[int(len(env.tumor_images) / 3)][0], env.tumor_images[int(len(env.tumor_images) *2 / 3)][0], env.tumor_images[-1][0]]
make_img(ticks4, 'fig'+'4')
ticks3 = [env.tumor_images[0][0], env.tumor_images[int(len(env.tumor_images) / 2)][0], env.tumor_images[-1][0]]
make_img3(ticks3, 'fig'+'3')


ticks, counts, doses = env.dataset
fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('time (h)')
ax1.set_ylabel('Dose (Gy)', color=color)
ax1.set_ylim(0, 5)
ax1.set_xlim(0, 450)
ax1.plot(ticks, doses, color=color, marker='o', mew=8, linewidth=4)
ax1.tick_params(axis='y', labelcolor=color)
d_ticks = []
d_counts = []
for i in range(len(ticks)):
    d_ticks += [ticks[i], ticks[i]]
    d_counts += [counts[i][0], counts[i][1]]

fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.savefig(path+'/'+'fig'+'_treat.pdf', format='pdf')
