In [1]:
from utils import *

### Isak Andersson - AI23

# <p style="text-align:center;">Laboration - Deep Learning</p>

---

## Intro

This notebook aims to show the progress of understanding and learning how to successfully implement a Reinforcement Learning Network across multiple platforms.

The platforms used are TensorFLow-Keras, PyTorch and Jax, where Keras will be used as a baseline when comparing the others. The reason why these are chosen in this way stems from that the Keras code for this was mostly given before this study began, at the same time as Tensorflow is a dying breed, suggesting that as a future ML engineers, PyTorch and Jax should be the focus.

As per the assignments request the platforms RL models will be tested at playing the game Space Invaders. They will not nessecarily be using the same settings or size, since this notebook is *not* about comparing what platform is best at playing video games, but rather how different syntax and workflow is used when writing the different platforms model descriptions.

But don't worry, there will be graphs of their different scores!


## Method

The given code *(./from_lecturer/Lec5-RL-Gymnasium.py)* for the Keras network from the assignment parameters, that was set to run another game, Break Out, got adjusted to the following to make it work, plus some quality of life changes, for Space Invaders:

- The game was of course set to "SpaceInvadersNoFrameskip-v4" instead of "BreakoutNoFrameskip-v4".
- The number of available actions was set to 6 instead of 4.
- The input method ("layer 0") was changed from a lambda function to a permutation with *layers.Permute*.
- One extra Dense layer was added with 256 features.
- Most lists for storage was dequed with *collections.deque* to be able to remove items from start and end of list in O(N) time.
- Logging was added, and even more comprehensive for the latter two model frameworks.
- The model got trained after 6 actions instead of 4.

The code then ran (on cpu) for about 100 hours, or 15 000 + playthroughs (episodes). This became the baseline to be used for comparing, both regarding point stats, but also compute time.

The working code was then "copied" to a new file and rewritten to run on PyTorch. With some help from Anthropics Claude LLM, but with *a lot* of human clean-up/re-writes, and with the same settings, this code ran for about 10 hours (on gpu), or 15 000 episodes. There was another 10h of PyTorch training, but where the replay functionality was not running due to developer error. This data will be shown for fun, and shame.

At last, the code was again in a similar fashion copied, anthropiced and cleaned to fit the JAX (Flax-Linen) framework suite. However, this endevour was abandoned because of time restraints and, consequently, because of the fact that [JAX does not fully support gpu acceleration when running on Windows.](https://jax.readthedocs.io/en/latest/installation.html#install-nvidia-gpu) But, the code is written, and it is running. So for this exercise it has filled it's function.

## The Network

All of the platforms are using the same kind of network initially, and where they've been adjusted slighlty during writing. But the base model looks like this (keras_trainer.ipynb):

Please do not mind the syntax errors, as these are only meant to show the architecture of the models used. The notebook is runable since the code shown below are within a function that never gets called.

For full code see ./trainers/.

In [2]:
def KERAS():
    Conv2D(filters=32, kernel_size=8, strides=4, activation="relu"),
    Conv2D(filters=64, kernel_size=4, strides=2, activation="relu"),
    Conv2D(filters=64, kernel_size=3, activation="relu"),
    Flatten(),
    Dense(units=512, activation="relu"),
    Dense(units=256, activation="relu"),
    Dense(units=6, activation="linear")


This is slightly modified from the code given before the assignment, where an extra convolutional layer has been added and the Dense(keras)/Linear(torch) layers are doubled and larger in size.

The baseline-structure for all these models was loosly built upon an example that can be found [here.](https://chloeewang.medium.com/using-deep-reinforcement-learning-to-play-atari-space-invaders-8d5159aa69ed)

The **PyTorch** defined network is very similar but with some key differences (torch_trainer.py):

In [3]:
def TORCH():
    self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
    self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
    self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3)
    self.fc1 = nn.Linear(in_features=64 * 7 * 7, out_features=512)
    self.fc2 = nn.Linear(in_features=512, out_features=256)
    self.fc3 = nn.Linear(in_features=256, out_features=6)

# These then get called in a forward function as so:

def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    x = F.relu(self.conv3(x))
    x = x.view(x.size(0), -1) # The Flatten layer
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x)) # Not the same activation as in Keras!
    return self.fc3(x)

Here, the activation functions are obviously separated which might be the most apperant difference. It should be mentioned that PyTorch, as well as the other two, allows for different ways of writing but where the author of this study prefered the way shown above. But take note of that the coder needs to do the input/output math for every layer manually. Is this perhaps an oversight? Or is there more flexible options because of this? It seems like this would have been somewhat easy to implement if it's the same math needed every time. In any case, the math when calculating the output/input of a convolution layer is calculated like so:

