In [None]:
# !python -m pip install "gymnasium[atari]"
# !python -m pip install "gymnasium[accept-rom-license, atari]"
# !pip install shimmy
# !pip install scikit-image

In [None]:
import matplotlib.pyplot as plt
import gymnasium as gym
import seaborn as sns
import torch.nn as nn
import pandas as pd
import numpy as np

import warnings
import ale_py
import shimmy
import joblib
import torch
import os

from gym import wrappers
from skimage.measure import block_reduce
from IPython.display import clear_output
from video_frame_cache import VideoFrameCache

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as ln

| **Value** | **Meaning** |
|:---------:|:-----------:|
| 0 | NOOP |
| 1 | FIRE |
| 2 | RIGHT |
| 3 | LEFT |
| 4 | RIGHTFIRE |
| 5 | LEFTFIRE |

# Run Constants

In [None]:
render = False # render video output?
show = False
no_grad = False
corner_correct = False
plot_action_every = 0 # 64

timer_i = 2000 # number of iterations without reward before noise is intentionally greater than signal
dropout_i = 1000

record_actions = False # will be handled as true if plot_action_every > 0
record_probs = False
record_rewards = True

# Model Instantiation

In [None]:
def moving_average(a, window_size) :
    ret = np.cumsum(a, dtype=float)
    ret[window_size:] = ret[window_size:] - ret[:-window_size]
    return ret[window_size - 1:] / window_size

In [None]:
def show_obs(obs):
    """ 
    Simple display of image observation 
    
    Args:
    `obs` : np.ndarray
    - Observation from the environment
    """
    plt.figure(figsize=(16,10))
    plt.imshow(obs)
    plt.show()
    return

In [None]:
def add_noise(probs, i_since_r, timer_i, c=1, sigma=None, buffer=None):        
    # autograd no inplace ops
    if buffer is None:
        buffer = timer_i // 2
    n = len(probs)
    if sigma is None:
        sigma = 2 / n
    noise = torch.normal(0., sigma, (1, n), requires_grad=True) # means, stds shared, size n
    noise = noise - noise.mean()
    
    scale = c * i_since_r / (timer_i - buffer)

    probs = probs + noise * scale
    pmin = torch.min(probs)
    if pmin < 0:
        probs = probs - pmin
        probs = probs / torch.sum(probs)
        
    return probs

def balance_lr(probs, i_since_r, timer_i, beta=.5, buffer=None):
#     2 : 'RIGHT'
#     3 : 'LEFT'
#     4 : 'RIGHTFIRE'
#     5 : 'LEFTFIRE'
    if i_since_r < timer_i / 4:
        return probs
    elif i_since_r < timer_i / 2:
        alpha = .5
    elif i_since_r < timer_i * 3 / 4:
        alpha = .8
    else:
        alpha = .99

    zero_probs = torch.zeros_like(probs, requires_grad=True)
    zero_probs[0,2] = (probs[0,3] - probs[0,2])
    zero_probs[0,3] = (probs[0,2] - probs[0,3])
    
    zero_probs[0,4] = (probs[0,5] - probs[0,4])
    zero_probs[0,5] = (probs[0,4] - probs[0,5])
    zero_probs = zero_probs * alpha * beta / 2
    
    probs = probs + zero_probs
    with torch.no_grad():
        assert torch.sum(probs).round(decimals=3) == 1, torch.sum(probs)
    return probs

def standardize(x):
    eps = np.finfo(np.float64).eps.item()
    x = (x - x.mean()) / (x.std() + eps)
    return x

def balance_all(probs, i_since_r, timer_i, beta=2):
    probs = probs + 2 * i_since_r / timer_i
    probs = softmax = nn.Softmax(dim=-1)(probs)
    return probs

In [None]:
def process_probs(probs, i_since_r, timer_i=1000, corner_correct=True):
    truncated = i_since_r > timer_i
        
    if corner_correct: # heavily biases agent from getting 'stuck' in corner
        probs = add_noise(probs, i_since_r, timer_i)
        probs = balance_lr(probs, i_since_r, timer_i)
    else:
        probs = balance_all(probs, i_since_r, timer_i)

    if torch.round(torch.sum(probs), decimals=4) != 1:
        raise ValueError('Probs do not sum to 1')
        
    return probs, truncated

