In [6]:
!pip install git+https://github.com/Farama-Foundation/Gymnasium.git@main
!pip install gymnasium[box2d]
!pip install moviepy --upgrade
!pip install pysdl2
!pip install pyvirtualdisplay

Collecting git+https://github.com/Farama-Foundation/Gymnasium.git@main
  Cloning https://github.com/Farama-Foundation/Gymnasium.git (to revision main) to /tmp/pip-req-build-4ggi2gsu
  Running command git clone --filter=blob:none --quiet https://github.com/Farama-Foundation/Gymnasium.git /tmp/pip-req-build-4ggi2gsu
  Resolved https://github.com/Farama-Foundation/Gymnasium.git to commit 443b1940f11087280663e884edea571f47f72413
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: gymnasium
  Building wheel for gymnasium (pyproject.toml) ... [?25ldone
[?25h  Created wheel for gymnasium: filename=gymnasium-1.0.0rc1-py3-none-any.whl size=935227 sha256=9e375aa6ee225e0c98325666b13c8cc3e33b7c1a546963c75e3f5727b1a7a2fc
  Stored in directory: /tmp/pip-ephem-wheel-cache-g69cprli/wheels/92/23/

In [27]:
import gymnasium as gym
from gymnasium import wrappers
#import pyvirtualdisplay
import cv2
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Categorical

In [12]:
env = gym.make('CarRacing-v2', continuous=False, render_mode="rgb_array")
env = wrappers.RecordVideo(env, 'video/car_racing', episode_trigger=lambda n: n%200==0, fps=30)

  logger.warn(


In [16]:
env.observation_space, env.action_space.n

(Box(0, 255, (96, 96, 3), uint8), 5)

In [3]:
def render(env, img):
    img.set_data(env.render())
    display.display(plt.gcf())
    display.clear_output(wait=True)

In [4]:
class RandomPolicy:
    def __init__(self, env):
        self.env = env
    
    def __call__(self, observation):
        return env.action_space.sample()
    
    def update(self, *args):
        # Do nothing
        pass
    
    def init_game(self, observation):
        pass
    

In [81]:
def play_game(policy, episodes=2000, do_render = False, seed=100):
    observation, info = env.reset(seed=seed)
    policy.init_game(observation)

    if do_render:
        plt.ion()
        plt.axis('off')
        img = plt.imshow(env.render())
   
    status = {}
    episode = 0
    status['steps'] = 0
    status['episode_reward'] = 0
    status['average_reward'] = None
    total_reward = 0
    
    env.metadata['status'] = status

    with tqdm(total=episodes) as pbar:
        pbar.set_postfix(status)
        while True:
            try:
                action = policy(observation)
                observation, reward, terminated, truncated, info = env.step(action)
                status['steps'] += 1
                status['episode_reward'] += reward
                if do_render:
                    render(env, img)
                policy.update(observation, reward, terminated, truncated, info, status)
                pbar.set_postfix(status)

                if terminated or truncated:
                    episode += 1
                    if episode > pbar.total:
                        break
                    total_reward += status['episode_reward']
                    if status['average_reward'] is None:
                        status['average_reward'] = status['episode_reward']
                    else:
                        status['average_reward'] = 0.05 * status['episode_reward'] + (1 - 0.05) * status['average_reward']
                    if status['average_reward'] > env.spec.reward_threshold:
                        print(f"Solved! Running reward is now {status['average_reward']} and "
                              f"the last episode runs to {status['steps']} time steps!")
                        break

                    pbar.set_postfix(status)
                    pbar.update()
                    status['steps'] = 0

                    status['episode_reward'] = 0
                    observation, info = env.reset()
                    policy.init_game(observation)
            except KeyboardInterrupt:
                env.close()
                break

In [None]:
policy = RandomPolicy(env)
play_game(policy, episodes=1, do_render=True)

In [82]:
class ActorCriticNetwork(nn.Module):

    def __init__(self,
                 input_shape : list=[3,96,96],
                 out_channels1 : int=5,
                 out_channels2 : int=5,
                 kernel_size : int=4,
                 stride : int=2,
                 hidden_size : int=256,
                 num_actions : int=5):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=input_shape[0], 
                               out_channels=out_channels1, 
                               kernel_size=kernel_size, 
                               stride=stride)
        
        self.conv2 = nn.Conv2d(in_channels=out_channels1, 
                               out_channels=out_channels2, 
                               kernel_size=kernel_size, 
                               stride=stride)
        
        # Determine the shape of self.conv2 output and pass it to linear1
        dummy_input = torch.rand(1,*input_shape)
        with torch.no_grad():
            out_conv2_shape = torch.flatten(self.conv2(self.conv1(dummy_input))).shape[0]
    
        self.linear1 = nn.Linear(out_conv2_shape, hidden_size)
        self.policy = nn.Sequential(nn.Linear(hidden_size, num_actions), nn.Softmax())
        self.value = nn.Linear(hidden_size, 1)
        
        self.relu = nn.ReLU()
        
        
    def forward(self, x):
        
        # Adjust tensor to have shape [batch, *image_shape]
        if len(x.shape) == 3:
            x = x.unsqueeze(0)
            
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = torch.flatten(x,1)
        x = self.linear1(x)
        x = self.relu(x)
        p = self.policy(x)
        v = self.value(x)
        return p, v

In [102]:
from collections import namedtuple
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
    
class ACPolicy:
    
    def __init__(self, env, gamma=0.99, lr=5e-3):
        # Two possible actions 0, 1       
        self.net = ActorCriticNetwork(num_actions=env.action_space.n)
        self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=lr)
        self.mean_reward = None
        self.games = 0
        self.gamma = gamma
        self.eps = np.finfo(np.float32).eps.item()
        self.best_reward = None

        
    def __call__(self, observation):
        x = torch.tensor(observation).permute(2, 0, 1) / 255
        probs, value = self.net(x)
        m = Categorical(probs)
        action = m.sample()
        
        self.memory.append(SavedAction(m.log_prob(action), value))
        self.last_observation = observation
        
        return action.item()
        
    def init_game(self, observation):
        self.memory = []
        self.rewards = []
        self.total_reward = 0
        
        
    def update(self, observation, reward, terminated, truncated, info, status):
        self.total_reward += reward
        self.rewards.append(reward)

        
        if terminated or truncated:
            self.games += 1
            if self.mean_reward is None:
                self.mean_reward = self.total_reward
            else:
                self.mean_reward = self.mean_reward * 0.95 + self.total_reward * (1.0 - 0.95)
                
            if self.best_reward is None:
                self.best_reward = self.total_reward
            elif self.total_reward > self.best_reward:
                self.best_reward = self.total_reward
                status['best'] = self.best_reward 
                self.save('best.pt')
                
            # calculate discounted reward and make it normal distributed
            discounted = []
            R = 0
            for r in self.rewards[::-1]:
                R = r + self.gamma * R
                discounted.insert(0, R)
            discounted = torch.tensor(discounted)
            discounted = (discounted - discounted.mean()) / (discounted.std() + self.eps)
            
            policy_losses = []
            value_losses = []
            for mem, discounted_reward in zip(self.memory, discounted):
                advantage = discounted_reward - mem.value.item() 
                #print(mem.value)
                policy_losses.append(-(mem.log_prob * advantage))
                value_losses.append(F.smooth_l1_loss(mem.value, discounted_reward.unsqueeze(0).unsqueeze(0)))
               
            self.optimizer.zero_grad()
            policy_loss = torch.stack(policy_losses).sum()
            value_loss = torch.stack(value_losses).sum()
            loss = policy_loss + value_loss 
            loss.backward()    
            self.optimizer.step()
            
            status['policy_loss'] = str(policy_loss.item())
            status['value_loss'] = str(value_loss.item())
            
            if self.games % 1000 == 0:
                self.save(f"model_{self.games}.pt")
    
    
    def load(self, PATH):
        checkpoint = torch.load(PATH)
        self.net.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.games = checkpoint['games']
        self.mean_reward = checkpoint['mean_reward']
        
    def save(self, PATH):
        torch.save({
                    'games': self.games,
                    'model_state_dict': self.net.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'mean_reward': self.mean_reward}, PATH)

In [106]:
env = gym.make('CarRacing-v2', continuous=False)
policy = ACPolicy(env)
policy.load('model_2000.pt')
play_game(policy)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [99]:
policy.save('best.pt')

In [105]:
env = gym.make('CarRacing-v2', continuous=False, render_mode="rgb_array")
env = wrappers.RecordVideo(env, 'video/car_racing', fps=30)
policy = ACPolicy(env)
policy.load('model_2000.pt')
play_game(policy, episodes=1)

  0%|          | 0/1 [00:00<?, ?it/s]