# Case study 6 - DQN

Changes compared to OG work by Guillaume :
- `policy` is now an epsilon-greedy policy when taking rollouts (still deterministic when evaluating pairwise comparisons with target weights)
- we keep 2 sets of classifiers (original + target). Target weights get smooth updates in the form of exponentially moving average
- we keep past transitions in a circular replay buffer

These last two changes help with off policiness/exploration and stability.

## Setup

In [1]:
import numpy as np

import scipy.integrate as integrate

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

from collections import deque
from random import sample
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Hyperparameters

batchSize = 250
numberOfActions = 4
numberOfEpochs = 10

epsilon        = 0.1      #for the epsilon-greedy policy
buffer_size    = 2000
num_rollouts   = 250
num_iterations = 60
beta           = 0.95     #soft updates

# Some of the following code is based on a PyTorch tutorial in the official PyTorch website:
# Below is the definition of the neural networks used for the pair-wize classification of the actions
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(2, 16).to(device)
        self.fc2 = nn.Linear(16, 16).to(device)
        self.fc3 = nn.Linear(16, 16).to(device)
        self.fc4 = nn.Linear(16, 2).to(device)

    def forward(self, x):

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        logits = self.fc4(x)
        return logits

# We use the cross entropy loss function
lossFunction = nn.CrossEntropyLoss()

# Initialization of the classifiers for each pairs of actions:
classifiers = {}
target_classifiers = {}

for actionIndex1 in range(numberOfActions):
  for actionIndex2 in range(actionIndex1 + 1, numberOfActions):

    net = Net().to(device)
    input = torch.randn((1, 2), device=device)
    out = net(input)

    net.zero_grad()
    out.backward(torch.randn((1, 2), device=device))

    classifiers[(actionIndex1, actionIndex2)] = net
    target_classifiers[(actionIndex1, actionIndex2)] = net   #We initialize params and target params at the same values.

rng = np.random.default_rng()

# Each action is represented by an index in a dictionary. Such that each action is accessed via this index
actionsDictionary = {}
actionsDictionary.update({0 : 0.1})
actionsDictionary.update({1 : 0.4})
actionsDictionary.update({2 : 0.7})
actionsDictionary.update({3 : 1.0})


# Below are the constants used in the simulations of the cancer treatement plan
a1 = 0.15
a2 = 0.1
b1 = 1.2
b2 = 1.2
c0 = -4
c1 = 1
c2 = 1
d1 = 0.5
d2 = 0.5


# Below is the definition of delta Y as described in the paper. It takes among its arguments the index of an action that corresponds to the amount of chemical given to the patient
def deltaY(XState, YState, initialXState, actionIndex):
  if(YState <= 0):
    return 0

  return a1 * np.maximum(XState, initialXState) - b1 * (actionsDictionary[actionIndex] - d1)

# Below is the definition of delta X as described in the paper. It takes among its arguments the index of an action that corresponds to the amount of chemical given to the patient
def deltaX(XState, YState, initialYState, actionIndex):

  return a2 * np.maximum(YState, initialYState) + b2 * (actionsDictionary[actionIndex] - d2)

# Below is just adding the deltas to the states, for 1 time step
def simulate1Step(XState, YState, initialXState, initialYState, actionIndex):

  return (XState + deltaX(XState, YState, initialYState, actionIndex),
          YState + deltaY(XState, YState, initialXState, actionIndex))

# Below is the function that returns (0) if the patient has died during the present point in time in the simulation. If the patient lives, it returns 1
def checkLifeStatus(previousXState, previousYState, presentXState, presentYState):

  def XAsAFunctionOfTime(time):
    return previousXState + time * (presentXState - previousXState)

  def YAsAFunctionOfTime(time):
    return previousYState + time * (presentYState - previousYState)

  def lambdaAsAFunctionOfTime(time):
    return np.exp(c0 + c1 * YAsAFunctionOfTime(time) + c2 * XAsAFunctionOfTime(time))

  lambdaIntegral = integrate.quad(lambdaAsAFunctionOfTime, 0, 1)[0]

  probabilityOfDeath = 1 - np.exp(-lambdaIntegral)

  if(rng.random() < probabilityOfDeath):
    return 0
  else:
    return 1



