In [None]:
# !python -m pip install "gymnasium[atari]"
# !python -m pip install "gymnasium[accept-rom-license, atari]"
# !pip install shimmy
# !pip install scikit-image

In [None]:
import matplotlib.pyplot as plt
import gymnasium as gym
import seaborn as sns
import torch.nn as nn
import pandas as pd
import numpy as np

import warnings
import ale_py
import shimmy
import joblib
import torch
import os
import gc


from pathlib import Path
from gym import wrappers
from typing import Callable
from dataclass import dataclass
from skimage.measure import block_reduce
from IPython.display import clear_output
from video_frame_cache import VideoFrameCache

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as ln

| **Value** | **Meaning** |
|:---------:|:-----------:|
| 0 | NOOP |
| 1 | FIRE |
| 2 | RIGHT |
| 3 | LEFT |
| 4 | RIGHTFIRE |
| 5 | LEFTFIRE |

# Model Instantiation

In [None]:
def show_obs(obs):
    plt.figure(figsize=(16,10))
    plt.imshow(obs)
    plt.show()
    return

In [None]:
def add_noise(probs, i_since_r, timer_i, c=1, sigma=None, buffer=None):        
    # autograd no inplace ops
    if buffer is None:
        buffer = timer_i // 2
    n = len(probs)
    if sigma is None:
        sigma = 2 / n
    noise = torch.normal(0., sigma, (1, n), requires_grad=True) # means, stds shared, size n
    noise = noise - noise.mean()
    
    scale = c * i_since_r / (timer_i - buffer)

    probs = probs + noise * scale
    pmin = torch.min(probs)
    if pmin < 0:
        probs = probs - pmin
        probs = probs / torch.sum(probs)
        
    return probs

def balance_lr(probs, i_since_r, timer_i, beta=.5, buffer=None):
#     2 : 'RIGHT'
#     3 : 'LEFT'
#     4 : 'RIGHTFIRE'
#     5 : 'LEFTFIRE'
    if i_since_r < timer_i / 4:
        return probs
    elif i_since_r < timer_i / 2:
        alpha = .5
    elif i_since_r < timer_i * 3 / 4:
        alpha = .8
    else:
        alpha = .99

    zero_probs = torch.zeros_like(probs, requires_grad=True)
    zero_probs[0,2] = (probs[0,3] - probs[0,2])
    zero_probs[0,3] = (probs[0,2] - probs[0,3])
    
    zero_probs[0,4] = (probs[0,5] - probs[0,4])
    zero_probs[0,5] = (probs[0,4] - probs[0,5])
    zero_probs = zero_probs * alpha * beta / 2
    
    probs = probs + zero_probs
    with torch.no_grad():
        assert torch.sum(probs).round(decimals=3) == 1, torch.sum(probs)
    return probs

def standardize(x):
    eps = np.finfo(np.float64).eps.item()
    x = (x - x.mean()) / (x.std() + eps)
    return x

def balance_all(probs, i_since_r, timer_i, beta=2):
    probs = probs + 2 * i_since_r / timer_i
    probs = softmax = nn.Softmax(dim=-1)(probs)
    return probs

In [None]:
class LinearModel(ln.Module):
    num_classes : int = 6
    hidden_sizes : Sequence = (128, 64)
    kernel_init : Callable = ln.initializers.glorot_normal

    @ln.compact
    def __call__(self, x, return_activations=False):
        activations = []
        for hidden_size in self.hidden_sizes:
            x = ln.Dense(
                hidden_size,
                kernel_init=self.kernel_init
            )(x)
            activations.append(x)
            x = jax.nn.swish(x)
            activations.append(x)
            
        x = ln.Dense(
            self.num_classes,
            kernel_init=self.kernel_init
        )(x)
        x = jax.nn.sigmoid(x)
        activations.append(x)
        return x if not return_activations else (x, activations)

