In [None]:
import os
import sys
import itertools
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
%load_ext autoreload
%autoreload 2
from itertools import product

import matplotlib.pyplot as plt
import numpy as np

# import whynot as wn
import whynot.gym as gym

from scripts import utils

%matplotlib inline
import whynot.simulators.covid19 as covid19
import whynot.simulators.covid19.environments.starter_env

## Starter env

This environment helps us demonstrate our state space, action space, etc.

In [None]:
env = gym.make('COVID19-v0')
env.seed(1);

In [None]:
action_to_social_distancing_map = {
        0: 1.0,
        1: 0.75,
        2: 0.5,
        3: 0.25,
        4: 0.10,
        5: 0.0
}

social_distancing_to_action_map = {value:key for (key, value) in action_to_social_distancing_map.items()}

In [None]:
social_distancing_to_action_map

In [None]:
n_iter = 100

In [None]:
class NoTreatmentPolicy():
    def sample_action(self, obs):
        return 5
    
class SocialDistancingPolicy():
    def __init__(self, social_distance_val):
        self.social_distance_val = social_distance_val
        
    def sample_action(self, obs):
        return social_distancing_to_action_map[self.social_distance_val] 

In [None]:
policies = {
    "No Treatment": NoTreatmentPolicy(),
    "Social Distance 10%": SocialDistancingPolicy(0.1),
    "Social Distance 25%": SocialDistancingPolicy(0.25),
    "Social Distance 50%": SocialDistancingPolicy(0.5),
    "Social Distance 100%": SocialDistancingPolicy(1.0),
}

In [None]:
def augment_policies(default_policies, policy, policy_name='learned_policy'):
    policies = dict(default_policies)    
    policies[policy_name] = policy
    return policies

In [None]:
def sample_trajectory(env, policy, max_episode_length):
    """Sample a single trajectory, acting according to the specified policy."""
    # initialize env for the beginning of a new rollout
    ob = env.reset()
    obs, acs, rewards, next_obs, terminals = [], [], [], [], []
    steps = 0
    while True:
        # Use the most recent observation to decide what to do
        obs.append(ob)
        ac = policy.sample_action(ob)
        acs.append(ac)

        # Take that action and record results
        ob, rew, done, _ = env.step(ac)

        # Record result of taking that action
        steps += 1
        next_obs.append(ob)
        rewards.append(rew)

        # End the rollout if the rollout ended
        # Note that the rollout can end due to done, or due to max_episode_length
        if done or steps > max_episode_length:
            rollout_done = 1
        else:
            rollout_done = 0
        terminals.append(rollout_done)
        if rollout_done:
            break

    return {"observation" : np.array(obs, dtype=np.float32),
            "reward" : np.array(rewards, dtype=np.float32),
            "action" : np.array(acs, dtype=np.float32),
            "next_observation": np.array(next_obs, dtype=np.float32),
            "terminal": np.array(terminals, dtype=np.float32)}

In [None]:
def plot_sample_trajectory(env, policies):
    """Plot sample trajectories from policies."""
    obs_dim_names = covid19.State.variable_names()

    fig, axes = plt.subplots(4, 3, sharex=True, figsize=[30, 15])
    axes = axes.flatten()    
    
    for name, policy in policies.items():
        trajectory = sample_trajectory(env, policy, 400)
        obs = trajectory["observation"]
        # Plot state evolution
        for i in range(len(obs_dim_names)):
            y = obs[:, i]
            axes[i].plot(y, label=name)
            axes[i].set_ylabel(obs_dim_names[i])
            ymin, ymax = axes[i].get_ylim()
            axes[i].set_ylim(np.minimum(ymin, y.min()), np.maximum(ymax, y.max()))
            ymin, ymax = axes[i].get_ylim()
        
        # Plot actions
        actions = np.array(trajectory["action"])
        action_vals = [1 - action_to_social_distancing_map[action] for action in actions]

        # actionlist = [actionlist_beta,actionlist_hosp,actionlist_rec]
        for idx, label in enumerate(["beta_scale"]):
            ax_idx = len(obs_dim_names) + idx
            axes[ax_idx].plot(action_vals, label=name)
            axes[ax_idx].set_ylabel(label)
        
        # Plot reward
        reward = trajectory["reward"]
        axes[-1].plot(reward, label=name)
        axes[-1].set_ylabel("reward")
        axes[-1].ticklabel_format(scilimits=(-2, 2))
        ymin, ymax = axes[-1].get_ylim()
        axes[-1].set_ylim(np.minimum(ymin, reward.min()), np.maximum(ymax, reward.max()))
        print(f"Total reward for {name}: {np.sum(reward):.2f}")
        
    for ax in axes:
        ax.legend()
        ax.set_xlabel("Day")
    plt.subplots_adjust(hspace=0.4)
    plt.show()

In [None]:
plot_sample_trajectory(env, policies)

## Let's look at a second environment - Prosperous place

There is a cost to deceased but there is a much higher cost to social distancing because of the amount of economic output of this place - so social distancing is very expensive in this location and the policy learns that.

In [None]:
import whynot.simulators.covid19.environments.rich_place_env

In [None]:
rich_env = gym.make('COVID19-RICH-v0')
rich_env.seed(1);

In [None]:
rich_learned_policy = utils.run_training_loop(env=rich_env, n_iter=n_iter, max_episode_length=150, batch_size=1000, learning_rate=1e-3)

In [None]:
plot_sample_trajectory(rich_env, augment_policies(policies, rich_learned_policy))

## Now we look at a place without as much economic output

This place has the same human cost - that is the cost of an individual dying, however this location does not produce nearly as much economic output and so it can afford social distancing practices and we see that the model learns that

In [None]:
import whynot.simulators.covid19.environments.poor_place_env

In [None]:
poor_env = gym.make('COVID19-POOR-v0')
poor_env.seed(1);

In [None]:
poor_learned_policy = utils.run_training_loop(env=poor_env, n_iter=n_iter, max_episode_length=150, batch_size=1000, learning_rate=1e-3)

In [None]:
plot_sample_trajectory(poor_env, augment_policies(policies, poor_learned_policy))

## Confounded environment

All of the environments have an unobserved confounder - that is a location that dictates the cost of social distancing and the rewards that you get - however the confounder level is fixed in the earlier examples.

In this example, the confounder level varies

In [None]:
import whynot.simulators.covid19.environments.confounded_env

In [None]:
confounded_env = gym.make('COVID19-CONFOUNDED-v0')
confounded_env.seed(1);

In [None]:
confounded_learned_policy = utils.run_training_loop(env=confounded_env, n_iter=n_iter, max_episode_length=150, batch_size=1000, learning_rate=1e-3)

In [None]:
plot_sample_trajectory(confounded_env, augment_policies(policies, confounded_learned_policy))

### Possible extension:

One possible extension to this idea is a stochastic action space. Currently our action space is deterministic - we can have an action space where we sample an action with some probability. What this will allow us to do is have a location that confounds both the action and the reward.

Let me give you an example:

According to our earlier definition - India will not have a very high economic output as the US and so it can afford social distancing and so it's number of deaths are lower. However, social distancing is just a policy that is enacted - in a place like India, people may not follow such policies and hence social distancing can never reach 100% like our models. So instead of having a fixed action space, if we were to sample an action from a stochastic policy set, we can sample 100% social distancing say 90% of the time and 10% of the time, select any other policy - like 10% social distancing.

An RL agent will never be able to understand what is happening. As the place is a confounder - we can have the place determine just how much social distancing is possible and what is the cost of social distancing

References:

1. https://wwwnc.cdc.gov/eid/article/26/6/20-0233_article
2. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6332839/