In [5]:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy

import torch
import torch.nn as nn
import torch.functional as F


<stable_baselines3.ppo.ppo.PPO at 0x19265656d30>

In [3]:
env = gym.make("CartPole-v1")
expert = PPO(
    policy=MlpPolicy,
    env=env,
    seed=0,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0003,
    n_epochs=10,
    n_steps=64,
)
expert.learn(100000)  # Note: set to 100000 to train a proficient expert

In [5]:
from stable_baselines3.common.evaluation import evaluate_policy

reward, _ = evaluate_policy(expert, env, 10)
print(reward)

500.0


In [30]:
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv
import numpy as np

rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
    rollout.make_sample_until(min_timesteps=None, min_episodes=50),
    rng=rng,
)
transitions = rollout.flatten_trajectories(rollouts)

In [564]:
transitions[0]

{'obs': array([-0.00072172, -0.03225593,  0.00235076,  0.0262831 ], dtype=float32),
 'acts': 1,
 'infos': {},
 'next_obs': array([-0.00136683,  0.16283223,  0.00287643, -0.26565722], dtype=float32),
 'dones': False}

In [7]:
print(
    f"""The `rollout` function generated a list of {len(rollouts)} {type(rollouts[0])}.
After flattening, this list is turned into a {type(transitions)} object containing {len(transitions)} transitions.
The transitions object contains arrays for: {', '.join(transitions.__dict__.keys())}."
"""
)

The `rollout` function generated a list of 50 <class 'imitation.data.types.TrajectoryWithRew'>.
After flattening, this list is turned into a <class 'imitation.data.types.Transitions'> object containing 25000 transitions.
The transitions object contains arrays for: obs, acts, infos, next_obs, dones."



# Behavioral Cloning

- transitions contain all the states and actions, can use to train behavioral cloning
- state is a continuous 4-tuple, actions is a binary variable

In [9]:
for i, x in enumerate(transitions):
    print(x.keys())
    break   

dict_keys(['obs', 'acts', 'infos', 'next_obs', 'dones'])


### define Net for behavioral cloning as benchmark, dataset, learning structure

In [10]:
# simple MLP for behavioral cloning
import torch
import torch.nn as nn
import torch.functional as F

