### This file is for testing radical changes to the codebase. Upside Down Reinforcement Learning.py is the working version.

In [1]:
import gym
env = gym.make('CartPole-v1')

In [40]:
import numpy as np
def random_policy(obs, command):
    return np.random.randint(env.action_space.n)

In [51]:
import time
from copy import deepcopy
#Visualise agent function
def visualise_agent(policy, command, n=5):
    try:
        for trial_i in range(n):
            current_command = deepcopy(command)
            observation = env.reset()
            done=False
            t=0
            episode_return=0
            while not done:
                env.render()
                action = policy(torch.tensor([observation]).double(), torch.tensor([command]).double())
                observation, reward, done, info = env.step(action)
                episode_return+=reward
                #time.sleep(0.1)
                current_command[0]-= reward
                current_command[1] = max(1, current_command[1]-1)
                
                t+=1
            env.render()
            time.sleep(1.5)
            print("Episode {} finished after {} timesteps. Return = {}".format(trial_i, t, episode_return))
        env.close()
    except KeyboardInterrupt:
        env.close()

In [42]:
#visualise_agent(random_policy, command=[500, 500], 1)

In [123]:
from copy import deepcopy
def collect_experience(policy, replay_buffer, replay_size, last_few, n_episodes=100, log_to_tensorboard=True):
    global i_episode
    init_replay_buffer = deepcopy(replay_buffer)
    try:
        for _ in range(n_episodes):
            command = sample_command(init_replay_buffer, last_few)
            writer.add_scalar('Command desired reward', command[0], i_episode)    # write loss to a graph
            writer.add_scalar('Command horizon', command[1], i_episode)    # write loss to a graph
            observation = env.reset()
            episode_mem = {'observation':[],
#                           'command':[],
                           'action':[],
                           'reward':[],}
            done=False
            while not done:
                action = policy(torch.tensor([observation]).double(), torch.tensor([command]).double())
                new_observation, reward, done, info = env.step(action)
                
                episode_mem['observation'].append(observation)
                #episode_mem['command'].append(command)
                episode_mem['action'].append(action)
                episode_mem['reward'].append(reward)
                
                observation=new_observation
                #command[0]-= reward
                command[0] = max(1, command[0]-reward)
                command[1] = max(1, command[1]-1)
                print(command)
            episode_mem['return']=sum(episode_mem['reward'])
            episode_mem['episode_len']=len(episode_mem['observation'])
            replay_buffer.append(episode_mem)
            i_episode+=1
            if log_to_tensorboard: writer.add_scalar('Return', sum(episode_mem['reward']), i_episode)    # write loss to a graph
            print("Episode {} finished after {} timesteps. Return = {}".format(i_episode, len(episode_mem['observation']), sum(episode_mem['reward'])))
        env.close()
    except KeyboardInterrupt:
        env.close()
    replay_buffer = sorted(replay_buffer, key=lambda x:x['return'])[-replay_size:]
    return replay_buffer

def sample_command(replay_buffer, last_few):
    if len(replay_buffer)==0:
        return [1, 1]
    else:
        command_samples = replay_buffer[-last_few:]
        lengths = [mem['episode_len'] for mem in command_samples]
        returns = [mem['return'] for mem in command_samples]
        mean_return, std_return = np.mean(returns), np.std(returns)
        command_horizon = np.mean(lengths)
        desired_reward = np.random.uniform(mean_return, mean_return+std_return)
        return [desired_reward, command_horizon]

In [124]:
def train_net(policy_net, replay_buffer, n_updates=100, batch_size=64, log_to_tensorboard=True):
    global i_updates
    all_costs = []
    for i in range(n_updates):
        batch_observations = np.zeros((batch_size, np.prod(env.observation_space.shape)))
        batch_commands = np.zeros((batch_size, 2))
        batch_label = np.zeros((batch_size))
        for b in range(batch_size):
            sample_episode = np.random.randint(0, len(replay_buffer))
            sample_t1 = np.random.randint(0, len(replay_buffer[sample_episode]['observation']))
            sample_t2 = len(replay_buffer[sample_episode]['observation'])
            sample_horizon = sample_t2-sample_t1
            sample_mem = replay_buffer[sample_episode]['observation'][sample_t1]
            sample_desired_reward = sum(replay_buffer[sample_episode]['reward'][sample_t1:sample_t2])
            network_input = np.append(sample_mem, [sample_desired_reward, sample_horizon])
            label = replay_buffer[sample_episode]['action'][sample_t1]
            batch_observations[b] = sample_mem
            batch_commands[b] = [sample_desired_reward, sample_horizon]
            batch_label[b] = label
        batch_observations = torch.tensor(batch_observations).double()
        batch_commands = torch.tensor(batch_commands).double()
        batch_label = torch.tensor(batch_label).long()
        pred = policy_net(batch_observations, batch_commands)
        cost = F.cross_entropy(pred, batch_label)
        all_costs.append(cost.item())
        cost.backward()
        policy_net.optimizer.step()
        policy_net.optimizer.zero_grad()
    return np.mean(all_costs)

