In [3]:
import sys
import torch  
import gym
import numpy as np  
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import obstacle_env
import time
import pickle

# Constants
GAMMA = 0.99
init = True
model_path = 'reinforce_weights.pt'
save = False

class PolicyNetwork(nn.Module):

    def __init__(self, num_inputs, num_actions, hidden_size=256):
        super(PolicyNetwork, self).__init__()
        self.num_actions = num_actions
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, num_actions)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.softmax(self.linear2(x), dim=1)
        return x 
    
class Agent:

    def __init__(self, env, init, path, learning_rate=0.01):
        self.env = env
        self.num_actions = self.env.action_space.n
        self.policy_network = PolicyNetwork(self.env.observation_space.shape[0], self.env.action_space.n)
        if not init:
            self.policy_network.load_state_dict(torch.load(path))
        self.optimizer = optim.Adam(self.policy_network.parameters(), lr=learning_rate)

    def get_action(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.policy_network.forward(Variable(state))
        highest_prob_action = np.random.choice(self.num_actions, p=np.squeeze(probs.detach().numpy()))
        log_prob = torch.log(probs.squeeze(0)[highest_prob_action])
        return highest_prob_action, log_prob

    def update_policy(self, rewards, log_probs):
        discounted_rewards = []

        for t in range(len(rewards)):
            Gt = 0 
            pw = 0
            for r in rewards[t:]:
                Gt = Gt + GAMMA**pw * r
                pw = pw + 1
            discounted_rewards.append(Gt)
            
        discounted_rewards = torch.FloatTensor(discounted_rewards)
        discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-9) # normalize discounted rewards

        policy_gradient = []
        for log_prob, Gt in zip(log_probs, discounted_rewards):
            policy_gradient.append(-log_prob * Gt)
        
        self.optimizer.zero_grad()
        policy_gradient = torch.stack(policy_gradient).sum()
        policy_gradient.backward()
        if init:
            self.optimizer.step()

    def train(self, max_episode=3000, max_step=200):
        timestamps = []
        time_start = time.time()
        episode_rewards = []
        for episode in range(max_episode):
            state = env.reset()
            log_probs = []
            rewards = []
            episode_reward = 0

            for steps in range(max_step):
                action, log_prob = self.get_action(state)
                new_state, reward, done, _ = self.env.step(action)
                
                log_probs.append(log_prob)
                rewards.append(reward)
                episode_reward += reward
                
                if done:
                    self.update_policy(rewards, log_probs)
                    if episode % 10 == 0:
                        print("episode " + str(episode) + ": " + str(episode_reward))

                    break
                
                state = new_state
            episode_rewards.append(episode_reward)
            timestamps.append(time.time())
        return time_start, timestamps, episode_rewards
        


def save_results(rewards, timestamps, seed, time_start, env_name, init, n_run, name='reinforce'):
    run_dict = {'name': name, 
                'rewards': rewards,
                'time_start': time_start, 
                'timestamps': timestamps,
                'seed': seed,
                'n_run': n_run,
                'env_name': env_name}
    if not init:
        filename = 'run_time_%s_%s_%s.pickle' % (name, seed, 'pretrained')
    else:
        filename = 'run_time_%s_%s.pickle' % (name, seed)
    with open(filename, 'wb') as handle:
        pickle.dump(run_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
if __name__ == '__main__':
    env_name = 'CartPole-v0'
    #env_name = 'LunarLander-v2'
    #env_name = 'MountainCar-v0'
    env = gym.make(env_name)
    SEED = 1234
    env.seed(SEED);
    np.random.seed(SEED);
    #env.render()
    torch.manual_seed(SEED);
    num_runs = 2
    result_dict = {}
    for run in range(num_runs):
        agent = Agent(env, init, model_path)
        time_start, timestamps, rewards = agent.train(500,300)
        save_results(rewards, timestamps, SEED, time_start, env_name, init, run)
        if save:
            torch.save(agent.policy_network.state_dict(), model_path)

AttributeError: module 'gym.envs.box2d' has no attribute 'LunarLander'

In [2]:
# def main():
#     env = gym.make('obstacle-v0')
#     policy_net = PolicyNetwork(env.observation_space.shape[0], env.action_space.n, 128)
    
#     max_episode_num = 5000
#     max_steps = 10000
#     numsteps = []
#     avg_numsteps = []
#     all_rewards = []

#     for episode in range(max_episode_num):
#         state = env.reset()
#         log_probs = []
#         rewards = []

#         for steps in range(max_steps):
#             env.render()
#             action, log_prob = policy_net.get_action(state)
#             new_state, reward, done, _ = env.step(action)
#             log_probs.append(log_prob)
#             rewards.append(reward)

#             if done:
#                 update_policy(policy_net, rewards, log_probs)
#                 numsteps.append(steps)
#                 avg_numsteps.append(np.mean(numsteps[-10:]))
#                 all_rewards.append(np.sum(rewards))
#                 if episode % 1 == 0:
#                     sys.stdout.write("episode: {}, total reward: {}, average_reward: {}, length: {}\n".format(episode, np.round(np.sum(rewards), decimals = 3),  np.round(np.mean(all_rewards[-10:]), decimals = 3), steps))

#                 break
            
#             state = new_state
        
#     plt.plot(numsteps)
#     plt.plot(avg_numsteps)
#     plt.xlabel('Episode')
#     plt.show()