class BC_Net_Cartpole(nn.Module):
    def __init__(self, hidden_size, input_size=4, output_size=1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.layers(x).flatten()

In [11]:
states = np.array([x['obs'] for x in transitions])
actions = np.array([x['acts'] for x in transitions])

from torch.utils.data import Dataset
class ExpertDataset(Dataset):
        def __init__(self, X, y):
            self.X = torch.tensor(X, dtype=torch.float32)
            self.y = torch.tensor(y, dtype=torch.float32)
    
        def __len__(self):
            return len(self.y)
        
        def __getitem__(self, idx):
            return self.X[idx], self.y[idx]

dataset = ExpertDataset(X=states, y=actions)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

#### Train Loop for Behavioral Cloning Model

In [12]:
model = BC_Net_Cartpole(hidden_size=16)
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 20

for epoch in range(num_epochs):
    for batch_idx, (inputs, targets) in enumerate(train_dataloader):
        # Forward pass
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        # Backward pass and optimization
        optimizer.zero_grad()  
        loss.backward()         
        optimizer.step()        
        if batch_idx % 500 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

model.eval()

Epoch 0, Batch 0, Loss: 0.6962
Epoch 1, Batch 0, Loss: 0.3149
Epoch 2, Batch 0, Loss: 0.2566
Epoch 3, Batch 0, Loss: 0.3066
Epoch 4, Batch 0, Loss: 0.3244
Epoch 5, Batch 0, Loss: 0.3091
Epoch 6, Batch 0, Loss: 0.2432
Epoch 7, Batch 0, Loss: 0.1987
Epoch 8, Batch 0, Loss: 0.2136
Epoch 9, Batch 0, Loss: 0.2265
Epoch 10, Batch 0, Loss: 0.3165
Epoch 11, Batch 0, Loss: 0.3276
Epoch 12, Batch 0, Loss: 0.1960
Epoch 13, Batch 0, Loss: 0.2952
Epoch 14, Batch 0, Loss: 0.2852
Epoch 15, Batch 0, Loss: 0.1912
Epoch 16, Batch 0, Loss: 0.2575
Epoch 17, Batch 0, Loss: 0.2461
Epoch 18, Batch 0, Loss: 0.1311
Epoch 19, Batch 0, Loss: 0.2303


BC_Net_Cartpole(
  (layers): Sequential(
    (0): Linear(in_features=4, out_features=16, bias=True)
    (1): Tanh()
    (2): Linear(in_features=16, out_features=16, bias=True)
    (3): Tanh()
    (4): Linear(in_features=16, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [20]:
undertrained_model = BC_Net_Cartpole(hidden_size=16)
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(undertrained_model.parameters(), lr=0.0001)

num_epochs = 2

for epoch in range(num_epochs):
    for batch_idx, (inputs, targets) in enumerate(train_dataloader):
        # Forward pass
        outputs = undertrained_model(inputs)
        loss = loss_fn(outputs, targets)
        # Backward pass and optimization
        optimizer.zero_grad()  
        loss.backward()         
        optimizer.step()        
        if batch_idx % 500 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

undertrained_model.eval()

Epoch 0, Batch 0, Loss: 0.6878
Epoch 1, Batch 0, Loss: 0.6380


BC_Net_Cartpole(
  (layers): Sequential(
    (0): Linear(in_features=4, out_features=16, bias=True)
    (1): Tanh()
    (2): Linear(in_features=16, out_features=16, bias=True)
    (3): Tanh()
    (4): Linear(in_features=16, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

### Define Custom Policy Class That Takes a Neural Net as a Param to Predict
#### NOTE: I transform the network output to a Long, since I want to test "deterministic" policies as is the case for our RL Shrinkage problem

In [13]:
from typing import Dict, Tuple
from numpy import ndarray
from stable_baselines3.common.policies import BasePolicy

# define custom policy class that takes a net as a param to predict
# is basically actor critic without the critic!?
class CustomPolicy(BasePolicy):
    def __init__(self, observation_space, action_space, behavioral_cloning_net):
        super().__init__(observation_space=observation_space, action_space=action_space)
        self.net = behavioral_cloning_net
        self.action_space = action_space

    def _predict(self, 
                 observation,
                 state=None,
                 episode_start=None,
                 deterministic=True,):
        return self.net.forward(observation).round().long()
    


#### Evaluate trained and untrained BC models on the cartpole problem

In [15]:
model.eval()
base_BC_policy = CustomPolicy(env.observation_space, env.action_space, model)

reward, _ = evaluate_policy(base_BC_policy, env, 10)
print(reward)  

500.0


In [21]:
undertrained_model.eval()
untrained_base_BC_policy = CustomPolicy(env.observation_space, env.action_space, undertrained_model)

reward, _ = evaluate_policy(untrained_base_BC_policy, env, 10)
print(reward)  


51.7


In [16]:
env = gym.make("CartPole-v1", render_mode='human')

vec_env = DummyVecEnv([lambda: env])

In [22]:
from stable_baselines3.common.vec_env import DummyVecEnv

vec_env = DummyVecEnv([lambda: env])
obs = vec_env.reset()
for i in range(400):
    action, _states = base_BC_policy.predict(obs, deterministic=True)
    obs, rewards, dones, info = vec_env.step(action)
    vec_env.render("human")

In [23]:
## LESS TRAINED MODEL
from stable_baselines3.common.vec_env import DummyVecEnv

vec_env = DummyVecEnv([lambda: env])
obs = vec_env.reset()
for i in range(400):
    action, _states = untrained_base_BC_policy.predict(obs, deterministic=True)
    obs, rewards, dones, info = vec_env.step(action)
    vec_env.render("human")

# Generative Adversarial Imitation Learning

##### Basic idea:
- learn a policy from expert demonstrations by matching the expert's occupancy measure (distribution of state-action pairs). By doing so, it bypasses the indirect inverse reinforcement learning (IRL) step, which would first recover a cost function and the perform reinforcement learning.

##### Adversarial Learning Process:
- a Discriminator Network is trained to distinguish between state-action pairs generated by policy (learner) and those from the expert
- the policy is simultaneously trained to "fool" this discriminator, meaning it tries to replicate the expert's policy

##### The Algorithm:
- alternates between two steps:
    1. updating the discriminator params
    2. updating the policy params using a Trust Region Policy Optimization (TRPO) step, where the discriminator's output acts as a cost function, guiding the policy towards expert-like behavior

In [66]:
class DiscriminatorNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size) -> None:
        super().__init__()
        self.layers = torch.nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid() # 0 --> expert, 1--> learned policy
        )
    
    def forward(self, x):
        return self.layers(x).flatten()
    

class PolicyNetworkCartPole(nn.Module):
    def __init__(self, input_size, hidden_size, output_size) -> None:
        super().__init__()
        self.layers = torch.nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size),
            nn.Sigmoid() # 0 or 1 (left or right)
        )
    
    def forward(self, x):
        return self.layers(x).flatten()
    
    def entropy(self, x):
        return 0




In [107]:
# init env, policy network, and discriminator
env = gym.make("CartPole-v1")
policy_net = PolicyNetworkCartPole(input_size=4, hidden_size=32, output_size=1)
# Q: what input size does the discriminator take? a whole trajectory? fixed number of steps, i.e. 128?
# I would say a fixed number of steps, but actually just x = (s,a) pairs in batches, and y in {0,1}
discriminator = DiscriminatorNetwork(input_size=4+1, hidden_size=32, output_size=1)

#### Generate expert policy and corresponding rollout

In [57]:
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv
import numpy as np

rng = np.random.default_rng()
rollouts = rollout.rollout(
    expert,
    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
    rollout.make_sample_until(min_timesteps=None, min_episodes=50),
    rng=rng,
)
expert_transitions = rollout.flatten_trajectories(rollouts)

##### write a policy wrapper to wrap the learned policy s.t. it works with the imitation library

In [None]:
# define policy wrapper 
from stable_baselines3.common.policies import BasePolicy
class PolicyWrapper(BasePolicy):
    """
    not needed for training, just for rollouts
    takes the actual trained policy net as parameter 
    """
    def __init__(self, observation_space, action_space, net):
        super().__init__(observation_space=observation_space, action_space=action_space)
        self.net = net
        self.action_space = action_space

    def _predict(self, 
                 observation,
                 state=None,
                 episode_start=None,
                 deterministic=True,):
        return self.net.forward(observation).round().long()
wrapped_policy = PolicyWrapper(env.observation_space, env.action_space, policy_net)

In [42]:
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

rng = np.random.default_rng()
rollouts = rollout.rollout(
    wrapped_policy,
    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
    rollout.make_sample_until(min_timesteps=500, min_episodes=1),
    rng=rng,
)
learner_transitions = rollout.flatten_trajectories(rollouts)

reward, _ = evaluate_policy(wrapped_policy, env, 10)
print("Sanity check: before any learning, evaluate the policy:")
print(reward)  

Sanity check: before any learning, evaluate the policy:
9.2


In [65]:
for x in expert_transitions:
    print(x.keys())
    break

dict_keys(['obs', 'acts', 'infos', 'next_obs', 'dones'])


##### helpers needed for GAIL Trainer

In [218]:
from torch.utils.data import DataLoader, Dataset
class create_discriminator_learning_ds(Dataset):
    def __init__(self, learner_s_a, expert_s_a):
        learner_s_a = np.concatenate([np.array(learner_s_a[0]), np.array(learner_s_a[1]).reshape(-1, 1)], axis=1)
        expert_s_a =  np.concatenate([np.array(expert_s_a[0]), np.array(expert_s_a[1]).reshape(-1, 1)], axis=1)

        self.s_a_full = np.concatenate([learner_s_a, expert_s_a], axis=0)
        self.y_full = np.concatenate([np.repeat(0, learner_s_a.shape[0]), np.repeat(1, expert_s_a.shape[0])])

        self.s_a_full = torch.tensor(self.s_a_full, dtype=torch.float32)
        self.y_full = torch.tensor(self.y_full, dtype=torch.float32)
        

    def __len__(self):
        return len(self.s_a_full)

    def __getitem__(self, idx):
        return self.s_a_full[idx], self.y_full[idx]
    
class BasicTorchDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class ExpertDataset():
    """
    takes rollouts as transitions from imitation package
    saves them as states and actions so no need to always iterate thru transitions
    """
    def __init__(self, expert_transitions) -> None:
        self.states = np.array([x['obs'] for x in expert_transitions])
        self.actions = np.array([x['acts'] for x in expert_transitions])
        
expert_ds = ExpertDataset(expert_transitions=expert_transitions)

In [229]:
class GAILTrainer:
    def __init__(self, env, policy, discriminator, expert_ds, disc_optimizer, disc_loss_fn, λ=1e-3):
        self.env = env
        self.policy = policy
        self.discriminator = discriminator
        self.expert_ds = expert_ds
        self.disc_optimizer = disc_optimizer
        self.disc_loss_fn = disc_loss_fn

    def update_discriminator(self, learner_states_actions, expert_states_actions, NUM_EPOCHS=5):
        discriminator_train_ds = create_discriminator_learning_ds(learner_states_actions, expert_states_actions)
        discriminator_train_ds = DataLoader(discriminator_train_ds, shuffle=True, batch_size=32)
        discriminator.train()
        for epoch in range(NUM_EPOCHS):
            for batch_idx, (inputs, targets) in enumerate(discriminator_train_ds):
                outputs = discriminator(inputs)
                loss = self.disc_loss_fn(outputs, targets)
                self.disc_optimizer.zero_grad()
                loss.backward()
                self.disc_optimizer.step()
                if batch_idx % 10 == 0:
                    print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        discriminator.eval()
        
    def sample_from_policy(self, min_timesteps=500, min_episodes=1):
        rng = np.random.default_rng()
        rollouts = rollout.rollout(
            wrapped_policy,
            DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
            rollout.make_sample_until(min_timesteps=min_timesteps, min_episodes=min_episodes),
            rng=rng,
        )
        learner_transitions = rollout.flatten_trajectories(rollouts)
        learner_states = [x['obs'] for x in learner_transitions]
        learner_acts = [x['acts'] for x in learner_transitions]
        return learner_states, learner_acts

    def update_policy(self, learner_states_actions, expert_states_actions, NUM_EPOCHS=5):
        # PPO / TRPO step with rewards = log D
        pass


    def train(self, iterations=100):
        for i in range(iterations):
            learner_states_actions = self.sample_from_policy()
            rnd_idces = np.random.choice(np.arange(0, len(self.expert_ds.states)), size=len(learner_states_actions[0]))
            expert_states_actions = self.expert_ds.states[rnd_idces], self.expert_ds.actions[rnd_idces]

            self.update_discriminator(learner_states_actions, expert_states_actions)
            self.update_policy(learner_states_actions, expert_states_actions)


    

In [257]:
# given a trained discriminator; train policy and see if discriminator got worse

env = gym.make("CartPole-v1")
policy_net = PolicyNetworkCartPole(input_size=4, hidden_size=32, output_size=1)
policy_optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.001)
# policy_loss_fn = nn.BCELoss()

# Q: what input size does the discriminator take? a whole trajectory? fixed number of steps, i.e. 128?
# I would say a fixed number of steps, but actually just x = (s,a) pairs in batches, and y in {0,1}
discriminator = DiscriminatorNetwork(input_size=4+1, hidden_size=32, output_size=1)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)
discriminator_loss_fn = nn.BCELoss()

expert_transitions = rollout.flatten_trajectories(rollouts)


gail_trainer = GAILTrainer(
    env=DummyVecEnv([lambda: RolloutInfoWrapper(env)]), 
    policy=wrapped_policy, 
    discriminator=discriminator, 
    expert_ds=expert_ds,
    disc_optimizer=discriminator_optimizer, 
    disc_loss_fn=discriminator_loss_fn, 
    λ=1e-3)

In [258]:
policy_samples1 = gail_trainer.sample_from_policy()
rnd_idces = np.random.choice(np.arange(0, len(expert_ds.states)), size=len(policy_samples1[0]))
expert_states_actions = expert_ds.states[rnd_idces], expert_ds.actions[rnd_idces]
gail_trainer.update_discriminator( (np.array(policy_samples1[0]), np.array(policy_samples1[1])), expert_states_actions)

Epoch 0, Batch 0, Loss: 0.6520
Epoch 0, Batch 10, Loss: 0.5791
Epoch 0, Batch 20, Loss: 0.5012
Epoch 0, Batch 30, Loss: 0.4948
Epoch 1, Batch 0, Loss: 0.4202
Epoch 1, Batch 10, Loss: 0.4098
Epoch 1, Batch 20, Loss: 0.3274
Epoch 1, Batch 30, Loss: 0.2768
Epoch 2, Batch 0, Loss: 0.2325
Epoch 2, Batch 10, Loss: 0.2813
Epoch 2, Batch 20, Loss: 0.2569
Epoch 2, Batch 30, Loss: 0.2293
Epoch 3, Batch 0, Loss: 0.1921
Epoch 3, Batch 10, Loss: 0.1144
Epoch 3, Batch 20, Loss: 0.1291
Epoch 3, Batch 30, Loss: 0.1741
Epoch 4, Batch 0, Loss: 0.1152
Epoch 4, Batch 10, Loss: 0.1781
Epoch 4, Batch 20, Loss: 0.1259
Epoch 4, Batch 30, Loss: 0.1824


In [242]:
policy_samples = gail_trainer.sample_from_policy(min_timesteps=5000)
policy_train_ds = BasicTorchDataset(
             np.array(policy_samples[0]),  
             np.array(policy_samples[1]).reshape(-1, 1)
             )
policy_train_ds = DataLoader(policy_train_ds, shuffle=True, batch_size=32)
for batch_idx, (states, actions) in enumerate(policy_train_ds):
    outputs = discriminator((torch.concat([states, actions], dim=1))).log() * -1
    break

In [250]:
dist = torch.distributions.Binomial(probs=policy_net(states))

In [256]:
dist.log_prob(dist.sample()).mean()

tensor(-0.6904, grad_fn=<MeanBackward0>)

In [262]:
# now update policy, and check if we have larger loss on the same data

def update_policy(NUM_EPOCHS=20):
    # PPO / TRPO step with rewards = log D
        policy_samples = gail_trainer.sample_from_policy(min_timesteps=5000)
        policy_train_ds = BasicTorchDataset(
             np.array(policy_samples[0]),  
             np.array(policy_samples[1]).reshape(-1, 1)
             )
        policy_train_ds = DataLoader(policy_train_ds, shuffle=True, batch_size=32)
        policy_net.train()
        for epoch in range(NUM_EPOCHS):
            for batch_idx, (states, actions) in enumerate(policy_train_ds):
                with torch.no_grad():
                    outputs = discriminator((torch.concat([states, actions], dim=1))).log() * -1
                    rewards = outputs.sum()

                dist = torch.distributions.Binomial(probs=policy_net(states))
                sampled_acts = dist.sample()
                logp = dist.log_prob(sampled_acts)

                loss = -(logp * rewards).mean()
                policy_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10.)  # gradient clipping
                policy_optimizer.step()
                if batch_idx % 10 == 0:
                    print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        policy_net.eval()


