In [9]:
import numpy as np
import torch
from torch.distributions import Categorical
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gym

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [34]:
class Policy(nn.Module):
    def __init__(self, state_size, action_size):
        super(Policy, self).__init__()
        
        self.layer_1 = nn.Linear(state_size, 32)
        self.layer_2 = nn.Linear(32, 16)
        self.critic_layer = nn.Linear(16, 1)
        self.actor_layer = nn.Linear(16, action_size)
    
    def forward(self, state):
        if state.ndim == 1:
            input =  torch.from_numpy(state).float().unsqueeze(0).to("cpu")
        else:
            input =  torch.from_numpy(state).float().to("cpu")
        
        x = F.relu(self.layer_1(input))
        x = F.relu(self.layer_2(x))
        
        return self.actor_head(x), self.critic_layer(x)
    
    def actor_head(self, input_state):
        x = self.actor_layer(input_state)
        #print("input state = ", input_state)
        #print(" x = ", x)
        prob = Categorical(logits = x)
        #print("prob = ", prob)
        action = prob.sample()
        log_prob = prob.log_prob(action).unsqueeze(-1)
        entropy = prob.entropy().unsqueeze(-1)
        action = action.item() if len(action) == 1 else action.data.numpy()
        return {"action": action, "log_prob":log_prob, "entropy":entropy}

In [35]:
env = gym.make('CartPole-v0')
old_policy = Policy(4, 2)
new_policy = Policy(4, 2)
optimizer = optim.Adam(new_policy.parameters(), lr=2e-4)

In [36]:
def collect_trajectory(policy, env, max_steps):
    rewards = []
    values = []
    states = []
    dones = []
    actions = []
    entropys = []
    log_probs = []

    i = 0
    done = True
    state = env.reset()
    reward_sum = 0
    
    while True:
        actor, value = policy(state)
        
        if i > 0 and i >= max_steps:
            yield {"states": states, "actions": actions, "rewards": rewards, "log_probs": log_probs,
                  "last_value": value*(1-done), "dones":dones, "values":values}

        actions.append(actor["action"])
        states.append(state)
        values.append(value)
        action = actor["action"]
        
        state, reward, done, _ = env.step(action)

        log_probs.append(actor["log_prob"])
        
        dones.append(done)
        rewards.append(reward)
        reward_sum += reward
        i+=1
        if done:
            reward_sum = 0
            state = env.reset()

def calculate_gae(obs, gamma, tau):
    values = obs["values"]
    values.append(obs["last_value"])
    rewards = obs["rewards"]
    dones = obs["dones"]
    gae = torch.zeros(len(rewards))
    refs = []
    
    last_gae = 0
    for i in reversed(range(len(rewards))):
        td_error = rewards[i] + gamma*values[i+1]*(1-dones[i]) - values[i]
        gae[i] = gamma*tau*last_gae + td_error
        last_gae = gae[i]
        refs.append(last_gae - values[i])
    obs["advantages"] = torch.FloatTensor(gae).to(device)
    obs["ref"] = torch.FloatTensor(refs).to(device)

In [37]:
data = collect_trajectory(old_policy, env, 10)
observations = data.__next__()
calculate_gae(observations, 1, 1)

In [38]:
def surrogate_method(new_policy, obs, epsilon = 0.1, beta = 0.01):
    
    actions = torch.tensor(obs["actions"], dtype=torch.int8, device=device)
    states = torch.FloatTensor(obs["states"]).to(device)
    old_log_probs = torch.FloatTensor(obs["log_probs"]).to(device)
    
    actor, value = new_policy(states.cpu().numpy())
    
    new_log_probs = torch.FloatTensor(actor["log_prob"].squeeze(0)).to(device)
    action = actor["action"]
    
    ratio = torch.exp(new_log_probs - old_log_probs)
    
    normalized_adv = (obs["advantages"] - obs["advantages"].mean())/ (obs["advantages"].std() + 1e-5)
    
    cliped_ratio = torch.clamp(ratio, 1-epsilon, 1+epsilon)*normalized_adv
    no_clipped_ratio = ratio*normalized_adv
    action_loss = -torch.min(no_clipped_ratio, cliped_ratio).mean()
    policy_loss = (obs["ref"] - value.to(device)).pow(2).mean()
    
    return torch.mean(ratio + policy_loss - beta*actor["entropy"].to(device))

In [39]:
Loss = surrogate_method(new_policy, observations)
print(Loss)

tensor(41.9892, grad_fn=<MeanBackward0>)


In [40]:
def prepareEnviroment():
    return 1