In [16]:
import tmrl
import time
import matplotlib.pyplot as plts
import numpy as np
import torch
import torch.nn as nn

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.get_device_name(device)

'NVIDIA GeForce RTX 4090'

In [3]:
env = tmrl.get_environment()
print('Observation Space:\t', env.observation_space)
print('Action Space:\t\t', env.action_space)
observation_space = np.sum([np.product(value.shape) for value in env.observation_space])
action_space = env.action_space.shape[0]
print('observation_space logits:', observation_space)
print('action_space logits:\t', action_space)

Observation Space:	 Tuple(Box(0.0, 1000.0, (1,), float32), Box(0.0, 1.0, (1,), float32), Box(0.0, inf, (4, 19), float32), Box(-1.0, 1.0, (3,), float32), Box(-1.0, 1.0, (3,), float32))
Action Space:		 Box(-1.0, 1.0, (3,), float32)
observation_space logits: 84
action_space logits:	 3


In [4]:
hyper_params = {'policy_lr': 1e-5,
                'critic_lr': 1e-5,
                'gamma': 0.996,
                'clip_coef': 0.2,
                'critic_coef': 0.1,
                'entropy_coef': 0.1,
                'batch_size': 256,
                'num_updates': 10000,
                'epochs_per_update': 100,
                'hidden_dim':512,
                'max_episode_steps': 2400,
                'norm_advantages': True,
                'grad_clip_val': 0.1,
                'initial_std': 1,
                'avg_ray': 400}

In [5]:
class Policy (nn.Module):
    def __init__(self):
        super().__init__()
        self.action_mean = nn.Sequential(
            #nn.LayerNorm(observation_space),
            #nn.BatchNorm1d(observation_space),
            nn.Linear(observation_space, hyper_params['hidden_dim']),
            nn.ReLU(),
            nn.Linear(hyper_params['hidden_dim'],hyper_params['hidden_dim']),
            nn.ReLU(),
            nn.Linear(hyper_params['hidden_dim'], action_space),
            nn.Tanh()
        )

        self.actor_logvar = nn.Sequential(
            #nn.LayerNorm(observation_space),
            #nn.BatchNorm1d(observation_space),
            nn.Linear(observation_space, hyper_params['hidden_dim']),
            nn.ReLU(),
            nn.Linear(hyper_params['hidden_dim'],hyper_params['hidden_dim']),
            nn.ReLU(),
            nn.Linear(hyper_params['hidden_dim'], 1)
        )

    def sample_action_with_logprobs(self, observation):
        dist = self(observation)
        sample_action = dist.sample()
        return sample_action, dist.log_prob(sample_action)
    
    def mean_only(self, observation):
        with torch.no_grad():
            return self.action_mean(observation)
    
    def get_action_log_prob(self, observation, action):
        dist = self(observation)
        return dist.log_prob(action)
    
    def forward(self, observation):
        observation /= hyper_params['avg_ray']
        means = self.action_mean(observation)
        vars = torch.zeros(observation.shape[0], action_space).to(device)
        vars[:,:] = self.actor_logvar(observation).exp().view(-1,1)
        covar_mat = torch.zeros(observation.shape[0], action_space, action_space).to(device)
        covar_mat[:,np.arange(action_space), np.arange(action_space)] = vars

        dist = torch.distributions.MultivariateNormal(means, covar_mat)
        return dist
        
class Critic (nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            #nn.LayerNorm(observation_space),
            #nn.BatchNorm1d(observation_space),
            nn.Linear(observation_space, hyper_params['hidden_dim']),
            nn.ReLU(),
            nn.Linear(hyper_params['hidden_dim'],hyper_params['hidden_dim']),
            nn.ReLU(),
            nn.Linear(hyper_params['hidden_dim'], 1)
        )
    def forward(self, observation):
        observation /= hyper_params['avg_ray']
        return self.network(observation)
    
class Agent(nn.Module):
    def __init__(self):
        super().__init__()
        self.policy = Policy()
        self.critic = Critic()
    def forward(self, x):
        raise SyntaxError('Propagate through Agent.policy \
                          and Agent.critic individually')