# A return value of (-1) means that action1 is preferable to action2, a return value of (1) means that action2 is preferable to action1
# We follow here the treatement plan for (1) patient under different starting actions (actionIndex1) and (actionIndex2)
# It does not make sense to take (XState) and (YState) different from (initialXState) and (initialYState) repectively; I just though at first that
# I needed code for simulations which would start somewhere in the middle of the treatement plan (somewhere else than at the beginning); but I didn't change the code to make it cleaner yet.
def evaluatePreference(XState, YState,
                       initialXState, initialYState,
                       actionIndex1, actionIndex2,
                       timeIndex,
                       policy, classifiers):

  maximumToxicityWithActionIndex1 = XState
  maximumToxicityWithActionIndex2 = XState

  #print("here0")

  previousXStateWithActionIndex1 = XState
  previousXStateWithActionIndex2 = XState

  previousYStateWithActionIndex1 = YState
  previousYStateWithActionIndex2 = YState

  # We simulate (1) time step here:

  (XStateWithActionIndex1, YStateWithActionIndex1) = simulate1Step(XState, YState, initialXState, initialYState, actionIndex1)

  (XStateWithActionIndex2, YStateWithActionIndex2) = simulate1Step(XState, YState, initialXState, initialYState, actionIndex2)


  # We check the life status here:
  lifeStatusWithActionIndex1 = checkLifeStatus(previousXStateWithActionIndex1, previousYStateWithActionIndex1, XStateWithActionIndex1, YStateWithActionIndex1)
  lifeStatusWithActionIndex2 = checkLifeStatus(previousXStateWithActionIndex2, previousYStateWithActionIndex2, XStateWithActionIndex2, YStateWithActionIndex2)

  # If the patient has died for (1) of the actions, then the following logic gives the pareto dominance relationship:
  if(lifeStatusWithActionIndex2 > lifeStatusWithActionIndex1):
    return 1
  elif(lifeStatusWithActionIndex2 < lifeStatusWithActionIndex1):
    return -1
  elif((lifeStatusWithActionIndex2 == 0) and (lifeStatusWithActionIndex1 == 0)):
    return 0

  # We store the maximum toxicity here, that tells us about pareto dominance:
  maximumToxicityWithActionIndex1 = np.maximum(maximumToxicityWithActionIndex1, XStateWithActionIndex1)
  maximumToxicityWithActionIndex2 = np.maximum(maximumToxicityWithActionIndex2, XStateWithActionIndex2)




  timeIndex = timeIndex + 1

  # The following are the remaining time steps in the simulation; they follow the same logic are previously described
  while(timeIndex < 6):

    # The remaining action indices are chosen according to the policy in all remaining simulation steps:
    actionIndex1 = policy(classifiers, XStateWithActionIndex1, YStateWithActionIndex1)
    actionIndex2 = policy(classifiers, XStateWithActionIndex2, YStateWithActionIndex2)

    # We always store the state X and Y for each action indices. (ActionIndex1) and (ActionIndex2) correspond to the initial action taken at the beginning that have to be compared
    previousXStateWithActionIndex1 = XStateWithActionIndex1
    previousXStateWithActionIndex2 = XStateWithActionIndex2

    previousYStateWithActionIndex1 = YStateWithActionIndex1
    previousYStateWithActionIndex2 = YStateWithActionIndex2

    (XStateWithActionIndex1, YStateWithActionIndex1) = simulate1Step(XStateWithActionIndex1, YStateWithActionIndex1, initialXState, initialYState, actionIndex1)
    (XStateWithActionIndex2, YStateWithActionIndex2) = simulate1Step(XStateWithActionIndex2, YStateWithActionIndex2, initialXState, initialYState, actionIndex2)

    lifeStatusWithActionIndex1 = checkLifeStatus(previousXStateWithActionIndex1, previousYStateWithActionIndex1, XStateWithActionIndex1, YStateWithActionIndex1)
    lifeStatusWithActionIndex2 = checkLifeStatus(previousXStateWithActionIndex2, previousYStateWithActionIndex2, XStateWithActionIndex2, YStateWithActionIndex2)

    if(lifeStatusWithActionIndex2 > lifeStatusWithActionIndex1):
      return 1
    elif(lifeStatusWithActionIndex2 < lifeStatusWithActionIndex1):
      return -1
    elif((lifeStatusWithActionIndex2 == 0) and (lifeStatusWithActionIndex1 == 0)):
      return 0

    maximumToxicityWithActionIndex1 = np.maximum(maximumToxicityWithActionIndex1, XStateWithActionIndex1)
    maximumToxicityWithActionIndex2 = np.maximum(maximumToxicityWithActionIndex2, XStateWithActionIndex2)

    timeIndex = timeIndex + 1

  tumorSizeAtTheEndWithActionIndex1 = YStateWithActionIndex1
  tumorSizeAtTheEndWithActionIndex2 = YStateWithActionIndex2

  # The following logic describes the pareto dominance relationship when the patient has survived under the 2 choices of initial actions:
  if((tumorSizeAtTheEndWithActionIndex2 < tumorSizeAtTheEndWithActionIndex1) and (maximumToxicityWithActionIndex2 < maximumToxicityWithActionIndex1)):
    return 1
  elif((tumorSizeAtTheEndWithActionIndex1 < tumorSizeAtTheEndWithActionIndex2) and (maximumToxicityWithActionIndex1 < maximumToxicityWithActionIndex2)):
    return -1
  else:
    return 0


