In [1]:
import gym
import numpy as np
import torch
from typing import Iterable, Mapping, Optional, Sequence, Set, Tuple, Union
from reagent.ope.estimators.sequential_estimators import (
    Mdp,
    Model,
    RLPolicy,
    State,
    StateReward,
    Transition,
    ActionSpace,
    ActionDistribution,
    Action,
    RandomRLPolicy,
    RLEstimatorInput,
    IPSEstimator,
    NeuralDualDICE,
)
from reagent.models.dqn import FullyConnectedDQN
from reagent.ope.utils import Clamper, RunningAverage
from gym import wrappers
import matplotlib
import matplotlib.pyplot as plt

NUM_EPISODES = 200
MAX_HORIZON = 250
GAMMA = 0.99
ALPHA = 0.66

device = torch.device("cuda") if torch.cuda.is_available() else None
print(f"Device - {device}")

model = torch.jit.load("/mnt/vol/gfsfblearner-nebraska/flow/data/2020-07-24/18eeebdf-b0ed-4f93-b079-95f7c58656ff/207187922_207187922_0.pt")
model = model.dqn_with_preprocessor.model
model.to(device)

INFO:reagent.core.dataclasses:USE_VANILLA_DATACLASS: False
INFO:reagent.core.dataclasses:ARBITRARY_TYPES_ALLOWED: True
INFO:reagent.core.registry_meta:Adding REGISTRY to type LearningRateSchedulerConfig
INFO:reagent.core.registry_meta:Not Registering LearningRateSchedulerConfig to LearningRateSchedulerConfig. Abstract method [] are not implemented.
INFO:reagent.core.registry_meta:Registering LambdaLR to LearningRateSchedulerConfig
INFO:reagent.core.registry_meta:Registering MultiplicativeLR to LearningRateSchedulerConfig
INFO:reagent.core.registry_meta:Registering StepLR to LearningRateSchedulerConfig
INFO:reagent.core.registry_meta:Registering MultiStepLR to LearningRateSchedulerConfig
INFO:reagent.core.registry_meta:Registering ExponentialLR to LearningRateSchedulerConfig
INFO:reagent.core.registry_meta:Registering CosineAnnealingLR to LearningRateSchedulerConfig
INFO:reagent.core.registry_meta:Registering CyclicLR to LearningRateSchedulerConfig
INFO:reagent.core.registry_meta:Regist

# Define the policy classes

In [2]:
class ComboPolicy(RLPolicy):
    # Weighted combination between two given policies
    def __init__(self, action_space: ActionSpace, weights: Sequence[float], policies: Sequence[RLPolicy]):
        assert len(weights) == len(policies)
        self._weights = weights
        self._policies = policies
        self._action_space = action_space
        self._softmax = torch.nn.Softmax()
    
    def action_dist(self, state: State) -> ActionDistribution:
        weighted_policies = [w * p(state).values for w,p in zip(self._weights, self._policies)]
        weighted = torch.stack(weighted_policies).sum(0)
        dist = self._softmax(weighted)
        return self._action_space.distribution(dist)
    
class PyTorchPolicy(RLPolicy):
    def __init__(self, action_space: ActionSpace, model):
        self._action_space = action_space
        self._model = model
        self._softmax = torch.nn.Softmax()
        
    def action_dist(self, state: State) -> ActionDistribution:
        dist = self._model(torch.tensor(state.value, dtype=torch.float).reshape(1, -1))[0]
        return self._action_space.distribution(self._softmax(dist))

# Utility Functions

In [3]:
def generate_logs(episodes: int, max_horizon: int, policy: RLPolicy) -> Sequence[Mdp]:
    """
    Args:
        episodes: number of episodes to generate
        max_horizon: max horizon of each episode
        policy: RLPolicy which uses real-valued states
    """
    log = []
    env = gym.make('CartPole-v0')
    for _ in range(episodes):
        init_state = env.reset()
        cur_state = init_state
        mdp = []
        for _ in range(max_horizon):
            action_dist = policy(State(cur_state))
            action = action_dist.greedy().value
            action_prob = action_dist.probability(Action(action))
            next_state, reward, done, _ = env.step(action)
            mdp.append(Transition(last_state=State(cur_state),
                                 action=Action(action),
                                 action_prob=action_prob,
                                 state=State(next_state),
                                 reward=reward,
                                 status=2 if done else 1))
            if done:
                break
            cur_state = next_state
        log.append(mdp)
    return log