In [6]:
def env_obs_to_tensor(observations):
    tensors = [torch.tensor(observation).view(-1) for observation in observations]
    return torch.cat(tuple(tensors), dim=-1)

def env_act_to_tensor(action):
    return torch.tensor(action)

In [18]:
agent = Agent().to(device)
policy_optim = torch.optim.Adam(agent.policy.parameters(), lr=hyper_params['policy_lr'])
critic_optim = torch.optim.Adam(agent.critic.parameters(), lr=hyper_params['critic_lr'])
#agent.load_state_dict(torch.load('130.33999633789062RewardRacer56Update.pt'))

In [19]:
def train_PPO():
    cum_rewards = []
    actor_losses = []
    critic_losses = []
    total_losses = []

    cum_reward = 0
    for update in range(hyper_params['num_updates']):        
        ##info to record per episode
        obs = torch.zeros(hyper_params['max_episode_steps'], observation_space)
        actions = torch.zeros(hyper_params['max_episode_steps'], action_space)
        logprobs = torch.zeros(hyper_params['max_episode_steps'])
        rewards = torch.zeros(hyper_params['max_episode_steps'])
        state_values = torch.zeros(hyper_params['max_episode_steps'])
        returns = torch.zeros(hyper_params['max_episode_steps'])
        
        #check if I should use a learning rate scheduler ##
        
        #reset  the enviornment before each new episode
        next_obs = env_obs_to_tensor(env.reset()[0]) #just grab obs
        
        ##this is according to tmrl docs - 
        ##have to manually click the game to ensure focus on
        ##window when training 
        #if update == 0:
            #time.sleep(1.0)
        
        max_idx = 0
        was_terminated = False
        agent.eval()
        for step in range(hyper_params['max_episode_steps']):
            obs[step] = next_obs

            with torch.no_grad():
                action, logprob = agent.policy.sample_action_with_logprobs(next_obs.to(device).unsqueeze(0))
                state_value = agent.critic(next_obs.to(device).unsqueeze(0))
            actions[step] = action[0]
            logprobs[step] = logprob[0]
            state_values[step] = state_value[0]

            ##actions are sampled from gaussian distribution - could be greater
            ##in maginitude than the action space allowed range
            clamped_action = np.clip(np.array(action.cpu()),-1,1)

            next_obs, reward, terminated, truncated, info = env.step(clamped_action[0])

            #terminate episode if stuck on a rail
            if next_obs[2][next_obs[2] <= 40].sum() > 0:
                terminated = True
            
            #print('reward here', reward)
            rewards[step] = torch.tensor(reward)
            next_obs = env_obs_to_tensor(next_obs)
            if terminated or truncated:
                was_terminated = True
                max_idx = step
                break
        ##pause environment according to tmrl
        env.wait()
        ##
        max_idx = step
        ##calculate cumulative rewards - state values
        with torch.no_grad():
            for t in range(max_idx + 1)[::-1]:
                if t == (max_idx):
                    if not was_terminated:
                        ##bootstrap value with critic estimation if episode wasn't terminated but went max steps
                        returns[t] = rewards[t] #+ (hyper_params['gamma']*agent.critic(next_obs.to(device)))
                    else:
                        returns[t] = rewards[t]
                else:
                    returns[t] = rewards[t] + (hyper_params['gamma']*returns[t+1])
                    #print('here', returns[t], 't=', t)
            advantages = returns - state_values
            cum_reward = rewards.sum().item()
        #print('Update', update + 1, 'Cumulative rewrards:', returns[0])
        rand_idxs = np.random.permutation(np.arange(max_idx+1))
        epochs_values_loss = []
        epochs_ppo_loss = []
        epochs_total_loss = []
        
        if cum_reward > 200:
            torch.save(agent.state_dict(), f'Y{cum_reward:.2f}RewardRacer{update}Update_2.pt')
        
        agent.train()
        for epoch in range(hyper_params['epochs_per_update']):
            for batch_start_idx in range(0, max_idx, hyper_params['batch_size']):
                batch_end_idx = batch_start_idx + hyper_params['batch_size']
                batch_idxs = rand_idxs[batch_start_idx:batch_end_idx]

                batch_obs = obs[batch_idxs]
                batch_actions = actions[batch_idxs]

                ##calculate the ppo objective
                #############################################################################################################################
                batch_new_log_probs = agent.policy.get_action_log_prob(batch_obs.to(device), batch_actions.to(device))
                batch_old_log_probs = logprobs[batch_idxs].to(device)

                log_ratio = batch_new_log_probs - batch_old_log_probs
                ratio = log_ratio.exp()

                ##calculate the surrogate objective function
                batch_advantages = advantages[batch_idxs].to(device)
                if hyper_params['norm_advantages']:
                    batch_advantages = (batch_advantages - batch_advantages.mean()) / (batch_advantages.std() + 1e-8)

                unclipped_obj = -ratio * batch_advantages
                clipped_obj = -torch.clip(ratio, 1 - hyper_params['clip_coef'], 1 + hyper_params['clip_coef']) * batch_advantages
                ppo_loss = torch.max(unclipped_obj, clipped_obj).sum() / hyper_params['batch_size'] ##not .mean() in case of small last batch
                epochs_ppo_loss.append(ppo_loss.item())

                ##value loss is literally just mse loss of actual on-polcy cumulative rewards
                ##we already made the advantage - we only train our critic here
                new_state_values = agent.critic(batch_obs.to(device))
                v_loss = ((new_state_values.view(-1) - returns[batch_idxs].to(device))**2).sum() / hyper_params['batch_size'] ##not .mean() in case of small last batch
                epochs_values_loss.append(v_loss.item())
            
                total_loss = ppo_loss + hyper_params['critic_coef']*v_loss
                epochs_total_loss.append(total_loss.item())
                
                policy_optim.zero_grad()
                critic_optim.zero_grad()
                
                total_loss.backward()

                nn.utils.clip_grad.clip_grad_value_(agent.policy.parameters(), clip_value=hyper_params['grad_clip_val'])
                for param in agent.policy.parameters():
                    mask = torch.isnan(param.grad)
                    param.grad[mask] = 0.0
                    if mask.sum == param.numel():
                        print('code is broken, yup')
                policy_optim.step()

                nn.utils.clip_grad.clip_grad_value_(agent.critic.parameters(), clip_value=hyper_params['grad_clip_val'])
                critic_optim.step()
        print('Update', update + 1)
        print('actor loss', np.mean(epochs_ppo_loss))
        print('critic loss', np.mean(epochs_values_loss))
        print('total loss', np.mean(epochs_total_loss))
        print('total reward', cum_reward)
        cum_rewards.append(cum_reward)
        actor_losses.append(np.mean(epochs_ppo_loss))
        critic_losses.append(np.mean(epochs_values_loss))
        total_losses.append(np.mean(epochs_total_loss))
    return cum_rewards, actor_losses, critic_losses, total_losses