update_policy(NUM_EPOCHS=10)

Epoch 0, Batch 0, Loss: 100.9789
Epoch 0, Batch 10, Loss: 102.7491
Epoch 0, Batch 20, Loss: 96.0607
Epoch 0, Batch 30, Loss: 88.0709
Epoch 0, Batch 40, Loss: 98.2936
Epoch 0, Batch 50, Loss: 100.2688
Epoch 0, Batch 60, Loss: 97.6443
Epoch 0, Batch 70, Loss: 84.3338
Epoch 0, Batch 80, Loss: 106.9212
Epoch 0, Batch 90, Loss: 99.8957
Epoch 0, Batch 100, Loss: 100.8091
Epoch 0, Batch 110, Loss: 96.0502
Epoch 0, Batch 120, Loss: 103.0953
Epoch 0, Batch 130, Loss: 102.7319
Epoch 0, Batch 140, Loss: 93.4437
Epoch 0, Batch 150, Loss: 102.6603
Epoch 1, Batch 0, Loss: 87.7436
Epoch 1, Batch 10, Loss: 102.4170
Epoch 1, Batch 20, Loss: 100.3865
Epoch 1, Batch 30, Loss: 94.2882
Epoch 1, Batch 40, Loss: 95.7877
Epoch 1, Batch 50, Loss: 100.3239
Epoch 1, Batch 60, Loss: 89.4863
Epoch 1, Batch 70, Loss: 98.0538
Epoch 1, Batch 80, Loss: 104.3178
Epoch 1, Batch 90, Loss: 100.1101
Epoch 1, Batch 100, Loss: 90.3274
Epoch 1, Batch 110, Loss: 110.9038
Epoch 1, Batch 120, Loss: 94.2125
Epoch 1, Batch 130, Lo