In [125]:
def create_greedy_policy(policy_network):
    def policy(obs, command):
        action_prob = policy_network(obs, command)
        action = np.argmax(action_prob.detach().numpy())
        return action
    return policy

def create_stochastic_policy(policy_network):
    def policy(obs, command):
        action_prob = policy_network(obs, command)
        action_sample = np.random.multinomial(1, action_prob.detach().numpy())
        action = np.argmax(action_sample)
        return action
    return policy

In [126]:
import torch
import torch.nn.functional as F

class FCNN_AGENT(torch.nn.Module):
    def __init__(self):
        super().__init__()
        hidden_size=64
        self.observation_embedding = torch.nn.Sequential(
            torch.nn.Linear(np.prod(env.observation_space.shape), hidden_size),
            torch.nn.ReLU()
        )
        self.command_embedding = torch.nn.Sequential(
            torch.nn.Linear(2, hidden_size),
            torch.nn.Sigmoid()
        )
        self.to_output = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, env.action_space.n)
        )
    
    def forward(self, observation, command):
        obs_emebdding = self.observation_embedding(observation)
        #print(obs_emebdding.shape)
        cmd_embedding = self.command_embedding(command)
        print(cmd_embedding.shape)
        embedding = torch.mul(obs_emebdding, cmd_embedding)
        #print(embedding.shape)
        action_probs = F.softmax(self.to_output(embedding), dim=-1)
        #print(action_probs.shape)
        #gg
        return action_probs
    
    def create_optimizer(self, lr):
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)

In [127]:
i_episode=0
i_updates=0
replay_buffer = []
replay_size = 500
last_few = 50
log_to_tensorboard = True 

batch_size = 128
n_warm_up_episodes = 100
n_episodes_per_iter = 50
n_updates_per_iter = 100

lr = 0.001
agent = FCNN_AGENT().double()
agent.create_optimizer(lr)

stochastic_policy = create_stochastic_policy(agent)
greedy_policy = create_greedy_policy(agent)

In [128]:
# SET UP TRAINING VISUALISATION
# SET UP TRAINING VISUALISATION
if log_to_tensorboard: from torch.utils.tensorboard import SummaryWriter
if log_to_tensorboard: writer = SummaryWriter() # we will use this to show our models performance on a graph using tensorboard

In [129]:
#Collect warm up episodes
replay_buffer = collect_experience(random_policy, replay_buffer, replay_size, last_few, n_warm_up_episodes, log_to_tensorboard)
train_net(agent, replay_buffer, n_updates=n_updates_per_iter, batch_size=batch_size, log_to_tensorboard)

SyntaxError: positional argument follows keyword argument (<ipython-input-129-09b1fadbae1d>, line 3)

In [122]:
n_iters = 10000
for i in range(n_iters):
    replay_buffer = collect_experience(stochastic_policy, replay_buffer, replay_size, last_few, n_episodes_per_iter, log_to_tensorboard)
    train_net(agent, replay_buffer, n_updates=n_updates_per_iter, batch_size=batch_size, log_to_tensorboard)

torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 101 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 102 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 103 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 104 finished after 10 times

torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 142 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 143 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 144 finished after 9 timesteps. Return = 9.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 145 finished after 8 timesteps. Return = 8.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torc

torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 185 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 186 finished after 9 timesteps. Return = 9.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 187 finished after 8 timesteps. Return = 8.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 188 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torc

torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 235 finished after 9 timesteps. Return = 9.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 236 finished after 9 timesteps. Return = 9.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 237 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 238 finished after 9 timesteps. Return = 9.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.

torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 285 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 286 finished after 9 timesteps. Return = 9.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 287 finished after 9 timesteps. Return = 9.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 288 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torc

torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 333 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 334 finished after 9 timesteps. Return = 9.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 335 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 336 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
to

torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 364 finished after 8 timesteps. Return = 8.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 365 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 366 finished after 9 timesteps. Return = 9.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 367 finished after 9 timesteps. Return = 9.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.

torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size

torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size

torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 498 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 499 finished after 10 timesteps. Return = 10.0
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
torch.Size([1, 64])
Episode 500 finished after 10 timesteps. Return = 10.0
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128

torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size

torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([128, 64])


KeyboardInterrupt: 