In [None]:
@dataclass
class JAXReinforcementBase():
    model : ln.Module
    reward_state : RewardState
    optimizer : Callable
    obs_shape : tuple
    xmin : int=26
    xmax : int=196
    ymin : int=10
    ymax : int=144
    downsample : str='horizontal'
    timer_i : int=1000
    corner_correct : bool=True
    
    def __post_init__(self):
        self.episode_rewards = list()
        self.episode_probss = list()
        self.probss = list()
        
        self.key = jax.random.key()
        key1, key2 = jax.random.split(self.key)
        init_obs = jax.random.normal(key1, self.obs_shape)
        init_x = self.pre_process(init_obs)
        self.params = self.model.init(key2, init_x)
        
        self.opt_state = optimizer.init(self.params)
        self.loss_grad_fn = jax.value_and_grad(self.loss)
        return
        
    def step(self, obs, i_since_r):
        x = self.pre_process(obs) - self.prev_obs
        self.prev_obs = x
        
        model_probs = self.model.apply(self.params, x)
        probs = self.process_probs(model_probs, i_since_r)
        self.episode_probss.append(probs)

        key = jax.random.split(self.key)
        action = jnp.random.choice(key, actions, p=probs).item()   
        
        adj_reward = self.reward_state.step(
            action, 
            probs, 
            reward, 
            info, 
            i_since_r, 
            timer_i=self.timer_i,
        ) 
        return action
        
    def log_probss(self):
        log_probss = jnp.log(jnp.concatenate(self.probss, axis=0))
        return log_probss
        
    def show_layers(self):
        jax.tree_util.tree_map(lambda x: x.shape, self.params)
        return
    
    #maybe move concatenates into loss or batch backward?
    @static_method
    def loss(log_probss, discounted_rewards):
        loss = jnp.dot(log_probss, discounted_rewards)
        print('loss shape', loss.shape)
        return loss
    
    @static_method
    @jax.jit
    def _backward(params, opt_state, loss_grad_fn, log_probss, discounted_rewards):
        loss_val, grads = loss_grad_fn(log_probss, discounted_rewards)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state
    
    def batch_backward(self):
        log_probss = self.log_probss()
        assert log_probss.ndim ==2, log_probss.shape
        discounted_rewards = self.reward_state.discounted_rewards()
        self.params, self.opt_state = self._backward(
            self.params, 
            self.opt_state,
            self.loss_grad_fn,
            log_probss,
            discounted_rewards,
        )
        
        self.batch_reset()
        gc.collect()
        return self
    
    def save(self, bytes_path):
        model_bytes = flax.serialization.to_bytes(self.params)
        # could chunk this
        with open(bytes_path, 'wb') as f:
            f.write(model_bytes)
        return
    
    def load(self, bytes_path):
        with open(bytes_path, 'rb') as f:
            model_bytes = f.read()
            
        self.params = flax.serialization.from_bytes(self.params, model_bytes)
        return
    
    def episode_reset(self, truncated=False):
        episode_reward = self.reward_state.episode_reset(truncated=truncated)
        if not truncated:
            self.episode_rewards.append(episode_reward)
            self.probss += self.episode_probss
            
        self.episode_probss.clear()
        self.prev_obs = 0
        return
    
    def batch_reset(self):
        self.reward_state.batch_reset()
        self.probss.clear()
        return
        
    def pre_process(self, obs):
        assert obs.shape == self.obs_shape
        x = x[self.xmin:self.xmax,self.ymin:self.ymax]
        
        if selfdownsample == 'horizontal':
            x = x[::2,:]
        x[x == 144] = 0 # erase background (background type 1)
        x[x == 109] = 0 # erase background (background type 2)
        x[x != 0] = 1 # everything else to 1
        print(np.unique(x))
        assert len(np.unique(x)) == 2, np.unique(x)
        
        if self.downsample == 'max_pool':
            # ideally downsampling would be done before changing values in place, but this way the background is ignored easily
            x = block_reduce(x, (2, 2), np.amax)
        
        x = jnp.asarray(np.expand_dims(x.ravel().float(), axis=0))
        return x
    
    def process_probs(
        self,
        probs,
        i_since_r,
    ):
        truncated = i_since_r > timer_i

        if corner_correct: # heavily biases agent from getting 'stuck' in corner
            probs = add_noise(probs, i_since_r, timer_i)
            probs = balance_lr(probs, i_since_r, timer_i)
        else:
            probs = balance_all(probs, i_since_r, timer_i)

        if torch.round(torch.sum(probs), decimals=4) != 1:
            raise ValueError('Probs do not sum to 1')

        return probs, truncated

In [None]:
class EnvIter():
    def __init__(self, game_name, pre_fn, max_episodes=100, **make_kwargs):
        self.env = gym.make(game_name, **make_kwargs)
        self.max_episodes = max_episodes
        self.n_episodes = -1
        self.reset()

        self.pre_fn = preprocess_fn
        self.prev_obs = None
        return
    
    def standard_step(self):
        obs, reward, terminated, truncated, info = env.step(action)
        if reward > 0:
            self.last_i = i
            self.i_since_r = 0
        else:
            self.i_since_r += 1
            
        self.i += 1
        return self.i, self.i_since_r, obs, reward, terminated, truncated, info
        
    def reset_step(self):
        obs, info = self.env.reset()
        return 0, 0, x, obs, 0, False, False, info
    
    def reset(self, truncated=False):
        self.reset_ = True
        if not truncated:
            self.n_episodes += 1
        self.i = 0
        self.last_i = 0
        return
    
    def iter_all(self):
        while self.n_episodes <= self.max_episodes:
            if self.reset_:
                self.reset_ = False
                yield self.reset_step()
            else:
                yield self.standard_step()
        env.close()
        return