### Final GAIL Class: 
- add all the functionalities into the GAIL Trainer class
- Helpers from Above

In [375]:
from torch.utils.data import DataLoader, Dataset
class create_discriminator_learning_ds(Dataset):
    def __init__(self, learner_s_a, expert_s_a):
        learner_s_a = np.concatenate([np.array(learner_s_a[0]), np.array(learner_s_a[1]).reshape(-1, 1)], axis=1)
        expert_s_a =  np.concatenate([np.array(expert_s_a[0]), np.array(expert_s_a[1]).reshape(-1, 1)], axis=1)

        self.s_a_full = np.concatenate([learner_s_a, expert_s_a], axis=0)
        self.y_full = np.concatenate([np.repeat(0, learner_s_a.shape[0]), np.repeat(1, expert_s_a.shape[0])])

        self.s_a_full = torch.tensor(self.s_a_full, dtype=torch.float32)
        self.y_full = torch.tensor(self.y_full, dtype=torch.float32)
        

    def __len__(self):
        return len(self.s_a_full)

    def __getitem__(self, idx):
        return self.s_a_full[idx], self.y_full[idx]
    
class BasicTorchDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class ExpertDataset():
    """
    takes rollouts as transitions from imitation package
    saves them as states and actions so no need to always iterate thru transitions
    """
    def __init__(self, expert_transitions) -> None:
        self.states = np.array([x['obs'] for x in expert_transitions])
        self.actions = np.array([x['acts'] for x in expert_transitions])
        
expert_ds = ExpertDataset(expert_transitions=expert_transitions)

In [376]:
class GAILTrainer:
    def __init__(self, env, policy, discriminator, expert_ds, disc_optimizer, disc_loss_fn, λ=1e-3):
        self.env = env
        self.policy = policy
        self.discriminator = discriminator
        self.expert_ds = expert_ds
        self.disc_optimizer = disc_optimizer
        self.disc_loss_fn = disc_loss_fn

    def update_discriminator(self, learner_states_actions, expert_states_actions, NUM_EPOCHS=5):
        discriminator_train_ds = create_discriminator_learning_ds(learner_states_actions, expert_states_actions)
        discriminator_train_ds = DataLoader(discriminator_train_ds, shuffle=True, batch_size=32)
        discriminator.train()
        for epoch in range(NUM_EPOCHS):
            for batch_idx, (inputs, targets) in enumerate(discriminator_train_ds):
                outputs = discriminator(inputs)
                loss = self.disc_loss_fn(outputs, targets)
                self.disc_optimizer.zero_grad()
                loss.backward()
                self.disc_optimizer.step()
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        discriminator.eval()
    
    def discriminator_forward_pass(self, learner_states_actions, expert_states_actions):
        pass
        
    def sample_from_policy(self, min_timesteps=500, min_episodes=1):
        rng = np.random.default_rng()
        rollouts = rollout.rollout(
            wrapped_policy,
            DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
            rollout.make_sample_until(min_timesteps=min_timesteps, min_episodes=min_episodes),
            rng=rng,
        )
        learner_transitions = rollout.flatten_trajectories(rollouts)
        learner_states = [x['obs'] for x in learner_transitions]
        learner_acts = [x['acts'] for x in learner_transitions]
        return learner_states, learner_acts

    def update_policy(self, NUM_EPOCHS=5):
    # PPO / TRPO step with rewards = log D
        policy_samples = self.sample_from_policy(min_timesteps=5000)
        policy_train_ds = BasicTorchDataset(
             np.array(policy_samples[0]),  
             np.array(policy_samples[1]).reshape(-1, 1)
             )
        policy_train_ds = DataLoader(policy_train_ds, shuffle=True, batch_size=32)
        policy_net.train()
        for epoch in range(NUM_EPOCHS):
            for batch_idx, (states, actions) in enumerate(policy_train_ds):
                with torch.no_grad():
                    outputs = discriminator((torch.concat([states, actions], dim=1))).log() * -1
                    rewards = outputs.sum()

                dist = torch.distributions.Binomial(probs=policy_net(states))
                sampled_acts = dist.sample()
                logp = dist.log_prob(sampled_acts)

                loss = -(logp * rewards).mean()
                policy_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10.)  # gradient clipping
                policy_optimizer.step()
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        policy_net.eval()


    def train(self, iterations=100):
        for i in range(iterations):
            learner_states_actions = self.sample_from_policy()
            rnd_idces = np.random.choice(np.arange(0, len(self.expert_ds.states)), size=len(learner_states_actions[0]))
            expert_states_actions = self.expert_ds.states[rnd_idces], self.expert_ds.actions[rnd_idces]
            print(f"Update iteration number {i}")
            print("updating discriminator...")
            self.update_discriminator(learner_states_actions, expert_states_actions)
            print("done!")
            print("updating policy...")
            self.update_policy()
            print("done!")