def zeta_nu_loss_callback(losses: Sequence[Tuple[float, float]], 
                          estimated_values: Sequence, 
                          input: RLEstimatorInput):
    def callback_fn(zeta_loss, nu_loss, estimator):
        losses.append((zeta_loss, nu_loss))
        estimated_values.append(estimator._compute_estimates(input))
    return callback_fn

# Create the trained policy, target policy, and behavior policy

In [4]:
random_policy = RandomRLPolicy(ActionSpace(2))
model_policy = PyTorchPolicy(ActionSpace(2), model)
target_policy = ComboPolicy(ActionSpace(2), [1.0, 0.0], [model_policy, random_policy])
behavior_policy = ComboPolicy(ActionSpace(2), [0.55 + 0.15 * ALPHA, 0.45 - 0.15 * ALPHA], [model_policy, random_policy])

# Generate the logged dataset

In [5]:
log = generate_logs(NUM_EPISODES, MAX_HORIZON, behavior_policy)

  del sys.path[0]


# Estimate the value of the target policy

Since the states are real-valued, instead of estimating v^pi(s), we take the average sum of the discounted rewards over numerous trials, getting E[v^pi(s)]

In [6]:
def estimate_value(episodes: int, max_horizon: int, policy: RLPolicy, gamma: float):
    avg = RunningAverage()
    env = gym.make('CartPole-v0')
    for _ in range(episodes):
        init_state = env.reset()
        cur_state = init_state
        r = 0.0
        discount = 1.0
        for _ in range(max_horizon):
            action_dist = policy(State(cur_state))
            action = action_dist.greedy().value
            action_prob = action_dist.probability(Action(action))
            next_state, reward, done, _ = env.step(action)
            r += reward * discount
            discount *= gamma
            if done:
                break
            cur_state = next_state
        avg.add(r)
    return avg.average

ground_truth = estimate_value(NUM_EPISODES, MAX_HORIZON, target_policy, GAMMA)
print(f"Target Policy Ground Truth value: {ground_truth}")

  del sys.path[0]


Target Policy Ground Truth value: 70.20302794198436


In [7]:
inp = RLEstimatorInput(
    gamma=GAMMA,
    log=log,
    target_policy=target_policy,
    discrete_states=False
)
ips = IPSEstimator()
dualdice_losses = []
dualdice_values = []
dualdice = NeuralDualDICE(4, 2, deterministic_env=True, 
                          value_lr=0.003, zeta_lr=0.003, 
                          batch_size=2048, 
                          loss_callback_fn=zeta_nu_loss_callback(dualdice_losses, dualdice_values, inp),
                          device=device)

In [None]:
ips_result = ips.evaluate(inp)
dd_result = dualdice.evaluate(inp)

INFO:root:IPSEstimator(device(None),weighted[True]}: start evaluating
  del sys.path[0]
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=74.5090560913086, ground_truth=0.0
INFO:root:IPSEstimator(device(None),weighted[True]}: finishing evaluating[process_time=13.853707919000001]
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=12.197612311945937, ground_truth=0.0
INFO:root:Samples 100 Avg Zeta Loss 0.013515950131695717, Avg Value Loss -0.011872679508778674
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=21.359412562633842, ground_truth=0.0
INFO:root:Samples 200 Avg Zeta Loss 0.032867668516701073, Avg Value Loss -0.03195237421035925
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=31.3605478482464, ground_truth=0.0
INFO:root:Samples 300 Avg Zeta Loss 0.06170809593284501, Avg Value Loss -0.060989961180688
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=39.15435264085474, ground_truth=0.0
INFO:root:Sampl