In [None]:
#agent.load_state_dict(torch.load('checkpoints/lunar_lander_32x32_checkpoint_0.pt'))

In [72]:
visualise_agent(greedy_policy, command=[250, 200], n=5)

Episode 0 finished after 84 timesteps. Return = -134.1428887684655
Episode 1 finished after 62 timesteps. Return = -143.282848576283
Episode 2 finished after 72 timesteps. Return = -0.15610989654310004
Episode 3 finished after 106 timesteps. Return = -72.03684838919166
Episode 4 finished after 105 timesteps. Return = -37.44217835732849


In [73]:
visualise_agent(stochastic_policy, command=[250, 200], n=5)

Episode 0 finished after 97 timesteps. Return = -89.47575725612703


In [220]:
#torch.save(agent.state_dict(), 'checkpoints/lunar_lander_32x32_checkpoint_0.pt')

In [1]:
#print([mem['return'] for mem in replay_buffer])

# Previous Code

In [None]:
def train_net(policy_net, replay_buffer, n_updates=100, batch_size=64):
    all_costs = []
    for i in range(n_updates):
        batch_input = np.zeros((batch_size, np.prod(env.observation_space.shape)+2))
        batch_label = np.zeros((batch_size))
        for b in range(batch_size):
            sample_episode = np.random.randint(0, len(replay_buffer))
            sample_horizon = np.random.randint(1, len(replay_buffer[sample_episode]['observation'])+1)
            sample_mem_idx = np.random.randint(0, len(replay_buffer[sample_episode]['observation'])+1-sample_horizon)
            sample_mem = replay_buffer[sample_episode]['observation'][sample_mem_idx]
            sample_desired_reward = sum(replay_buffer[sample_episode]['reward'][sample_mem_idx:sample_mem_idx+sample_horizon])
            network_input = np.append(sample_mem, [sample_desired_reward, sample_horizon])
            label = replay_buffer[sample_episode]['action'][sample_mem_idx]
            batch_input[b] = network_input
            batch_label[b] = label
        batch_input = torch.tensor(batch_input).double()
        batch_label = torch.tensor(batch_label).long()
        pred = policy_net(batch_input)
        cost = F.cross_entropy(pred, batch_label)
        all_costs.append(cost.item())
        cost.backward()
        policy_net.optimizer.step()
        policy_net.optimizer.zero_grad()
    return np.mean(all_costs)

In [8]:
def train_net(policy_net, episode_mem, n_samples = 5): #stochastic gradient descent
    all_costs = []
    for i in range(n_samples):
        sample_horizon = np.random.randint(1, len(episode_mem['observation'])+1)
        sample_mem_idx = np.random.randint(0, len(episode_mem['observation'])+1-sample_horizon)
        sample_mem = episode_mem['observation'][sample_mem_idx]
        sample_desired_reward = sum(episode_mem['reward'][sample_mem_idx:sample_mem_idx+sample_horizon])
        network_input = torch.tensor(np.append(sample_mem, [sample_desired_reward, sample_horizon])).double()
        label = torch.tensor([episode_mem['action'][sample_mem_idx]]).double()
        
        pred = policy_net(network_input)
        cost = F.binary_cross_entropy(pred, label)
        all_costs.append(cost.item())
        cost.backward()
        policy_net.optimizer.step()
        policy_net.optimizer.zero_grad()
    return np.mean(all_costs)
    

In [33]:
def train(policy_net, n_episodes=100):
    global i_episode
    global epsilon
    try:
        for _ in range(n_episodes):
            observation = env.reset()
            episode_mem = {'observation':[],
                            'action':[],
                            'reward':[],
                            'done':[]}
            done=False
            while not done:
                network_input = torch.tensor(np.append(observation, [desired_reward, command_horizon])).double()
                action_prob = policy_net(network_input)
                action = np.random.binomial(1, action_prob.item())
                #action = int(action_prob.item()>0.5)
                if np.random.rand()<epsilon: action = np.random.randint(0, 2)
                new_observation, reward, done, info = env.step(action)
                
                episode_mem['observation'].append(observation)
                episode_mem['action'].append(action)
                episode_mem['reward'].append(reward)
                episode_mem['done'].append(done)
                
                observation=new_observation
                epsilon*=0.999
            episode_mem['return']=sum(episode_mem['reward'])
            episode_mem['episode_len']=len(episode_mem['observation'])
            mean_cost = train_net(policy_net, episode_mem)
            
            i_episode+=1
            print("Episode {} finished after {} timesteps. Epsilon={} Mean Cost={}".format(i_episode, len(episode_mem['observation']), epsilon, mean_cost))
        env.close()
    except KeyboardInterrupt:
        env.close()