# A return value of (-1) means that policy1 is preferable to policy2, a return value of (1) means that policy2 is preferable to policy1
# The code below is almost identical to the code for function (evaluatePreference) above, except that here, at each time steps, actions are taken
# from each of the (2) different policies (policy1) and (policy2).
def evaluatePreferenceBetween2Policies(XState, YState,
                       initialXState, initialYState,
                       timeIndex,
                       policy1, policy2, classifiers):

  maximumToxicityWithActionIndex1 = XState
  maximumToxicityWithActionIndex2 = XState

  #print("here0")

  previousXStateWithActionIndex1 = XState
  previousXStateWithActionIndex2 = XState

  previousYStateWithActionIndex1 = YState
  previousYStateWithActionIndex2 = YState

  actionIndex1 = policy1(classifiers, XState, YState)
  actionIndex2 = policy2(classifiers, XState, YState)

  (XStateWithActionIndex1, YStateWithActionIndex1) = simulate1Step(XState, YState, initialXState, initialYState, actionIndex1)

  #print("here 0.1")
  (XStateWithActionIndex2, YStateWithActionIndex2) = simulate1Step(XState, YState, initialXState, initialYState, actionIndex2)


  #print("here1")

  lifeStatusWithActionIndex1 = checkLifeStatus(previousXStateWithActionIndex1, previousYStateWithActionIndex1, XStateWithActionIndex1, YStateWithActionIndex1)
  lifeStatusWithActionIndex2 = checkLifeStatus(previousXStateWithActionIndex2, previousYStateWithActionIndex2, XStateWithActionIndex2, YStateWithActionIndex2)

  if(lifeStatusWithActionIndex2 > lifeStatusWithActionIndex1):
    return 1
  elif(lifeStatusWithActionIndex2 < lifeStatusWithActionIndex1):
    return -1
  elif((lifeStatusWithActionIndex2 == 0) and (lifeStatusWithActionIndex1 == 0)):
    return 0

  maximumToxicityWithActionIndex1 = np.maximum(maximumToxicityWithActionIndex1, XStateWithActionIndex1)
  maximumToxicityWithActionIndex2 = np.maximum(maximumToxicityWithActionIndex2, XStateWithActionIndex2)




  timeIndex = timeIndex + 1

  while(timeIndex < 6):

    actionIndex1 = policy1(classifiers, XStateWithActionIndex1, YStateWithActionIndex1)
    actionIndex2 = policy2(classifiers, XStateWithActionIndex2, YStateWithActionIndex2)

    previousXStateWithActionIndex1 = XStateWithActionIndex1
    previousXStateWithActionIndex2 = XStateWithActionIndex2

    previousYStateWithActionIndex1 = YStateWithActionIndex1
    previousYStateWithActionIndex2 = YStateWithActionIndex2

    (XStateWithActionIndex1, YStateWithActionIndex1) = simulate1Step(XStateWithActionIndex1, YStateWithActionIndex1, initialXState, initialYState, actionIndex1)
    (XStateWithActionIndex2, YStateWithActionIndex2) = simulate1Step(XStateWithActionIndex2, YStateWithActionIndex2, initialXState, initialYState, actionIndex2)

    lifeStatusWithActionIndex1 = checkLifeStatus(previousXStateWithActionIndex1, previousYStateWithActionIndex1, XStateWithActionIndex1, YStateWithActionIndex1)
    lifeStatusWithActionIndex2 = checkLifeStatus(previousXStateWithActionIndex2, previousYStateWithActionIndex2, XStateWithActionIndex2, YStateWithActionIndex2)

    if(lifeStatusWithActionIndex2 > lifeStatusWithActionIndex1):
      return 1
    elif(lifeStatusWithActionIndex2 < lifeStatusWithActionIndex1):
      return -1
    elif((lifeStatusWithActionIndex2 == 0) and (lifeStatusWithActionIndex1 == 0)):
      return 0

    maximumToxicityWithActionIndex1 = np.maximum(maximumToxicityWithActionIndex1, XStateWithActionIndex1)
    maximumToxicityWithActionIndex2 = np.maximum(maximumToxicityWithActionIndex2, XStateWithActionIndex2)

    timeIndex = timeIndex + 1

  tumorSizeAtTheEndWithActionIndex1 = YStateWithActionIndex1
  tumorSizeAtTheEndWithActionIndex2 = YStateWithActionIndex2

  if((tumorSizeAtTheEndWithActionIndex2 < tumorSizeAtTheEndWithActionIndex1) and (maximumToxicityWithActionIndex2 < maximumToxicityWithActionIndex1)):
    return 1
  elif((tumorSizeAtTheEndWithActionIndex1 < tumorSizeAtTheEndWithActionIndex2) and (maximumToxicityWithActionIndex1 < maximumToxicityWithActionIndex2)):
    return -1
  else:
    return 0