In [None]:
# OBS_SHAPE = (210, 160)
action_dict = {
    0 : 'NOOP',
    1 : 'FIRE',
    2 : 'RIGHT',
    3 : 'LEFT',
    4 : 'RIGHTFIRE',
    5 : 'LEFTFIRE'
}
actions = sorted(action_dict)

In [None]:
from flax import linen as ln

class LinearModel(ln.Module):
    num_classes : int = 6
    hidden_sizes : Sequence = (512, 256, 256, 128)
    kernel_init : Callable = nn.linear.default_kernel_init

    @ln.compact
    def __call__(self, x, return_activations=False):
        activations = []
        for hidden_size in self.hidden_sizes:
            x = ln.Dense(
                hidden_size,
                kernel_init=self.kernel_init
            )(x)
            activations.append(x)
            x = jax.nn.swish(x)
            activations.append(x)
            
        x = ln.Dense(
            self.num_classes,
            kernel_init=self.kernel_init
        )(x)
        x = jax.nn.sigmoid(x)
        activations.append(x)
        return x if not return_activations else (x, activations)

In [None]:
class ReinforcementBase(nn.Module):
    def __init__(
        self, 
        gamma=.99, 
        lr=.1,
        xmin=26, 
        xmax=196, 
        ymin=10, 
        ymax=144,
        downsample = 'horizontal',
    ):
        super(LinearReinforcement, self).__init__()
        self.gamma = gamma
        self.xmin = xmin
        self.xmax = xmax
        self.ymin = ymin
        self.ymax = ymax
        self.downsample = downsample
        
        # preprocess args
        input_shape = (xmax - xmin, ymax - ymin)

        self.log_probss = list()
        self.rewards = list()
        self.episode_losses = list()
        return

    def episode_end(self):
        assert len(self.log_probss)
        episode_loss = self.reward_loss()
        self.episode_losses.append(jnp.expand_dims(episode_loss, 0))
        
        self.log_probss.clear()
        self.rewards.clear()
        return self
    
    def batch_backward(self):
        assert len(self.episode_losses)
#         self.optimizer.zero_grad() # make sure this shouldn't go in reward_loss
        loss = jnp.sum(self.episode_losses)
#         loss.backward()
#         self.optimizer.step()
        self.episode_losses.clear()
        return self
    
    def reset(self):
        self.log_probss.clear()
        self.rewards.clear()
        return
        
    def preprocess(self, x):
        assert x.shape == (210, 160)
        x = x[self.xmin:self.xmax,self.ymin:self.ymax]
        if self.downsample == 'horizontal':
            x = x[::2,:]
        x[x == 144] = 0 # erase background (background type 1)
        x[x == 109] = 0 # erase background (background type 2)
        x[x != 0] = 1 # everything else to 1
        if self.downsample == 'max_pool':
            # ideally downsampling would be done before changing values in place, but this way the background is ignored easily
            x = block_reduce(x, (2, 2), np.amax)
        x = jnp.asarray(x.ravel().float().expand_dims(axis=0))
        return x

In [None]:
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import gc

save_path = 'model.pt'
always_collect = True
clear = True
resume = False # resume training from previous checkpoint

In [None]:
class EnvIter():
    def __init__(self, game_name, pre_fn, max_episodes=100, **make_kwargs):
        self.env = gym.make(game_name, **make_kwargs)
        self.max_episodes = max_episodes
        self.n_episodes = -1
        self.reset()

        self.pre_fn = preprocess_fn
        self.prev_obs = None
        return
    
    def standard_step(self):
        obs, reward, terminated, truncated, info = env.step(action)
        x = self.pre_fn(obs)
        if self.prev_obs is None:
            raise ValueError('Reset problem')
        x = x - self.prev_obs
        self.prev_obs = x
        return x, obs, reward, terminated, truncated, info
        
    def reset_step(self):
        obs, info = self.env.reset()
        x = self.pre_fn(obs)
        self.prev_obs = x
        return x, obs, 0, False, False, info
    
    def reset(self):
        self.reset_ = True
        self.n_episodes += 1
        return
    
    def iter_all(self):
        while self.n_episodes <= self.max_episodes:
            if self.reset_:
                self.reset_ = False
                yield self.reset_step()
            else:
                yield self.standard_step()
        