In [None]:
class GAILTrainer:
    def __init__(self, env, policy, discriminator, expert_ds, disc_optimizer, disc_loss_fn, λ=1e-3):
        self.env = env
        self.policy = policy
        self.discriminator = discriminator
        self.expert_ds = expert_ds
        self.disc_optimizer = disc_optimizer
        self.disc_loss_fn = disc_loss_fn

    def update_discriminator(self, learner_states_actions, expert_states_actions, NUM_EPOCHS=5):
        discriminator_train_ds = create_discriminator_learning_ds(learner_states_actions, expert_states_actions)
        discriminator_train_ds = DataLoader(discriminator_train_ds, shuffle=True, batch_size=32)
        discriminator.train()
        for epoch in range(NUM_EPOCHS):
            for batch_idx, (inputs, targets) in enumerate(discriminator_train_ds):
                outputs = discriminator(inputs)
                loss = self.disc_loss_fn(outputs, targets)
                self.disc_optimizer.zero_grad()
                loss.backward()
                self.disc_optimizer.step()
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        discriminator.eval()
    
    def discriminator_forward_pass(self, learner_states_actions, expert_states_actions):
        pass
        
    def sample_from_policy(self, min_timesteps=500, min_episodes=1):
        rng = np.random.default_rng()
        rollouts = rollout.rollout(
            wrapped_policy,
            DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
            rollout.make_sample_until(min_timesteps=min_timesteps, min_episodes=min_episodes),
            rng=rng,
        )
        learner_transitions = rollout.flatten_trajectories(rollouts)
        learner_states = [x['obs'] for x in learner_transitions]
        learner_acts = [x['acts'] for x in learner_transitions]
        return learner_states, learner_acts

    def update_policy(self, NUM_EPOCHS=5):
    # PPO / TRPO step with rewards = log D
        policy_samples = self.sample_from_policy(min_timesteps=5000)
        policy_train_ds = BasicTorchDataset(
             np.array(policy_samples[0]),  
             np.array(policy_samples[1]).reshape(-1, 1)
             )
        policy_train_ds = DataLoader(policy_train_ds, shuffle=True, batch_size=32)
        policy_net.train()
        for epoch in range(NUM_EPOCHS):
            for batch_idx, (states, actions) in enumerate(policy_train_ds):
                with torch.no_grad():
                    outputs = discriminator((torch.concat([states, actions], dim=1)))

                rewards = -torch.log(torch.clamp(1. - outputs, 1e-6))

                dist = torch.distributions.Bernoulli(probs=policy_net(states))
                #sampled_acts = dist.sample()
                logp = dist.log_prob(actions)

                loss = -(logp.squeeze() * rewards).mean()
                policy_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10.)  # gradient clipping
                policy_optimizer.step()
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        policy_net.eval()


    def train(self, iterations=100):
        for i in range(iterations):
            learner_states_actions = self.sample_from_policy()
            rnd_idces = np.random.choice(np.arange(0, len(self.expert_ds.states)), size=len(learner_states_actions[0]))
            expert_states_actions = self.expert_ds.states[rnd_idces], self.expert_ds.actions[rnd_idces]
            print(f"Update iteration number {i}")
            print("updating discriminator...")
            self.update_discriminator(learner_states_actions, expert_states_actions)
            print("done!")
            print("updating policy...")
            self.update_policy()
            print("done!")

##### version 3
mix switch up 0 and 1 in expert and learner

In [412]:
from torch.utils.data import DataLoader, Dataset
class create_discriminator_learning_ds(Dataset):
    def __init__(self, learner_s_a, expert_s_a):
        learner_s_a = np.concatenate([np.array(learner_s_a[0]), np.array(learner_s_a[1]).reshape(-1, 1)], axis=1)
        expert_s_a =  np.concatenate([np.array(expert_s_a[0]), np.array(expert_s_a[1]).reshape(-1, 1)], axis=1)

        self.s_a_full = np.concatenate([learner_s_a, expert_s_a], axis=0)
        self.y_full = np.concatenate([np.repeat(1, learner_s_a.shape[0]), np.repeat(0, expert_s_a.shape[0])])

        self.s_a_full = torch.tensor(self.s_a_full, dtype=torch.float32)
        self.y_full = torch.tensor(self.y_full, dtype=torch.float32)
        

    def __len__(self):
        return len(self.s_a_full)

    def __getitem__(self, idx):
        return self.s_a_full[idx], self.y_full[idx]
    
class BasicTorchDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class ExpertDataset():
    """
    takes rollouts as transitions from imitation package
    saves them as states and actions so no need to always iterate thru transitions
    """
    def __init__(self, expert_transitions) -> None:
        self.states = np.array([x['obs'] for x in expert_transitions])
        self.actions = np.array([x['acts'] for x in expert_transitions])
        
expert_ds = ExpertDataset(expert_transitions=expert_transitions)

In [531]:
# version 3

class GAILTrainer:
    def __init__(self, env, policy, discriminator, expert_ds, disc_optimizer, disc_loss_fn, λ=1e-3):
        self.env = env
        self.policy = policy
        self.discriminator = discriminator
        self.expert_ds = expert_ds
        self.disc_optimizer = disc_optimizer
        self.disc_loss_fn = disc_loss_fn

    def update_discriminator(self, learner_states_actions, expert_states_actions, NUM_EPOCHS=5):
        discriminator_train_ds = create_discriminator_learning_ds(learner_states_actions, expert_states_actions)
        discriminator_train_ds = DataLoader(discriminator_train_ds, shuffle=True, batch_size=32)
        discriminator.train()
        for epoch in range(NUM_EPOCHS):
            for batch_idx, (inputs, targets) in enumerate(discriminator_train_ds):
                outputs = discriminator(inputs)
                loss = self.disc_loss_fn(outputs, targets)
                self.disc_optimizer.zero_grad()
                loss.backward()
                self.disc_optimizer.step()
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        discriminator.eval()
    
    def discriminator_forward_pass(self, learner_states_actions, expert_states_actions):
        pass
        
    def sample_from_policy(self, min_timesteps=500, min_episodes=1):
        rng = np.random.default_rng()
        rollouts = rollout.rollout(
            wrapped_policy,
            DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
            rollout.make_sample_until(min_timesteps=min_timesteps, min_episodes=min_episodes),
            rng=rng,
        )
        learner_transitions = rollout.flatten_trajectories(rollouts)
        learner_states = [x['obs'] for x in learner_transitions]
        learner_acts = [x['acts'] for x in learner_transitions]
        return learner_states, learner_acts

    def update_policy(self, NUM_EPOCHS=5):
    # PPO / TRPO step with rewards = log D
        policy_samples = self.sample_from_policy(min_timesteps=5000)
        policy_train_ds = BasicTorchDataset(
             np.array(policy_samples[0]),  
             np.array(policy_samples[1]).reshape(-1, 1)
             )
        policy_train_ds = DataLoader(policy_train_ds, shuffle=True, batch_size=32)
        policy_net.train()
        for epoch in range(NUM_EPOCHS):
            for batch_idx, (states, actions) in enumerate(policy_train_ds):
                with torch.no_grad():
                    outputs = discriminator((torch.concat([states, actions], dim=1)))

                rewards = -torch.log(1. - outputs)

                dist = torch.distributions.Bernoulli(probs=policy_net(states))
                #sampled_acts = dist.sample()
                logp = dist.log_prob(actions)

                loss = -(logp.squeeze() * rewards).mean()
                policy_optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10.)  # gradient clipping
                policy_optimizer.step()
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        policy_net.eval()


    def train(self, iterations=100):
        for i in range(iterations):
            learner_states_actions = self.sample_from_policy()
            rnd_idces = np.random.choice(np.arange(0, len(self.expert_ds.states)), size=len(learner_states_actions[0]))
            expert_states_actions = self.expert_ds.states[rnd_idces], self.expert_ds.actions[rnd_idces]
            print(f"Update iteration number {i}")
            print("updating discriminator...")
            self.update_discriminator(learner_states_actions, expert_states_actions)
            print("done!")
            print("updating policy...")
            self.update_policy()
            print("done!")