def randomPolicy(classifiers, XState, YState):

  return np.random.randint(numberOfActions)


# (possibly) epsilon-greedy policy with the pair-wise classifiers
def policy(classifiers, XState, YState, eps_greedy = False):

  # We first choose a random index:
  randomInitialActionIndex = np.random.randint(numberOfActions)

  if np.random.rand() < epsilon and eps_greedy:
      return randomInitialActionIndex


  bestActionIndex = randomInitialActionIndex

  # (actionIndicesToCheck) gives the series of actions to check successively to find the best action
  actionIndicesToCheck = [0, 1, 2, 3]
  actionIndicesToCheck.remove(bestActionIndex)

  # (classifiersAppliedOnTheStateChoices) is a dictionary containing the classifiers applied on the input state (XState, YState)
  # The results are going to be used to classify the actions
  classifiersAppliedOnTheStateChoices = {}

  for actionIndex1 in range(numberOfActions):
    for actionIndex2 in range(actionIndex1 + 1, numberOfActions):

      classifier = classifiers[(actionIndex1, actionIndex2)]

      # The best action is given by taking the argmax of the classifier applied on the state. A value of (0) means that, for the classifier at hand,
      # (actionIndex1) is pareto dominant to (actionIndex2). A value of (1) give the opposite dominance
      classifiersAppliedOnTheStateChoices[(actionIndex1, actionIndex2)] = torch.argmax(
          classifier(torch.tensor([[XState, YState]], device=device, dtype=torch.float32))
      )

  for actionIndexToCheck in actionIndicesToCheck:

    # (bestActionIndex) is checked against all possibilities of (actionIndexToCheck) in (actionIndicesToCheck)
    # The key (bestActionIndex, actionIndexToCheck) might not correspond to a classifier in the classifiers dictionary, because switching the 2 indices would have just given the reversed classifier.
    # Since we might not have learned the classifier corresponding to (bestActionIndex, actionIndexToCheck), because of symmetry, we first check we have it in the outer (if) statement
    if (bestActionIndex, actionIndexToCheck) in classifiersAppliedOnTheStateChoices:

      # The following means that actionIndexToCheck is pareto dominant to the previous bestActionIndex. So we store it, and continue the outer (for) loop for the other actionIndicesToCheck.
      # We do this until we have verified all actions and found the dominant one
        if(classifiersAppliedOnTheStateChoices[(bestActionIndex, actionIndexToCheck)] == 1):

          bestActionIndex = actionIndexToCheck
    else:
      if(classifiersAppliedOnTheStateChoices[(actionIndexToCheck, bestActionIndex)] == 0):

          bestActionIndex = actionIndexToCheck

  return bestActionIndex