In [20]:
time.sleep(1)
env.reset()



((array([0.], dtype=float32),
  array([0.], dtype=float32),
  array([[457., 465., 410., 305., 247., 212., 191., 197., 174., 171., 173.,
          182., 191., 212., 247., 304., 413., 465., 458.],
         [457., 465., 410., 305., 247., 212., 191., 197., 174., 171., 173.,
          182., 191., 212., 247., 304., 413., 465., 458.],
         [457., 465., 410., 305., 247., 212., 191., 197., 174., 171., 173.,
          182., 191., 212., 247., 304., 413., 465., 458.],
         [457., 465., 410., 305., 247., 212., 191., 197., 174., 171., 173.,
          182., 191., 212., 247., 304., 413., 465., 458.]], dtype=float32),
  array([0., 0., 0.], dtype=float32),
  array([0., 0., 0.], dtype=float32)),
 {})

In [21]:
cum_rewards, actor_losses, critic_losses, total_losses = train_PPO()



Update 1
actor loss -0.009400872273836285
critic loss 2.023207831978798
total loss 0.19291991442441941
total reward 5.279999732971191
Update 2
actor loss -0.010008172020316125
critic loss 0.9862292709946633
total loss 0.08861475693061947
total reward 7.090000152587891
Update 3
actor loss -0.01111858032643795
critic loss 0.2712934844195843
total loss 0.016010768599808215
total reward 3.2699999809265137
Update 4
actor loss -0.009975759582594036
critic loss 0.30344767570495607
total loss 0.020369008490815757
total reward 5.539999961853027
Update 5
actor loss -0.017283723820000887
critic loss 0.41214772647246717
total loss 0.023931049574166537
total reward 0.5199999809265137
Update 6
actor loss -0.00830787731334567
critic loss 0.02597520312294364
total loss -0.0057103569549508395
total reward 0.7399999499320984
Update 7
actor loss -0.0051650963537395005
critic loss 0.025153762307018043
total loss -0.0026497200899757447
total reward 0.6799999475479126
Update 8
actor loss -0.0038840771280229