<center>

$$
Output Size = \frac{Input Size - Kernel Size + 2(Padding)}{Stride}+1
$$

</center>

where padding refers to borderpixels.

As previously stated, the **JAX** model was abandoned. However, it was still a good learning experience. And for true transparency here is how the layers got defined in JAX:

In [4]:
def JAX():
    x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4))(x)
    x = nn.relu(x)
    x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2))(x)
    x = nn.relu(x)
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = x.reshape((x.shape[0], -1))  # Flatten
    x = nn.Dense(features=512)(x)
    x = nn.relu(x)
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=6)(x)
    return x

Both PyTorch and JAX (Flax-Linen) named their prefered import shorthand "nn", which initially was somewhat confusing. But considering no one owns the term "Neural Network", which we can assume the "nn" stands for, it maybe shouldn't be that confusing.

There of course is a lot of similarities between the different frameworks, after all, they are supposed to be doing almost the same thing. There are some glaring misstakes that, were these long model runs to be run again, would be fixed. Mainly the last layer activation function in both PyTorch and JAX. It proved itself harder than first thought to implement a linear activation in these models, and the problem was first noticed way too late.

The optimizer for all three versions were again the same as in the base example. But all of them imported through their own module libraries.
- Optimizer: Adam, with a learning rate of 0.00025

The loss-function however was not as easy to copy between the different frameworks. HUBER was used in the base example, but wasn't found, at least not under the same name, in PyTorch. Maybe human error. In any case, a substitute loss-function "smooth_L1_loss" was suggested since *[both provide smooth gradients for small errors and robust handling of outliers for large errors.](https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html)* And [sometimes it seems to be the same thing.](https://mlexplained.blog/2023/07/31/huber-loss-loss-function-to-use-in-regression-when-dealing-with-outliers/)

## Settings and scores

### Keras

For every framework there was ofcourse more settings than just the network parameters. Again, most of these were kept as they were in the given example. Some tweaks were done along the way when building in the other frameworks, but these settings were used for the baseline on Keras.

- Runtime:
    - 10,000,000 frames. This roughly equates to 15,000 episodes.
    - "Solved" set to a running reward score of >500.
    - No other early stopping parameter, ie max number of episodes set to 0 (inf).
- Exploration:
    - 50,000 random frames.
    - 1,000,000 greedy frames.
- Memory:
    - Max memory length set to 1,000,000 as per the DeepMind-paper suggestion.
    - Target network updated every 10,000 frames.
- Logged metrics
    - Episode count
    - Frame count (every 10,000 frame)
    - Running reward (mean of the last 100 episodes)
    - Max reward (out of the last 100 episodes)
    - Time since start

This took an 8-core cpu about five days, or 100h to finish.

In [5]:
plot_running_max([
    "./logs/keras/modelstats_08-12.csv",
    "./logs/keras/modelstats_09-12.csv",
    "./logs/keras/modelstats_10-12.csv",
    "./logs/keras/modelstats_11-12.csv",
    "./logs/keras/modelstats_12-12.csv"
    ], title="Keras - Reward over time")

As shown in the above graph some impressive high-scores were made! However, the mean score never really surpassed 400, which for even the uninitiated isn't really a hard score to beat. The models saved in /saved_models/keras/ were chosen by hand by looking at the peak or just after the peak of the running reward line.

Would the model have gotten a chance to train for this long again, additional logged information would've been added. Similar to our next example:

### Torch - With broken logic

As promised, here are stats for 6000 episodes (about 10h) that ran before it became obvious something went wrong. But a great learning experience!

In [6]:
plot_running_max([
    "./logs/torch/modelstats_16-12.csv",
    "./logs/torch/modelstats_17-12-1.csv",
], title="Torch - Reward over time - With broken logic")

In [7]:
plot_all([
    "./logs/torch/modelstats_16-12.csv",
    "./logs/torch/modelstats_17-12-1.csv",
], title="Torch - All recorded stats - With broken logic")

The main thing that went wrong with the above version was the complete disregard for actually learning. Every move was made at random, which still managed some impressive high-scores even though that's hard to admit. The fault was in some wrongly indented code in the ReplayBuffer class (/trainers/torch_trainer:58-77).

But what is obvious here is that without actually learning, the model actually get worse. Maybe if it had been running for another 10 hours it would stochastically get better, since it really should be random.

Let's have a look at PyTorch with proper settings, and an explaination of the logged values:

### Adjusted torch

In the working example of torch, as many settings as possible was kept the same. Most of the changes were made to what values were logged.