ReplayBuffer = deque([], maxlen = buffer_size)
#when trying to append elements in ReplayBuffer if len(RB) == maxlen, past samples will be thrown out from the other side

## DQN

In [None]:
for _ in tqdm(range(num_iterations)):

  # Sample collection by rollouts
  rolloutIndex = 0
  while(rolloutIndex < num_rollouts):
    # The initial states of the patient are taken to be random, as stated in the paper
    initialXState = rng.random() * 2
    initialYState = rng.random() * 2

    actionIndex1 = policy(classifiers, initialXState, initialYState, eps_greedy=True)   #Choose best action w.r.t. epsilon-greedy policy
    actionIndex2 = policy(classifiers, initialXState, initialYState, eps_greedy=True)

    if actionIndex1 == actionIndex2:
      continue #quick sanity check to keep comparing different actions

    # Ordering actions to be consistent with the rest of the framework
    actionIndex1, actionIndex2 = min(actionIndex1, actionIndex2), max(actionIndex1, actionIndex2)

    # We specify (2) actions (actionIndex1) and (actionIndex2) that are to be compared by the present state of the present pair-wise target classifier
    preferenceViaParetoDominance = evaluatePreference(initialXState, initialYState,
                    initialXState, initialYState,
                    actionIndex1, actionIndex2,
                    0,
                    policy, target_classifiers)

    # We only store cases that have a definite pareto dominance for training. A value of (0) returned by (evaluatePreference) means that none of (actionIndex1) or (actionIndex2) is preferable over the other for the present state
    if(preferenceViaParetoDominance != 0):

      # We naturally use one-hot encoding for training via the cross-entropy loss
      onehotEncoding = torch.tensor([0,0], device=device, dtype=torch.float32)
      if(preferenceViaParetoDominance == 1):
        onehotEncoding[1] = 1
      else:
        onehotEncoding[0] = 1

      ReplayBuffer.append(
          ((actionIndex1, actionIndex2), torch.tensor([initialXState, initialYState], device=device), onehotEncoding)
      )

      rolloutIndex = rolloutIndex + 1

  # Training classifiers
  Minibatch = sample(ReplayBuffer, batchSize)

  for actionIndex1 in range(numberOfActions):
    for actionIndex2 in range(actionIndex1 + 1, numberOfActions):

      relevant_samples = list(filter(lambda tup: tup[0] == (actionIndex1, actionIndex2), Minibatch))

      if not relevant_samples:
          continue

      train_dataloader = DataLoader(relevant_samples, batch_size = len(relevant_samples), shuffle = True)

      classifierToTrain = classifiers[(actionIndex1, actionIndex2)]

      optimizer = optim.SGD(classifierToTrain.parameters(), lr=0.01)

      # Training for (1) specific classifier (classifierToTrain):
      for _ in range(numberOfEpochs):
        for (_, inputState, preference) in train_dataloader:

            optimizer.zero_grad()   # zero the gradient buffers
            output = classifierToTrain(inputState)

            loss = lossFunction(output, preference)
            loss.backward()
            optimizer.step()

      #soft target update
      target_state_dict = target_classifiers[(actionIndex1, actionIndex2)].state_dict()
      policy_state_dict = classifiers[(actionIndex1, actionIndex2)].state_dict()

      for key in policy_state_dict:
          target_state_dict[key] = beta*target_state_dict[key] + (1-beta)*policy_state_dict[key]
          target_classifiers[(actionIndex1, actionIndex2)].load_state_dict(target_state_dict)



  0%|          | 0/60 [00:00<?, ?it/s]