Update 10
actor loss -0.024083445649594068
critic loss 3.761578140258789
total loss 0.3520743757486343
total reward 7.210000514984131
Update 11
actor loss -0.006635284470394254
critic loss 1.2799148745834827
total loss 0.12135620509274304
total reward 0.47999995946884155
Update 12
actor loss -0.019263005927205087
critic loss 0.35521884948015214
total loss 0.016258879620581864
total reward 3.2099997997283936
Update 13
actor loss -0.01254948424641043
critic loss 0.15089648276567458
total loss 0.0025401642266660927
total reward 3.4100000858306885
Update 14
actor loss -0.009174697633522253
critic loss 1.0001399705807368
total loss 0.09083930103729168
total reward 9.109999656677246
Update 15
actor loss -0.024157336093485356
critic loss 2.4907375955581665
total loss 0.22491642817854882
total reward 6.369999408721924




Update 16
actor loss -0.011489919777959585
critic loss 2.442085521221161
total loss 0.23271863639354706
total reward 10.5
Update 17
actor loss -0.009684189958497881
critic loss 1.1443269261717797
total loss 0.10474850432947277
total reward 9.74000072479248
Update 18
actor loss -0.006086992784403265
critic loss 1.096986515522003
total loss 0.10361166054848582
total reward 10.069999694824219
Update 19
actor loss -0.011931394655257464
critic loss 1.1214669926464558
total loss 0.10021530678495764
total reward 9.739999771118164
Update 20
actor loss -0.004262599921785295
critic loss 1.1946022525429725
total loss 0.11519762700423598
total reward 9.760000228881836
Update 21
actor loss -0.007044869031524285
critic loss 1.0002464285492898
total loss 0.09297977544600144
total reward 9.90999984741211
Update 22
actor loss -0.011514572785235941
critic loss 0.732220496237278
total loss 0.061707477830350396
total reward 9.039999008178711
Update 23
actor loss -0.023350348994135857
critic loss 0.4023406

Exception in thread Thread-12638 (__send_act_get_obs_and_wait):
Traceback (most recent call last):
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.10_3.10.3056.0_x64__qbz5n2kfra8p0\lib\threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.10_3.10.3056.0_x64__qbz5n2kfra8p0\lib\threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\xande\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\rtgym\envs\real_time_env.py", line 439, in __send_act_get_obs_and_wait
    self.__update_obs_rew_terminated_truncated()  # capture observation
  File "C:\Users\xande\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\rtgym\envs\real_time_env.py", line 453, in __update_obs_rew_terminated_truncated
    o, r, d, i = self.inte

KeyboardInterrupt: 

In [None]:
torch.distributions.MultivariateNormal(torch.zeros(3), torch.randn(9).reshape(3,3).exp())

In [12]:
def evaluate_model(model, env):
    done = False
    time.sleep(1.0)
    next_state = env_obs_to_tensor(env.reset()[0]).to(device)
    model.to(device)
    model.eval()
    while not done:
        action, logprob = model.policy.sample_action_with_logprobs(next_state.to(device).unsqueeze(0))
        #action = model.policy.mean_only(next_state.to(device).unsqueeze(0))
        clamped_action = np.clip((action.detach().cpu().numpy())[0], -1,1)
        next_state, reward, done, truncated, info = env.step(clamped_action)
        next_state = env_obs_to_tensor(next_state)

In [13]:
agent = Agent()
agent.load_state_dict(torch.load('Y206.65RewardRacer155Update_2.pt'))

<All keys matched successfully>

In [14]:
env.reset()
pass



In [15]:
evaluate_model(agent, env)



KeyboardInterrupt: 