In [1]:
import sys
import pdb
import os
# this_dir = os.path.dirname(os.path.abspath(__file__))
# project_dir = os.path.join(this_dir, '..', '..')
# sys.path.append(project_dir)

import copy
import numpy as np


# import src.policies.linear_algebra as la


class Glucose(object):
  NUM_STATE = 6 # the last one state variable is death indicator
  MAX_STATE = 1000 * np.ones(NUM_STATE)
  MIN_STATE = np.zeros(NUM_STATE)
  NUM_ACTION = 2

  # Generative model parameters
  mu_B0_X0 = 0
  sigma_B0_X0_Y0 = 1
  mu0 = 7.7
  prob_L_given_trts = np.array([0.2, 0.2, 0.2, 0.35]) # metf, sulf, glit, insu
  sigma_eps = 0.5
  tau_trts = np.array([0.14, 0.2, 0.02, 0.14]) # metf, sulf, glit, insu
  death_prob_coef = np.array([-10, 0.08, 0.5])

  def __init__(self, nPatients=1, x_initials=None, sx_initials=None):
    self.R = [[]] * nPatients  # List of rewards at each time step
    self.A = [[]] * nPatients  # List of actions at each time step
    self.X = [[]] * nPatients  # List of features (previous and current states) at each time step
    self.S = [[]] * nPatients  # List of states: < NAT, Discontinue, A1c, BP, Weight, C_t > at each time step
    self.Xprime_X_inv = None
    self.t = -1
    #    self.horizon = horizon
    self.current_state = [None] * nPatients
    self.last_state = [None] * nPatients
    self.last_action = [None] * nPatients
    self.nPatients = nPatients

    self.x_initials = x_initials
    self.sx_initials = sx_initials

  @staticmethod
  def reward_function(s):
    """

    :param s: state vector at current time step
    :return:
    """
    if s[5]==1:
      return -10.0
    elif s[2]<7:  
      return 1.0
    elif s[2]>7 and s[1]==1:
      return -2.0
    else:
      return 0.0

  @staticmethod
  def reward_funciton_mHealth(s_news):
    """

    :param s_news: an 2-d array of new s values for all patients (dim: (time steps * nPatients) by 6)
    :return:
    """
    r = np.zeros(s_news.shape[0])
    ind = (s_news[:,5] == 1)
    r[ind] = -10.0
    ind1 = s_news[~ind, 2] < 7
    r[~ind][ind1] = 1.0
    ind2 = s_news[~ind, :][~ind1, 1] == 1
    r[~ind][~ind1][ind2] = -2.0
    return np.mean(r, axis=0)

  def get_next_state(self,prev_state,action):
    eps=np.random.normal(0,self.sd_eps,1)
    bp_t=self.get_bp_t(prev_state[3],eps)[0]
    w_t=self.get_w_t(prev_state[4],eps)[0]
    c_t=self.get_c_t(prev_state[2],prev_state[0])[0]
    #Initialize NAT
    NAT=0
    if prev_state[0]<4:
      if action==1:
        NAT=prev_state[0]+1
      else:
        NAT=prev_state[0]
      d_t=self.get_d_t(action,prev_state[0])[0]
      A1c_t=self.get_A1c(prev_state,action,d_t,eps)[0]
      return [NAT,d_t,A1c_t,bp_t,w_t],c_t
  # If NAT>4, make algorithm indiffernt between action 1 and 0 to enable better state space exploration
    else:
      d_t=self.get_d_t(0,prev_state[0])[0]
      A1c_t=self.get_A1c(prev_state,0,d_t,eps)[0]
      return [4,d_t,A1c_t,bp_t,w_t,c_t]

  def get_bp_t(self,prev_bp,eps):
    bp_t=(prev_bp+eps)/math.sqrt(1+float(self.sd_eps)**2)
    return bp_t

  def get_w_t(self,prev_w,eps):
    w_t=(prev_w+eps)/math.sqrt(1+float(self.sd_eps)**2)
    return w_t
    
  def get_c_t(self,prev_A1c,prev_NAT):
    ## death indicator
    A1c_indicator=int(prev_A1c>7)
    x=-10+0.08*A1c_indicator*(prev_A1c**2)+0.5*prev_NAT
    c_t=np.random.binomial(1,self.exp_helper(x),1)
    return c_t

  def get_d_t(self,action,prev_NAT):
    ## treatment discontinuation indicator
    p=0
    if action==0:
      return [0]
    else:
      if prev_NAT==3:
        p=0.35
      else:
        p=0.2
      d_t=np.random.binomial(1,p,1)
      return d_t

  def get_A1c(self,prev_state,action,d_t,eps):
    if prev_state[2]>7 and prev_state[0]<4 and action!=0 and d_t!=1:
      new_u_t=0
      if prev_state[0]==0:
        new_u_t=self.prev_u_t*(1-self.metrformin_te)
      elif prev_state[0]==1:
        new_u_t=self.prev_u_t*(1-self.solfonylurea_te)
      elif prev_state[0]==2:
        new_u_t=self.prev_u_t*(1-self.glitazone_te)
      elif prev_state[0]==3:
        new_u_t=self.prev_u_t*(1-self.insulin_te)
      A1c_t=(prev_state[2]-self.prev_u_t+eps)/math.sqrt(1+self.sd_eps**2)+new_u_t
      self.prev_u_t=new_u_t
      return A1c_t
    else:
      A1c_t=(prev_state[2]-self.prev_u_t+eps)/math.sqrt(1+self.sd_eps**2)+self.prev_u_t
      return A1c_t


  def reset(self):
    """

    :return:
    """
    # Reset obs history
    self.R = [[]] * self.nPatients  # List of rewards at each time step
    self.A = [[]] * self.nPatients  # List of actions at each time step
    self.X = [[]] * self.nPatients  # state plus action
    self.S = [[]] * self.nPatients  # state is 6 dim

    # Generate first states for nPatients
    if self.x_initials is None:
      for i in range(self.nPatients):
        bp_0=np.random.normal(Glucose.mu_B0_X0, Glucose.sigma_B0_X0_Y0, 1)[0]
        w_0=np.random.normal(Glucose.mu_B0_X0, Glucose.sigma_B0_X0_Y0, 1)[0]
        A1c_0=np.random.normal(Glucose.mu0, Glucose.sigma_B0_X0_Y0, 1)[0]
        s_0=[0,0,A1c_0,bp_0,w_0, 0]
        self.S[i] = np.append(self.S[i], [s_0])
    else:
      self.X = [self.x_initials[i,] for i in range(self.nPatients)]
      self.S = [self.sx_initials[i, 1:4] for i in range(self.nPatients)]
      for i in range(self.nPatients):
        self.X[i] = np.vstack((self.X[i], self.X[i]))
        self.S[i] = np.vstack((self.S[i], self.S[i]))

     # self.last_state = [self.sx_initials[i, 4:7] for i in range(self.nPatients)]
     # self.current_state = [self.sx_initials[i, 1:4] for i in range(self.nPatients)]
      self.last_action = [self.x_initials[i, 7] for i in range(self.nPatients)]
      self.A = [np.array(self.x_initials[i, 7]) for i in range(self.nPatients)]
      self.R = [np.array(self.reward_function(self.s_initials[i,:])) \
                for i in range(self.nPatients)]
    return