- Runtime:
    - 10,000,000 frames. This roughly equates to 15,000 episodes.
    - "Solved" set to a running reward score of >500.
    - No other early stopping parameter, ie max number of episodes set to 0 (inf).
- Exploration:
    - 50,000 random frames.
    - 1,000,000 greedy frames.
- Memory:
    - Max memory length set to 1,000,000 as per the DeepMind-paper suggestion.
    - Target network updated every 10,000 frames.
- Logged metrics:
    - Episode count
    - Episode reward (that episode) - Added
        - Not very intresting, but neither was the number 10,000 times 2,3,4 etc when showing frame count...
    - Running reward (mean of the last 100 episodes)
    - Max reward (out of the last 100 episodes)
    - Epsilon - Added
        - There were expectations of a more intresting slope from this, but it dives to its minimal value almost instantly.
        - De-click the other parameters in the graph above to see what happens.
    - Loss - Added
        - A more intresting metric for sure, but also benefits of not being shown in the same graph.
        - But not intresting enough to get it's own graph, again the suggestment is to use the interactive functionality of plotly express.
    - Time since start

Another optimizing feature that was implemented for this run of the model was the fix of a numpy related issue. See line:120 in *torch_trainer.py*. The line

<center>

`state_tensor = torch.from_numpy(np.array(state)).float().unsqueeze(0).to(device)`

</center>

looks kind of intimidating. It replaced an earlier version where the expression *np.arrays()* was used. But this cut runtime by a lot! Just compare to the "Broken Logic" graphs. From 6000 episodes in 10 hours, to 15,000 in the same time!

This took a *RTX 2070 Super* gpu about 10 hours to finish.

In [8]:
plot_running_max([
    "./logs/torch/modelstats_18-12.csv",
    "./logs/torch/modelstats_19-12.csv",
], title="Torch - Reward over time - (repaired logic)")

In [9]:
plot_all([
    "./logs/torch/modelstats_18-12.csv",
    "./logs/torch/modelstats_19-12.csv",
], title="Torch - All recorded stats - (repaired logic)")

In these graphs we can see that the result finally more closely mirrors our base example with Keras. Success! However, noteworthy is that for a similar amount of time and settings, the PyTorch model learns slower and never to quite the same running reward height. But this is close enough to be able to blame the stochasitc nature of the models.

## JAX

More of an honorable mention at this point, plus even more examples of human error!

This study was started with the assumption that a modern framework like JAX would run flawlessly on a Windows machine. This assumption was wrong, and maybe strange. But nevertheless, here are just under 1000 episodes (10 hours) of a JAX model.

The same settings and logging information was set. But as you will come to see, this version wasn't really countable. Mostly because of an error in code before running.

In [10]:
plot_running_max([
    "./logs/jax/modelstats_19-12.csv",
    "./logs/jax/modelstats_20-12.csv",
], title="JAX - Reward over time")

In [11]:
plot_all([
    "./logs/jax/modelstats_19-12.csv",
    "./logs/jax/modelstats_20-12.csv",
], title="JAX - All recorded stats")

There's nothing really intresting to see here, except maybe for the more defined epsilon slope. Again, this requires some zoom in, but no real effort was placed to show this in other ways pretty insignificant model graph. But what we can see is that the epsilon goes steadily down through the whole learning process, and the loss value is all over the place. This is because of human error! The target model was never initialized! See line:178 in */trainers/jax_trainer.py*. The computed loss is against itself! Embarrasing, but again a teachable moment.

## Summary

Well. To sum up this assignment, one could easily point out that mistakes are expensive. Trial & error has historically been the best teacher in this career so far, but when it comes to teaching computers to change their own behavours, and thus letting them eat electricity and compute for nights on end, time quickly removes trial & error as a viable option.

Were this assigment to be done again, it would be intresting to focus more on hyperparameter tuning in one model instead of spending time on getting new frameworks to just work, with identical settings. However, there's no regrets, since the feeling persists that writng models in Keras soon will be a clear part of the last generation. Having learned the basics of both PyTorch and JAX seems to be of higher value than getting a model to actually play well, even though the tuning part got a lower priority, and realizing that in real applications, the tuning part is probably what really is worth something.

There is something about an open assignment such as this that is very restricting. To paralelize: Ask a piano student to play a specific piece and they will play it, but ask them to play whatever they want and they don't know what to do with themselves. This has been especially true for this assignment. No parameters of what's too "out there", and nothing to say what's too small. Maybe that's how the future in this industry have in store, but doubtful.

### Used tools:
- VS Code on windows
- Google Docs for spelling (cells were copy pasted)
- Anthropic Claude for getting boiler plate code
- Plotly Express for graphs