env_iter = EnvIter(
    'ALE/DemonAttack-v5', 
    model.preprocess,
    n_episodes=10,
    obs_type='grayscale', 
    render_mode=None,
)

In [None]:
class RewardState():
    def __init__(self, reward_dict=None):
        if reward_dict is None:
            reward_dict = {
                'life_penalty' : 15,
                'nofire_penalty' : .1,
                'comeback_reward' : 10,
            }
            
        self.reward_dict = reward_dict
        self.reward_sum = 0
        self.adj_reward_sum = 0
        self.rewards = list()
        self.probss = list()
        return
    
    def reset(self):
        self.reward_sum = 0
        self.adj_reward_sum = 0
        self.prev_lives = 3
        return
    
    def step(self, reward, probs, *args, **kwargs):
        adj_reward = self.modify(reward, *args, **kwargs)
        
        self.reward_sum += reward
        self.adj_reward_sum += adj_reward
        self.probss.append(probs)
        self.rewards.append(adj_reward)
        return adj_reward
    
    def modify(self, action, reward, info, prev_lives, i_since_r, timer_i=1000):
        if info['lives'] < self.prev_lives:
            reward += reward_dict['life_penalty']
        if reward <= 0 and action in [1,4,5]:
            reward += reward_dict['nofire_penalty']
        if reward > 0 and i_since_r > timer_i / 2:
            reward += reward_dict['comeback_reward']
            
        self.prev_lives = info['lives']
        return reward
    
    def discount_rewards(self):
        running_add = 0
        discounted_rewards = list()
        for reward in reversed(self.rewards):
            running_add = running_add * self.gamma + reward
            discounted_rewards.append(running_add)
            
        discounted_rewards = jnp.asarray(list(reversed(discounted_rewards)))
        discounted_rewards = discounted_rewards - discounted_rewards.mean()
        discounted_rewards = discounted_rewards / discounted_rewards.std()
        self.rewards.clear()
        return discounted_rewards
    
    def log_probss(self):
        log_probss = jnp.log(jnp.concatenate(self.probs, axis=0))
        self.probs.clear()
        return log_probss
    # uses .dot with discounted rewards as a loss
    
reward_state = RewardState()
video_cache = VideoFrameCache()

In [None]:
batch_size = 64
apply_stop = False

last_i = 0
drop_i = 16
i = 0

for x, obs, reward, terminated, truncated, info in env_iter.iter_all(): 
    video_cache.cache_append(obs)
    model_probs = model(x)
    
    i_since_r = i - last_i
    probs, ptruncated = process_probs(model_probs, i_since_r, timer_i=timer_i, corner_correct=corner_correct)

    action = jnp.random.choice(actions, p=probs).item()   
    adj_reward = reward_state.step(action, probs, reward, info, prev_lives, i_since_r, timer_i=timer_i) 
    
    ######################################################
    if apply_stop:
        truncated = truncated or ptruncated
    elif i_since_r > 100_000:
        warnings.warn('')
        truncated = True
        
    if terminated: # an episode finished
        print(f'\nEpisode {episode_number} of {n_episodes}, Iterations : {i}, Reward : {reward_sum}       \n\n', end='\r')
        # Finish The Episode.save
        model.episode_backward()
        # Finish the Batch
        if not env_iter.n_episodes % batch_size:           
            model.batch_backward()
            if save_path:
                model.save(save_path)
            gc.collect()
    elif truncated: # an episode terminated unexpectedly, shouldn't maintain results
        model.episode_losses.clear()
      
    else:
        print(i, end='                          \r')
        
    if terminated or truncated:
        video_cache.finish(f'episode_vids/episode_{env_iter.n_episodes}.mp4')
        env_iter.reset()
        reward_state.reset()
        
        prev_lives = 3 # for new episode adjustment

env.close()

In [None]:
!tensorboard --logdir=runs

if record_probs:
    plot_probs(prob_list)
if record_rewards:
    plot_rewards(reward_list, window_size=200)

[TensorBoard](localhost:6006/)