In [532]:
env = gym.make("CartPole-v1")
policy_net = PolicyNetworkCartPole(input_size=4, hidden_size=32, output_size=1)
policy_optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.001)
wrapped_policy = PolicyWrapper(
    observation_space=env.observation_space,
    action_space=env.action_space,
    net=policy_net
)
# policy_loss_fn = nn.BCELoss()

# Q: what input size does the discriminator take? a whole trajectory? fixed number of steps, i.e. 128?
# I would say a fixed number of steps, but actually just x = (s,a) pairs in batches, and y in {0,1}
discriminator = DiscriminatorNetwork(input_size=4+1, hidden_size=32, output_size=1)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0001)
discriminator_loss_fn = nn.BCELoss()

expert_transitions = rollout.flatten_trajectories(rollouts)

gail_trainer = GAILTrainer(
    env=DummyVecEnv([lambda: RolloutInfoWrapper(env)]), 
    policy=wrapped_policy, 
    discriminator=discriminator, 
    expert_ds=expert_ds,
    disc_optimizer=discriminator_optimizer, 
    disc_loss_fn=discriminator_loss_fn, 
    λ=1e-3)

#gail_trainer.train(5)



In [554]:
learner_states_actions = gail_trainer.sample_from_policy(min_timesteps=2500)
rnd_idces = np.random.choice(np.arange(0, len(gail_trainer.expert_ds.states)), size=len(learner_states_actions[0]))
expert_states_actions = gail_trainer.expert_ds.states[rnd_idces], gail_trainer.expert_ds.actions[rnd_idces]
print("updating discriminator...")
gail_trainer.update_discriminator(learner_states_actions, expert_states_actions, NUM_EPOCHS=4)

updating discriminator...
Epoch 0, Batch 156, Loss: 0.2353
Epoch 1, Batch 156, Loss: 0.0562
Epoch 2, Batch 156, Loss: 0.0880
Epoch 3, Batch 156, Loss: 0.1188


In [552]:
idx = np.random.choice(expert_ds.states.shape[0], 32)
exp_states = expert_ds.states[idx]
exp_actions = expert_ds.actions[idx]

learner_s_a = gail_trainer.sample_from_policy(2000)
idx = np.random.choice(len(learner_s_a[0]), 32)

learner_state = np.array(learner_s_a[0])[idx]
learner_act = np.array(learner_s_a[1])[idx]


discriminator_train_ds = create_discriminator_learning_ds((learner_state, learner_act), (exp_states, exp_actions))
discriminator_train_ds = DataLoader(discriminator_train_ds, shuffle=False, batch_size=64)


discriminator_train_ds = create_discriminator_learning_ds((learner_state, learner_act), (learner_state, learner_act))
discriminator_train_ds = DataLoader(discriminator_train_ds, shuffle=False, batch_size=64)
discriminator.eval()
with torch.no_grad():
    for i, (X, y) in enumerate(discriminator_train_ds):
        with torch.no_grad():
            out = discriminator(X)
        break

out  # expert get values close to 0, policy gets values close to 1, hence need to modify policy s.t. its not that easy to learn

tensor([0.4969, 0.9715, 0.9538, 0.6780, 0.6981, 0.9753, 0.9593, 0.8414, 0.9133,
        0.8374, 0.4945, 0.4809, 0.6451, 0.4226, 0.8381, 0.9510, 0.9757, 0.9500,
        0.4658, 0.9666, 0.9796, 0.9464, 0.8149, 0.9786, 0.9516, 0.9793, 0.9805,
        0.4471, 0.9737, 0.6328, 0.6374, 0.9399, 0.4969, 0.9715, 0.9538, 0.6780,
        0.6981, 0.9753, 0.9593, 0.8414, 0.9133, 0.8374, 0.4945, 0.4809, 0.6451,
        0.4226, 0.8381, 0.9510, 0.9757, 0.9500, 0.4658, 0.9666, 0.9796, 0.9464,
        0.8149, 0.9786, 0.9516, 0.9793, 0.9805, 0.4471, 0.9737, 0.6328, 0.6374,
        0.9399])

In [553]:
out.sum()

tensor(51.5545)

In [545]:
out.sum()

tensor(54.5504)

In [556]:

policy_net.train()

for epoch in range(500):
    policy_samples = gail_trainer.sample_from_policy(min_timesteps=500)
    policy_train_ds = BasicTorchDataset(
             np.array(policy_samples[0]),  
             np.array(policy_samples[1]).reshape(-1, 1)
             )
    policy_train_ds = DataLoader(policy_train_ds, shuffle=True, batch_size=64)
    for batch_idx, (states, actions) in enumerate(policy_train_ds):
        with torch.no_grad():
            outputs = discriminator((torch.concat([states, actions], dim=1)))  #outputs should all be close to 1
            
        rewards = -torch.log(outputs + 1e-8)  # rewards need to be large when D thinks its expert (even tho its not here!)
        dist = torch.distributions.Bernoulli(logits=policy_net(states))
        logp = dist.log_prob(actions.squeeze())


        loss = -1*(logp.squeeze() * rewards).mean()

        if batch_idx == 0:
            print(f"outputs {outputs}\n")
            print(f"loss {loss}\n")
            print(f"rewards {rewards}\n")

        policy_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 10.)  # gradient clipping
        policy_optimizer.step()

        if batch_idx % (64*10) == 0:
            print(f"outputs {outputs}\n")
            print(f"loss {loss}\n")
            print(f"rewards {rewards}\n")
            

    print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

policy_net.eval()