INFO:root:Samples 3900 Avg Zeta Loss 0.6993606929072077, Avg Value Loss -0.695912910153744
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.62992315297156, ground_truth=0.0
INFO:root:Samples 4000 Avg Zeta Loss 0.7185064421679705, Avg Value Loss -0.7149408297876721
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.02106215485004, ground_truth=0.0
INFO:root:Samples 4100 Avg Zeta Loss 0.7377599143193148, Avg Value Loss -0.7340898682761199
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.03011925073041, ground_truth=0.0
INFO:root:Samples 4200 Avg Zeta Loss 0.7571012823719949, Avg Value Loss -0.753317346207699
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.094718872419435, ground_truth=0.0
INFO:root:Samples 4300 Avg Zeta Loss 0.7765199167863462, Avg Value Loss -0.7726287578741903
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.51668158622166, ground_truth=0.0
INFO:root:Samples 4400 Avg Zeta

INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.42537439635173, ground_truth=0.0
INFO:root:Samples 8200 Avg Zeta Loss 1.5770045495220326, Avg Value Loss -1.567101864874857
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=53.78077885487664, ground_truth=0.0
INFO:root:Samples 8300 Avg Zeta Loss 1.5984987575646874, Avg Value Loss -1.5884598440166549
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.39514056925266, ground_truth=0.0
INFO:root:Samples 8400 Avg Zeta Loss 1.6199902530693813, Avg Value Loss -1.609815690150345
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.91772938974874, ground_truth=0.0
INFO:root:Samples 8500 Avg Zeta Loss 1.641498719962718, Avg Value Loss -1.631196654041149
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.53888944436647, ground_truth=0.0
INFO:root:Samples 8600 Avg Zeta Loss 1.6630436982566066, Avg Value Loss -1.6525983797136272
INFO:root:  Append estimate [1]: l

INFO:root:Samples 12400 Avg Zeta Loss 2.5044362658385912, Avg Value Loss -2.4879692679271863
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.80880858501469, ground_truth=0.0
INFO:root:Samples 12500 Avg Zeta Loss 2.526911814873565, Avg Value Loss -2.510299583455398
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.279315619973666, ground_truth=0.0
INFO:root:Samples 12600 Avg Zeta Loss 2.549356189486397, Avg Value Loss -2.5325833088455934
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.1465278174524, ground_truth=0.0
INFO:root:Samples 12700 Avg Zeta Loss 2.5718504080367475, Avg Value Loss -2.554933076825973
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.537210861582004, ground_truth=0.0
INFO:root:Samples 12800 Avg Zeta Loss 2.594342621967998, Avg Value Loss -2.5772745685105436
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.10314162738863, ground_truth=0.0
INFO:root:Samples 12900 Avg Z

INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.54123819166792, ground_truth=0.0
INFO:root:Samples 16800 Avg Zeta Loss 3.5010147787225088, Avg Value Loss -3.477726552019034
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.467344331550805, ground_truth=0.0
INFO:root:Samples 16900 Avg Zeta Loss 3.5238301834925103, Avg Value Loss -3.500386344766703
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.28887473531543, ground_truth=0.0
INFO:root:Samples 17000 Avg Zeta Loss 3.54673501305921, Avg Value Loss -3.5231460407911483
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.92693015615721, ground_truth=0.0
INFO:root:Samples 17100 Avg Zeta Loss 3.569526340886768, Avg Value Loss -3.545794418388864
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.03025080977607, ground_truth=0.0
INFO:root:Samples 17200 Avg Zeta Loss 3.592361263059505, Avg Value Loss -3.5684765164944454
INFO:root:  Append estimate [1]

INFO:root:Samples 21000 Avg Zeta Loss 4.462873749933285, Avg Value Loss -4.433544035719286
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.04397401368982, ground_truth=0.0
INFO:root:Samples 21100 Avg Zeta Loss 4.485874533784833, Avg Value Loss -4.4564147196117085
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.7484523091965, ground_truth=0.0
INFO:root:Samples 21200 Avg Zeta Loss 4.508806750477983, Avg Value Loss -4.479207391413474
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.133593411011255, ground_truth=0.0
INFO:root:Samples 21300 Avg Zeta Loss 4.531784399256959, Avg Value Loss -4.502044722542113
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.992835973327, ground_truth=0.0
INFO:root:Samples 21400 Avg Zeta Loss 4.554780917052289, Avg Value Loss -4.524900238392105
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.8102095312572, ground_truth=0.0
INFO:root:Samples 21500 Avg Zeta Loss

INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=57.73369957405821, ground_truth=0.0
INFO:root:Samples 25300 Avg Zeta Loss 5.453037416978743, Avg Value Loss -5.417410488873751
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.448979572895794, ground_truth=0.0
INFO:root:Samples 25400 Avg Zeta Loss 5.476074773811406, Avg Value Loss -5.440278601525014
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.42735983271223, ground_truth=0.0
INFO:root:Samples 25500 Avg Zeta Loss 5.499157419519285, Avg Value Loss -5.463165989212574
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=57.76284807003968, ground_truth=0.0
INFO:root:Samples 25600 Avg Zeta Loss 5.522222958893005, Avg Value Loss -5.48607737767571
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.829572343881004, ground_truth=0.0
INFO:root:Samples 25700 Avg Zeta Loss 5.545283859320147, Avg Value Loss -5.508961940823504
INFO:root:  Append estimate [1]: l

INFO:root:Samples 29500 Avg Zeta Loss 6.421703718425019, Avg Value Loss -6.3789572640036
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.54017058346755, ground_truth=0.0
INFO:root:Samples 29600 Avg Zeta Loss 6.4447026176520685, Avg Value Loss -6.401785587644748
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.67606859196099, ground_truth=0.0
INFO:root:Samples 29700 Avg Zeta Loss 6.467757576239587, Avg Value Loss -6.4246584097469
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.30032560124734, ground_truth=0.0
INFO:root:Samples 29800 Avg Zeta Loss 6.490880439342621, Avg Value Loss -6.447595181905674
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=55.22865621620491, ground_truth=0.0
INFO:root:Samples 29900 Avg Zeta Loss 6.514036437146596, Avg Value Loss -6.470569968423409
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.45233338038213, ground_truth=0.0
INFO:root:Samples 30000 Avg Zeta Loss 

INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=58.41675860060401, ground_truth=0.0
INFO:root:Samples 33800 Avg Zeta Loss 7.411793868764934, Avg Value Loss -7.361129669572205
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.93623280494192, ground_truth=0.0
INFO:root:Samples 33900 Avg Zeta Loss 7.4348754381616216, Avg Value Loss -7.384032983942099
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.724670704307705, ground_truth=0.0
INFO:root:Samples 34000 Avg Zeta Loss 7.457692771810964, Avg Value Loss -7.406668788089903
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.70628098687136, ground_truth=0.0
INFO:root:Samples 34100 Avg Zeta Loss 7.480636672677502, Avg Value Loss -7.429425143713561
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.75372079863446, ground_truth=0.0
INFO:root:Samples 34200 Avg Zeta Loss 7.50353605338457, Avg Value Loss -7.452143118862954
INFO:root:  Append estimate [1]: l

INFO:root:Samples 38000 Avg Zeta Loss 8.375849528850017, Avg Value Loss -8.317375089689683
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.279387802663166, ground_truth=0.0
INFO:root:Samples 38100 Avg Zeta Loss 8.398751506490301, Avg Value Loss -8.34007971234873
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=54.73150390934831, ground_truth=0.0
INFO:root:Samples 38200 Avg Zeta Loss 8.421663998094793, Avg Value Loss -8.362816867822508
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.289580549960036, ground_truth=0.0
INFO:root:Samples 38300 Avg Zeta Loss 8.444572094237799, Avg Value Loss -8.385532709369837
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=56.229642015453784, ground_truth=0.0
INFO:root:Samples 38400 Avg Zeta Loss 8.467486678098915, Avg Value Loss -8.408262741013504
INFO:root:  Append estimate [1]: log=69.98554447789161, estimated=57.77178224907841, ground_truth=0.0
INFO:root:Samples 38500 Avg Zeta 

In [None]:
def plot_dualdice_losses(losses):
    zeta_losses = [x[0] for x in losses]
    nu_losses = [x[1] for x in losses]
    plt.plot(zeta_losses, label="Zeta Loss")
    plt.plot(nu_losses, label="Nu Loss")
    plt.ylabel("Loss")
    plt.xlabel("Epochs")
    plt.show()

plot_dualdice_losses(dualdice_losses)
        