In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import norm 
from scipy.stats import lognorm
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

import tensorflow as tf
from tensorflow.python.summary.summary_iterator import summary_iterator

import gym
from stable_baselines3 import PPO
import mujoco_py


$x_i(x_{i-1}, y, \Delta t) = a+be^{-c\Delta t}+d-e(y-y_0)^2$ should satisfy the following conditions:   
$x_i = x_{i-1} \mbox{ when } \Delta t=0, y=0 \rightarrow a+b=x_{i-1}$    
$x_i = a+be^{-c\Delta t} \mbox{ when } y=0 \rightarrow d=e{y_0}^2$
Therefore, $x_i(x_{i-1}, y, \Delta t)$ can be expressed as $x_{i-1}-b+be^{-c\Delta t}+e(2yy_0-y^2)$ where we have four parameters: $b, c, e, y_0$           
$b$ is the asymptotic change of $x$ without any treatment, higher $b$ means larger change     
$c$ is the decrease rate of $x$, higher $c$ means higher rate        
$e$ is the sensitivity of $x$ to the dosage, higher $e$ means more sensitive     
$y_0$ is the optimal dosage

# Environment
The reward for every step is defined as the surviving days after last visit minus a penalty constant. If the patient did not make it to the next visit, the surviving days is the root of $x_i(\Delta t)=threshold$.     
Note: The root can be negative when the drift part ($e(2yy_0-y^2)$) is large (y deviates too far from $y_0$), but since one cannot live negative days, set surviving days to 0 in this condition.

In [2]:
class SimEnv(gym.Env):
    def __init__(self, b=10, c=0.5, e=3, y0=2, threshold=1, penalty=2, timeout=100):
        '''
        action_space: (y, delta_t)
        observation_space: (x)
        
        Input
        threshold: when x is lower than this threshold, patient die
        penalty: this parameter adds cost to frequent visits
        '''
        super(SimEnv, self).__init__()
        
        self.b = b
        self.c = c
        self.e = e
        self.y0 = y0
        self.threshold = threshold
        self.penalty = penalty
            
        self.action_space = gym.spaces.Box(low=0,high=float("inf"),shape=(2,),dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=-float("inf"), high=float("inf"), shape=(1,), dtype=np.float32)
        self.state = None
        self.timeout = timeout
        self.steps_elapsed = 0
        
    def reset(self):
        self.state = np.random.uniform(self.threshold, 15, 1) # numpy array (1,)
        self.steps_elapsed=0
        return(self.state)
        
        
    def step(self, action):
        delta_t=action[1]
        y=action[0]
        next_obs = self.xfunc(x=self.state, delta_t=delta_t, y=action[0], b=self.b,c= self.c, e=self.e, y0=self.y0)
        
        if next_obs > self.threshold:
            reward = delta_t - self.penalty
        elif (-1/self.c*np.log((self.threshold-self.state[0]+self.b-self.e*(2*y*self.y0-y**2))/self.b) < 0):
            reward = -self.penalty 
        else:
            reward = -1/self.c*np.log((self.threshold-self.state[0]+self.b-self.e*(2*y*self.y0-y**2))/self.b) - self.penalty
            
        self.state = next_obs 
        self.steps_elapsed+=1
        
        return(np.array([self.state]).astype(np.float32), reward, next_obs[0] <= self.threshold or self.steps_elapsed > self.timeout , {"x": self.state, "y": action[0], "delta_t": action[1], "dead": next_obs[0] <= self.threshold})
    
    def xfunc(self, x, delta_t, y, b, c, e, y0):
        '''
        Input
        x: last obs/state
        delta_t: time interval btw visits
        y: dosage
        b, c, e, y0: parameters
        Output
        current obs/state
        '''
        return(x-b+b*np.exp(-c*delta_t)+e*(2*y*y0-y**2)) 
    
    def render(self, mode=None):
        pass
    def close(self):
        pass

In [3]:
s = SimEnv()
s.reset()

array([4.02582458])

In [4]:
type(s.step([1,2])[2])

bool

# Policy - PPO

In [5]:
envgym = gym.make('HalfCheetah-v2')

In [6]:
env = SimEnv()

In [7]:
policy = PPO("MlpPolicy", env, n_steps=128, batch_size=64, verbose=1, tensorboard_log ='smmResults/').learn(total_timesteps=1000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to smmResults/PPO_34
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 24.4     |
|    ep_rew_mean     | -38.9    |
| time/              |          |
|    fps             | 793      |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 128      |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 37.2        |
|    ep_rew_mean          | -56.7       |
| time/                   |             |
|    fps                  | 654         |
|    iterations           | 2           |
|    time_elapsed         | 0           |
|    total_timesteps      | 256         |
| train/                  |             |
|    approx_kl            | 0.022998236 |
|    clip_fraction        | 0.152       |
|    clip_range           | 0.2  

In [8]:
# rewards = []
# losses = []
# for e in tf.compat.v1.train.summary_iterator('smmResults/PPO_9/PPO1'): # how to automatically change file name
#     for v in e.summary.value:
#         if v.tag == "rollout/ep_rew_mean":
#             rewards.append(v.simple_value)
#         if v.tag == "train/loss":
#             losses.append(v.simple_value)
# plt.figure()
# plt.plot(rewards)
# plt.xlabel("iterations")
# plt.ylabel("Episode reward")
# plt.title("Mean Episode Reward Over 5 Episodes")

In [9]:
# rewards = []
# length = []
# for e in tf.compat.v1.train.summary_iterator('smmResults/PPO_9/PPO1'):
#     for v in e.summary.value:
#         if v.tag == "rollout/ep_rew_mean":
#             rewards.append(v.simple_value)
#         if v.tag == "rollout/ep_len_mean":
#             length.append(v.simple_value)
# plt.figure()
# plt.plot(length)
# plt.xlabel("iterations")
# plt.ylabel("Episode Len")
# plt.title("Mean Episode Length Over 5 Episodes")