In [1]:
import argparse
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
#torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [2]:
parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G', help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N', help='random seed (default: 543)')
# parser.add_argument('--render', action='store_true', help='render the environment')
parser.add_argument('--render', type=bool,default=False, help='render the environment')
parser.add_argument('--trace', type=bool,default=False, help='render the environment')
parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='interval between training status logs (default: 10)')
parser.add_argument('-f','--file',help='Path for input file. (Dummy arg to enable execution in notebook.)' )
args = parser.parse_args() 
torch.manual_seed(args.seed)

<torch._C.Generator at 0x1b1c8e7a710>

In [3]:
class World():
    
    def __init__(self):
        self.env = gym.make('CartPole-v0')
        self.env.seed(args.seed)
        self.reward = 0.0
        self.done = False
        self.reset()
        
    def reset(self):
        self.state = torch.tensor(self.env.reset(), requires_grad=False, dtype=torch.float)
        
    def action_count(self):
        return self.env.action_space.n
    
    def world_dimensions(self):
        return self.env.observation_space.shape[0]
    
    def step(self,action):
        self.state, self.reward, self.done, _ = self.env.step(action.item())
        self.state = torch.tensor(self.state, requires_grad=False, dtype=torch.float)
        if args.render: self.env.render()
        
class Actor(nn.Module):
    
    def __init__(self,world: World):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(world.world_dimensions(),128 - world.world_dimensions())
        self.head = nn.Linear(128, world.action_count())

    def forward(self, state):
        a1 = F.softplus(self.l1(state))
        #Adding bypass connections
        head = self.head(torch.cat([a1,state]))
        action_scores = F.softmax(head, dim=-1)
        if(args.trace): print("action scores:",action_scores)
        return action_scores
    
    def choose_action(self,scores):
        self.categories = Categorical(scores)
        self.action = self.categories.sample()
        if(args.trace): print("Action:",self.action.item())
        return self.action

#The "advantage" is how much better the state is after the action than we expected it would be
    def advantage_loss(self,critic,world):
        #Do not include gradient of prev_value here, just the data.
        advantage = critic.hindsight_value(world) - critic.prev_value.data
        if(args.trace): print("advantage loss:",-self.categories.log_prob(self.action)*advantage)
        return -self.categories.log_prob(self.action)*advantage
    
class Critic(nn.Module):
    
    def __init__(self,world: World):
        super(Critic, self).__init__()
        self.l1 = nn.Linear(world.world_dimensions(),128 - world.world_dimensions())
        self.head = nn.Linear(128, 1)
        self.one = torch.ones([1], requires_grad=False, dtype=torch.float)
        self.value = 0.
        
    def forward(self, state):
        self.prev_value = self.value
        a1 = F.softplus(self.l1(state))
        #Adding bypass connections
        self.value = self.head(torch.cat([a1,state]))
        return self.value
    
#What the previous value should have been knowing what we know after the last state transition
    def hindsight_value(self,world):
        #Do not include gradient of the critic value here, just the data.
        return world.reward * self.one if world.done else world.reward + (args.gamma * self.value.data)
         
#Temporal Difference Loss is for the previous state!
    def td_loss(self,world):
        loss = F.mse_loss(self.prev_value,self.hindsight_value(world))
        if(args.trace): print("Critic value and loss:",self.prev_value,loss)
        return loss

In [4]:
def train(episodes=1000):

    world.reset()
    critic(world.state)
    mave_reward = 10
    for i_episode in range(1,episodes+1):
        ep_reward = 0
        for t in range(1000):
            action_scores = actor(world.state)
            action = actor.choose_action(action_scores)
            world.step(action)
            
            ep_reward += world.reward
            critic(world.state)

            critic_optimizer.zero_grad()
            critic_loss = critic.td_loss(world)
            critic_loss.backward()                                      
            critic_optimizer.step()

            actor_optimizer.zero_grad() 
            actor_loss = actor.advantage_loss(critic,world)
            actor_loss.backward()                                      
            actor_optimizer.step()
            if(world.done): break

        mave_reward = 0.05 * ep_reward + (1 - 0.05) * mave_reward
        if i_episode % args.log_interval == 0:
            print('Episode {}\tLast reward: {:.2f}\tMoving average reward: {:.2f}'.format(
                  i_episode, ep_reward, mave_reward))
        I = 1.           
        world.reset()
        critic(world.state)
        if mave_reward > world.env.spec.reward_threshold:
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(mave_reward, t+1))
            break

In [5]:
def reset_trainer():
    args.trace = False
    args.render = False
    global world
    global actor
    global critic
    global actor_optimizer
    global critic_optimizer
    world = World()
    actor = Actor(world)
    critic = Critic(world)
    actor_optimizer = optim.Adam(actor.parameters(), lr=2.7e-4,weight_decay=0.0001)
    critic_optimizer = optim.Adam(critic.parameters(), lr=4.6e-3,weight_decay=0.0001)

In [6]:
reset_trainer()

In [7]:
args.log_interval = 100
args.trace = False
train(1000) 

Episode 100	Last reward: 13.00	Moving average reward: 13.19
Episode 200	Last reward: 20.00	Moving average reward: 18.58
Episode 300	Last reward: 15.00	Moving average reward: 20.24
Episode 400	Last reward: 92.00	Moving average reward: 30.81
Episode 500	Last reward: 18.00	Moving average reward: 38.63
Episode 600	Last reward: 58.00	Moving average reward: 53.36
Episode 700	Last reward: 48.00	Moving average reward: 62.08
Episode 800	Last reward: 12.00	Moving average reward: 46.84
Episode 900	Last reward: 45.00	Moving average reward: 77.46
Episode 1000	Last reward: 200.00	Moving average reward: 120.95


In [8]:
args.log_interval = 10
args.render = True
args.trace = False
train(100) 
print("critic prev_value, value loss, advantage loss")
print(critic.prev_value,critic.td_loss(world),actor.advantage_loss(critic,world))
print(world.done)

Episode 10	Last reward: 140.00	Moving average reward: 70.68
Episode 20	Last reward: 200.00	Moving average reward: 79.03
Episode 30	Last reward: 135.00	Moving average reward: 124.32
Episode 40	Last reward: 200.00	Moving average reward: 147.24
Episode 50	Last reward: 200.00	Moving average reward: 124.46
Episode 60	Last reward: 197.00	Moving average reward: 154.62
Episode 70	Last reward: 200.00	Moving average reward: 161.45
Episode 80	Last reward: 15.00	Moving average reward: 112.55
Episode 90	Last reward: 200.00	Moving average reward: 105.92
Episode 100	Last reward: 196.00	Moving average reward: 133.45
critic prev_value, value loss, advantage loss
tensor([42.2285], grad_fn=<AddBackward0>) tensor(1699.7902, grad_fn=<MseLossBackward>) tensor([-14.9868], grad_fn=<MulBackward0>)
True