outputs tensor([0.9943, 0.9952, 0.8558, 0.9950, 0.9956, 0.9954, 0.9908, 0.9944, 0.8732,
        0.4948, 0.9612, 0.9831, 0.9903, 0.9932, 0.9641, 0.4232, 0.9843, 0.9929,
        0.9949, 0.9900, 0.8748, 0.9942, 0.9521, 0.9892, 0.9949, 0.9624, 0.9853,
        0.9896, 0.9956, 0.9954, 0.9635, 0.9954, 0.9901, 0.9950, 0.9848, 0.8856,
        0.9948, 0.8716, 0.9890, 0.9944, 0.9556, 0.9956, 0.9940, 0.9552, 0.9896,
        0.9655, 0.9933, 0.4802, 0.9942, 0.4580, 0.9928, 0.9816, 0.9956, 0.5172,
        0.9549, 0.9930, 0.9950, 0.9904, 0.9486, 0.9952, 0.8692, 0.9925, 0.9942,
        0.9954])

loss 0.026037102565169334

rewards tensor([0.0057, 0.0049, 0.1557, 0.0050, 0.0044, 0.0046, 0.0093, 0.0057, 0.1356,
        0.7036, 0.0396, 0.0171, 0.0098, 0.0068, 0.0365, 0.8599, 0.0159, 0.0071,
        0.0051, 0.0101, 0.1337, 0.0058, 0.0490, 0.0109, 0.0051, 0.0383, 0.0148,
        0.0104, 0.0044, 0.0046, 0.0372, 0.0046, 0.0099, 0.0050, 0.0154, 0.1215,
        0.0052, 0.1375, 0.0111, 0.0056, 0.0454, 0.0044, 0.0

PolicyNetworkCartPole(
  (layers): Sequential(
    (0): Linear(in_features=4, out_features=32, bias=True)
    (1): Tanh()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Tanh()
    (4): Linear(in_features=32, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [462]:
# test policy 
wrapped_policy = PolicyWrapper(env.observation_space, env.action_space, policy_net)
reward, _ = evaluate_policy(wrapped_policy, env, 10)
print("After learning, evaluate the policy:")
print(reward)  

After learning, evaluate the policy:
9.2




In [308]:
reward, _ = evaluate_policy(expert, env, 10)
print(reward)

500.0


In [305]:
idx = np.random.choice(expert_ds.states.shape[0], 500)
exp_states = expert_ds.states[idx]
exp_actions = expert_ds.actions[idx]

learner_s_a = gail_trainer.sample_from_policy(500)


In [317]:
discriminator_train_ds = create_discriminator_learning_ds(learner_s_a, (exp_states, exp_actions))
discriminator_train_ds = DataLoader(discriminator_train_ds, shuffle=True, batch_size=64)
discriminator.eval()
with torch.no_grad():
    for i, (X, y) in enumerate(discriminator_train_ds):
        out = discriminator(X)

In [318]:
y

tensor([1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0.,
        0., 0., 1., 1., 0., 1., 1., 1., 1., 0., 0.])

In [319]:
out.round()

tensor([1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0.,
        0., 0., 1., 1., 0., 1., 1., 1., 1., 0., 0.])

In [337]:
expert.predict(learner_s_a[0][:32])[0]


array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 1, 0], dtype=int64)

In [338]:
learner_s_a[1][:32]

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1]

In [332]:
expert.predict(X.detach()[:, 0:4])[0]


array([1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1,
       1, 0, 0], dtype=int64)

In [331]:
expert.predict(X.detach()[:, 0:4])

[i-j for i,j in zip(expert.predict(X.detach().numpy()[:, 0:4])[0], out.round().numpy())]

[-1.0,
 0.0,
 0.0,
 0.0,
 -1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 -1.0,
 -1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 -1.0,
 1.0,
 0.0,
 0.0,
 0.0,
 -1.0,
 0.0,
 0.0,
 -1.0,
 -1.0,
 0.0,
 0.0,
 0.0,
 -1.0,
 0.0,
 0.0,
 0.0,
 -1.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 -1.0,
 0.0,
 -1.0,
 -1.0,
 0.0,
 0.0,
 0.0,
 0.0]

In [292]:
gail_trainer.sample_from_policy(500)

([array([ 0.03837674,  0.00712525,  0.01037566, -0.03539691], dtype=float32),
  array([ 0.03851925, -0.18814394,  0.00966772,  0.2605415 ], dtype=float32),
  array([ 0.03475637, -0.38340256,  0.01487855,  0.556258  ], dtype=float32),
  array([ 0.02708832, -0.5787302 ,  0.02600371,  0.8535912 ], dtype=float32),
  array([ 0.01551372, -0.77419674,  0.04307554,  1.1543361 ], dtype=float32),
  array([ 2.9781397e-05, -9.6985316e-01,  6.6162258e-02,  1.4602088e+00],
        dtype=float32),
  array([-0.01936728, -1.165721  ,  0.09536643,  1.7728053 ], dtype=float32),
  array([-0.04268171, -1.3617804 ,  0.13082254,  2.0935533 ], dtype=float32),
  array([-0.06991731, -1.5579551 ,  0.17269361,  2.4236531 ], dtype=float32),
  array([-0.03879801,  0.01556669, -0.00526927,  0.04146906], dtype=float32),
  array([-0.03848668, -0.1794793 , -0.00443989,  0.33248484], dtype=float32),
  array([-0.04207627, -0.37453777,  0.0022098 ,  0.62376434], dtype=float32),
  array([-0.04956702, -0.5696905 ,  0.014685

#### Baseline: GAIL implementation from the imitation package

In [35]:
# baseline form imitation package
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm


SEED=841
vec_env = DummyVecEnv([lambda: env])
learner = PPO(
    env=vec_env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
)

reward_net = BasicRewardNet(
    observation_space=vec_env.observation_space,
    action_space=vec_env.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=vec_env,
    gen_algo=learner,
    reward_net=reward_net,
    allow_variable_horizon=True
)

# evaluate the learner before training
vec_env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True,
)

# train the learner and evaluate again
gail_trainer.train(20000)  # Train for 800_000 steps to match expert.
vec_env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True,
)

print("mean reward after training:", np.mean(learner_rewards_after_training))
print("mean reward before training:", np.mean(learner_rewards_before_training))

Running with `allow_variable_horizon` set to True. Some algorithms are biased towards shorter or longer episodes, which may significantly confound results. Additionally, even unbiased algorithms can exploit the information leak from the termination condition, producing spuriously high performance. See https://imitation.readthedocs.io/en/latest/getting-started/variable-horizon.html for more information.


round:   0%|          | 0/9 [00:00<?, ?it/s]

--------------------------------------
| raw/                        |      |
|    gen/time/fps             | 648  |
|    gen/time/iterations      | 1    |
|    gen/time/time_elapsed    | 3    |
|    gen/time/total_timesteps | 2048 |
--------------------------------------
--------------------------------------------------
| raw/                                |          |
|    disc/disc_acc                    | 0.5      |
|    disc/disc_acc_expert             | 0        |
|    disc/disc_acc_gen                | 1        |
|    disc/disc_entropy                | 0.683    |
|    disc/disc_loss                   | 0.713    |
|    disc/disc_proportion_expert_pred | 0        |
|    disc/disc_proportion_expert_true | 0.5      |
|    disc/global_step                 | 1        |
|    disc/n_expert                    | 1.02e+03 |
|    disc/n_generated                 | 1.02e+03 |
--------------------------------------------------
--------------------------------------------------
| raw/       

round:  11%|█         | 1/9 [00:12<01:39, 12.43s/it]

-----------------------------------------------------
| raw/                               |              |
|    gen/rollout/ep_rew_wrapped_mean | 11.8         |
|    gen/time/fps                    | 634          |
|    gen/time/iterations             | 1            |
|    gen/time/time_elapsed           | 3            |
|    gen/time/total_timesteps        | 4096         |
|    gen/train/approx_kl             | 0.0070472434 |
|    gen/train/clip_fraction         | 0.0674       |
|    gen/train/clip_range            | 0.2          |
|    gen/train/entropy_loss          | -0.688       |
|    gen/train/explained_variance    | 0.0358       |
|    gen/train/learning_rate         | 0.0004       |
|    gen/train/loss                  | 0.777        |
|    gen/train/n_updates             | 5            |
|    gen/train/policy_gradient_loss  | -0.0105      |
|    gen/train/value_loss            | 8.75         |
-----------------------------------------------------
----------------------------

round:  22%|██▏       | 2/9 [00:26<01:34, 13.50s/it]

---------------------------------------------------
| raw/                               |            |
|    gen/rollout/ep_rew_wrapped_mean | 15.5       |
|    gen/time/fps                    | 625        |
|    gen/time/iterations             | 1          |
|    gen/time/time_elapsed           | 3          |
|    gen/time/total_timesteps        | 6144       |
|    gen/train/approx_kl             | 0.00750826 |
|    gen/train/clip_fraction         | 0.0521     |
|    gen/train/clip_range            | 0.2        |
|    gen/train/entropy_loss          | -0.668     |
|    gen/train/explained_variance    | 0.0743     |
|    gen/train/learning_rate         | 0.0004     |
|    gen/train/loss                  | 1.64       |
|    gen/train/n_updates             | 10         |
|    gen/train/policy_gradient_loss  | -0.016     |
|    gen/train/value_loss            | 3.47       |
---------------------------------------------------
--------------------------------------------------
| raw/       

round:  33%|███▎      | 3/9 [00:40<01:21, 13.58s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 18.9        |
|    gen/time/fps                    | 652         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 8192        |
|    gen/train/approx_kl             | 0.014666978 |
|    gen/train/clip_fraction         | 0.143       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.627      |
|    gen/train/explained_variance    | 0.237       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 1.24        |
|    gen/train/n_updates             | 15          |
|    gen/train/policy_gradient_loss  | -0.0251     |
|    gen/train/value_loss            | 3.64        |
----------------------------------------------------
----------------------------------------------

round:  44%|████▍     | 4/9 [00:55<01:10, 14.07s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 22.8        |
|    gen/time/fps                    | 663         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 10240       |
|    gen/train/approx_kl             | 0.011562597 |
|    gen/train/clip_fraction         | 0.106       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.602      |
|    gen/train/explained_variance    | 0.329       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 1.05        |
|    gen/train/n_updates             | 20          |
|    gen/train/policy_gradient_loss  | -0.0192     |
|    gen/train/value_loss            | 2.53        |
----------------------------------------------------
----------------------------------------------

round:  56%|█████▌    | 5/9 [01:07<00:54, 13.51s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 28.2        |
|    gen/time/fps                    | 648         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 12288       |
|    gen/train/approx_kl             | 0.013904316 |
|    gen/train/clip_fraction         | 0.134       |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.567      |
|    gen/train/explained_variance    | 0.614       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 1.1         |
|    gen/train/n_updates             | 25          |
|    gen/train/policy_gradient_loss  | -0.0192     |
|    gen/train/value_loss            | 1.6         |
----------------------------------------------------
----------------------------------------------

round:  67%|██████▋   | 6/9 [01:18<00:37, 12.53s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 36          |
|    gen/time/fps                    | 672         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 14336       |
|    gen/train/approx_kl             | 0.008169535 |
|    gen/train/clip_fraction         | 0.0899      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.561      |
|    gen/train/explained_variance    | 0.758       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.457       |
|    gen/train/n_updates             | 30          |
|    gen/train/policy_gradient_loss  | -0.0102     |
|    gen/train/value_loss            | 1.06        |
----------------------------------------------------
----------------------------------------------

round:  78%|███████▊  | 7/9 [01:31<00:25, 12.58s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 42.5        |
|    gen/time/fps                    | 664         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 16384       |
|    gen/train/approx_kl             | 0.006491432 |
|    gen/train/clip_fraction         | 0.0783      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.535      |
|    gen/train/explained_variance    | 0.953       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0628      |
|    gen/train/n_updates             | 35          |
|    gen/train/policy_gradient_loss  | -0.00708    |
|    gen/train/value_loss            | 0.363       |
----------------------------------------------------
----------------------------------------------

round:  89%|████████▉ | 8/9 [01:43<00:12, 12.49s/it]

----------------------------------------------------
| raw/                               |             |
|    gen/rollout/ep_rew_wrapped_mean | 48.4        |
|    gen/time/fps                    | 663         |
|    gen/time/iterations             | 1           |
|    gen/time/time_elapsed           | 3           |
|    gen/time/total_timesteps        | 18432       |
|    gen/train/approx_kl             | 0.009983783 |
|    gen/train/clip_fraction         | 0.0707      |
|    gen/train/clip_range            | 0.2         |
|    gen/train/entropy_loss          | -0.52       |
|    gen/train/explained_variance    | 0.989       |
|    gen/train/learning_rate         | 0.0004      |
|    gen/train/loss                  | 0.0226      |
|    gen/train/n_updates             | 40          |
|    gen/train/policy_gradient_loss  | -0.00563    |
|    gen/train/value_loss            | 0.0693      |
----------------------------------------------------
----------------------------------------------

round: 100%|██████████| 9/9 [01:53<00:00, 12.67s/it]


AttributeError: 'CartPoleEnv' object has no attribute 'seed'