#   def next_state_and_reward(self, action, i):
#     """

#     :param action:
#     :return:
#     """

#     # Transition to next state
#     #    print(self.current_state[i], self.last_action[i])
#     get_next_state(self,prev_state,action)
#     sx = np.concatenate(([1], self.current_state[i], self.last_state[i],
#                          [self.last_action[i]]))
#     x = np.concatenate((sx, [action]))
#     glucose = np.random.normal(np.dot(x, self.COEF), self.SIGMA_NOISE)
#     food, activity = self.generate_food_and_activity()

#     # Update current and last state and action info
#     self.last_state[i] = copy.copy(self.current_state[i])
#     self.current_state[i] = np.array([glucose, food, activity]).reshape(1, 3)[0]
#     self.last_action[i] = action
#     reward = self.reward_function(self.last_state[i], self.current_state[i])
#     # current_x = np.concatenate()
#     return x, reward

  @staticmethod
  def get_state_at_action(action, x):
    """
    Replace current action entry in x with action.
    :param action:
    :param x:
    :return:
    """
    new_x = copy.copy(x)
    new_x[-1] = action
    return new_x

  def get_state_history_as_array(self):
    """
    Return past states as an array with blocks [ lag 1 states, states]
    :return:
    """
    X_as_array = np.vstack(self.X)
    SX_as_array = np.vstack(self.SX)
    S_as_array = np.vstack(self.S)
    return X_as_array, SX_as_array, S_as_array

  def get_state_transitions_as_x_y_pair(self, new_state_only=True):
    """
    For estimating transition density.
    :return:
    """
    #    X = np.vstack(self.X[1:])
    #    Sp1 = np.vstack(self.S[1:])
    if new_state_only:
      X = np.vstack([self.X[j][1:] for j in range(self.nPatients)])
      Sp1 = np.vstack([self.S[j][1:] for j in range(self.nPatients)])
    else:
      X = np.vstack([self.X[j][1:-1, :] for j in range(self.nPatients)])
      Sp1 = np.vstack([self.X[j][2:, :] for j in range(self.nPatients)])
    return X, Sp1

  def get_current_SX(self):
    ## current state feature for all patients
    current_sx = np.hstack((np.hstack((np.ones((self.nPatients, 1)), np.hstack((self.current_state, self.last_state)))),
                            np.array(self.last_action).reshape(self.nPatients, 1)))
    return current_sx

#  def get_Xprime_X_inv(self):
#    X, _ = self.get_state_transitions_as_x_y_pair()
#    x_new = self.get_state_history_as_array()[-1,]
#    if self.Xprime_X_inv is None:  # Can't do fast update
#      Xprime_X_new = np.dot(X.T, X)
#      self.Xprime_X_inv = np.linalg.inv(Xprime_X_new + 0.01*np.eye(X.shape[1]))
#    else:
#      # Compute Xprime_X_inv
#      self.Xprime_X_inv= la.sherman_woodbury(self.Xprime_X_inv, x_new, x_new)
#

[1 2]
