In [2]:
import numpy as np
import gymnasium as gym
from dataclasses import dataclass
import typing as tt
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
HIDDEN_SIZE = 128
BATCH_SIZE = 16
PERCENTILE = 70
MAX_BATCHES = 1

In [4]:
class Net(nn.Module):
    def __init__(self, obs_size: int, hidden_size: int, n_actions: int):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x: torch.Tensor):
        return self.net(x)

In [5]:
@dataclass
class EpisodeStep:
    observation: np.ndarray
    action: int

@dataclass
class Episode:
    reward: float
    steps: tt.List[EpisodeStep]

In [6]:
def generate_batches(env: gym.Env,
                    net: Net,
                    batch_size: int) -> tt.Generator[tt.List[Episode], None, None]:

    # the final batch of episodes
    batch = []

    # reset the env and get the first observation
    obs, _ = env.reset()

    episode_reward = 0.0
    episode_steps = []

    # used to extract a list of action probabilities
    # from the nn model
    sm = nn.Softmax(dim=1)

    while True:

        obs_v = torch.tensor(obs, dtype=torch.float32)
        act_probs_v = sm(net(obs_v.unsqueeze(0))) # retrieve the action probabilities for the first observation
        act_probs = act_probs_v.data.numpy()[0]

        action = np.random.choice(len(act_probs), p=act_probs) # choose an action using that distribution

        next_obs, reward, is_done, is_trunc, _ = env.step(action) # perfom the action

        episode_reward += float(reward)
        step = EpisodeStep(observation=obs, action=action)
        episode_steps.append(step)

        if is_done or is_trunc:
            e = Episode(reward=episode_reward, steps=episode_steps)
            batch.append(e)

            # resets everything
            episode_reward = 0.0
            episode_steps = []
            next_obs, _ = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []

        obs = next_obs

In [7]:
def filter_batch(batch: tt.List[Episode], percentile: float) -> \
        tt.Tuple[torch.FloatTensor, torch.LongTensor, float, float]:
    rewards = list(map(lambda s: s.reward, batch))
    reward_bound = float(np.percentile(rewards, percentile))
    reward_mean = float(np.mean(rewards))

    train_obs: tt.List[np.ndarray] = []
    train_act: tt.List[int] = []
    for episode in batch:
        if episode.reward < reward_bound:
            continue
        train_obs.extend(map(lambda step: step.observation, episode.steps))
        train_act.extend(map(lambda step: step.action, episode.steps))

    train_obs_v = torch.FloatTensor(np.vstack(train_obs))
    print(f"TRAIN OBS V: {train_obs_v.shape}")
    train_act_v = torch.LongTensor(train_act)
    return train_obs_v, train_act_v, reward_bound, reward_mean

In [8]:
# Loading the enviroment
env = gym.make("ALE/DonkeyKong-v5", render_mode="rgb_array", obs_type="ram")#, obs_type="grayscale")
env = gym.wrappers.RecordVideo(env, video_folder="video")

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
  logger.warn(


In [9]:
#obs, _ = env.reset()
#print(obs.flat)

obs_size = 128#env.observation_space.shape[0]
n_actions = int(env.action_space.n)

print(obs_size)
print(n_actions)

128
18


In [10]:
# Defining the model
net = Net(obs_size, HIDDEN_SIZE, n_actions)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.01)

In [11]:
batches = 0

for iter_no, batch in enumerate(generate_batches(env, net, BATCH_SIZE)):
    obs_v, acts_v, reward_b, reward_m = filter_batch(batch, PERCENTILE)
    optimizer.zero_grad()
    action_scores_v = net(obs_v)
    loss_v = loss_fn(action_scores_v, acts_v)
    loss_v.backward()
    optimizer.step()
    print("%d: loss=%.3f, reward_mean=%.1f, rw_bound=%.1f" % (
        iter_no, loss_v.item(), reward_m, reward_b))

    batches += 1

    if batches == MAX_BATCHES:
        print("Solved!")
        break

  logger.warn(


Moviepy - Building video /home/blkdmr/university/projects/rl/nn/video/rl-video-episode-0.mp4.
Moviepy - Writing video /home/blkdmr/university/projects/rl/nn/video/rl-video-episode-0.mp4



TypeError: must be real number, not NoneType