In [None]:
class RewardState():
    def __init__(self, gamma=.99, reward_dict=None):
        self.gamma = gamma
        if reward_dict is None:
            reward_dict = {
                'life_penalty' : 15,
                'nofire_penalty' : .1,
                'comeback_reward' : 10,
            }
            
        self.reward_dict = reward_dict
        self.reward_sum = 0
        self.adj_reward_sum = 0
        self.episode_rewards = list()
        self.probss = list()
        return
    
    def episode_reset(self, truncated=False):           
        self.reward_sum = 0
        self.adj_reward_sum = 0
        self.prev_lives = 3
        if not truncated:
            self.batch_rewards += self.episode_rewards
        self.episode_rewards.clear()
        return episode_reward
    
    def batch_reset(self):
        self.episode_reset()
        self.batch_rewards.clear()
        return
    
    def step(self, reward, *args, **kwargs):
        adj_reward = self.modify(reward, *args, **kwargs)
        self.reward_sum += reward
        self.adj_reward_sum += adj_reward
        self.rewards.append(adj_reward)
        return adj_reward
    
    def modify(self, action, reward, info, i_since_r, timer_i=1000):
        if info['lives'] < self.prev_lives:
            reward += reward_dict['life_penalty']
        if reward <= 0 and action in [1,4,5]:
            reward += reward_dict['nofire_penalty']
        if reward > 0 and i_since_r > timer_i / 2:
            reward += reward_dict['comeback_reward']
            
        self.prev_lives = info['lives']
        return reward
    
    @static_method
    @numba.jit
    def _discount_rewards(rewards, gamma):
        running_add = 0
        discounted_rewards = list()
        for reward in reversed(rewards):
            running_add = running_add * gamma + reward
            discounted_rewards.append(running_add)
            
        discounted_rewards = jnp.asarray(list(reversed(discounted_rewards)))
        discounted_rewards = discounted_rewards - discounted_rewards.mean()
        discounted_rewards = discounted_rewards / discounted_rewards.std()
        return discounted_rewards
    
    def discount_rewards(self):
        discounted_rewards = self._discount_rewards(self.rewards, self.gamma)
        self.rewards.clear()
        return discounted_rewards
    
    # uses .dot with discounted rewards as a loss

In [None]:
action_dict = {
    0 : 'NOOP',
    1 : 'FIRE',
    2 : 'RIGHT',
    3 : 'LEFT',
    4 : 'RIGHTFIRE',
    5 : 'LEFTFIRE'
}
actions = sorted(action_dict)

obs_shape = (210, 160)
lr = .01
corner_correct = False
timer_i = 200 # number of iterations without reward before noise is intentionally greater than signal

batch_size = 64
max_i = 5000

In [None]:
model = LinearModel(
    num_classes=len(actions),
)

reward_state = RewardState()

base = JAXReinforcementBase(
    model,
    reward_state,
    optax.adam(learning_rate=lr),
    obs_shape,
    corner_correct=corner_correct,
    timer_i=timer_i,
)

video_cache = VideoFrameCache()

# EnvIter should change so that preprocessing is only in base
env_iter = EnvIter(
    'ALE/DemonAttack-v5', 
    base.pre_process,
    n_episodes=10,
    obs_type='grayscale', 
    render_mode=None,
)

In [None]:
for i, i_since_r, obs, reward, terminated, truncated, info in env_iter.iter_all(): 
    video_cache.cache_append(obs)
    model_probs, probs, action, loop_truncated = base.step(x, i_since_r)
    truncated = truncated or loop_truncated or i >= max_i
    ######################################################
        
    if terminated: # an episode finished intentionally
        print(f'\nEpisode {episode_number} of {n_episodes}, Iterations : {i}, Reward : {reward_sum}       \n\n', end='\r')
        base.episode_reset(truncated=False)
        if not env_iter.n_episodes % batch_size:           
            base.batch_backward()
    elif truncated: # an episode terminated unexpectedly, shouldn't maintain results
        base.episode_reset(truncated=True)
    else:
        print(i, end='                          \r')
        
    if terminated or truncated:
        video_cache.finish(f'episode_vids/episode_{env_iter.n_episodes}.mp4')
        env_iter.reset(truncated=truncated)
        