In [None]:
numberOfRolloutsToTestIfLearnedPolicyIsBetterThanRandomPolicy = 500

finalTumorSizesLearnedPolicy = []
finalTumorSizesRandomPolicy = []

maximumToxicityWithLearnedPolicy = []
maximumToxicityWithRandomPolicy = []

for rolloutIndex in range(numberOfRolloutsToTestIfLearnedPolicyIsBetterThanRandomPolicy):

    initialXState = rng.random() * 2
    initialYState = rng.random() * 2

    (xStateLearnedPolicy, yStateLearnedPolicy) = (initialXState, initialYState)
    (xStateRandomPolicy, yStateRandomPolicy) = (initialXState, initialYState)

    maximumToxicityWithLearnedPolicy.append(0)
    maximumToxicityWithRandomPolicy.append(0)

    for timeStepIndex in range(6):

      actionIndexLearnedPolicy = policy(classifiers, xStateLearnedPolicy, yStateLearnedPolicy)
      actionIndexRandomPolicy = np.random.randint(4)

      # print("(initialXState, initialYState): ", (initialXState, initialYState))
      # print("(xStateLearnedPolicy, yStateLearnedPolicy): ", (xStateLearnedPolicy, yStateLearnedPolicy))
      # print("(xStateRandomPolicy, yStateRandomPolicy): ", (xStateRandomPolicy, yStateRandomPolicy))

      (xStateLearnedPolicy, yStateLearnedPolicy) = simulate1Step(xStateLearnedPolicy, yStateLearnedPolicy, initialXState, initialYState, actionIndexLearnedPolicy)
      (xStateRandomPolicy, yStateRandomPolicy) = simulate1Step(xStateRandomPolicy, yStateRandomPolicy, initialXState, initialYState, actionIndexRandomPolicy)

      maximumToxicityWithLearnedPolicy[rolloutIndex] = np.maximum(xStateLearnedPolicy, maximumToxicityWithLearnedPolicy[rolloutIndex])
      maximumToxicityWithRandomPolicy[rolloutIndex] = np.maximum(xStateRandomPolicy, maximumToxicityWithRandomPolicy[rolloutIndex])

    finalTumorSizesLearnedPolicy.append(yStateLearnedPolicy)
    finalTumorSizesRandomPolicy.append(yStateRandomPolicy)

print("Average Final Tumor Sizes for Learned Policy: ", np.mean(np.array(finalTumorSizesLearnedPolicy)))
print("Average Final Tumor Sizes for Random Policy: ", np.mean(np.array(finalTumorSizesRandomPolicy)))
print("Average Maximum Toxicity for Learned Policy: ", np.mean(np.array(maximumToxicityWithLearnedPolicy)))
print("Average Maximum Toxicity for Random Policy: ", np.mean(np.array(maximumToxicityWithRandomPolicy)))

Average Final Tumor Sizes for Learned Policy:  4.783999098226286
Average Final Tumor Sizes for Random Policy:  1.8526111010850634
Average Maximum Toxicity for Learned Policy:  0.6628490298834889
Average Maximum Toxicity for Random Policy:  2.4249693134132073
