In [46]:
import gymnasium as gym
import mediapy as media
import torch
import numpy as np
import pickle

# Actions:
# 0 - noop
# 2 - up
# 3 - down
env = gym.make("ALE/Pong-v5", render_mode="rgb_array", obs_type="rgb")
# env = gym.make("Pong-v0", render_mode="rgb_array", obs_type="rgb")

A.L.E: Arcade Learning Environment (version 0.8.1+unknown)
[Powered by Stella]
Game console created:
  ROM file:  /Users/Eugene/miniconda3/envs/rl_learning/lib/python3.12/site-packages/AutoROM/roms/pong.bin
  Cart Name: Video Olympics (1978) (Atari)
  Cart MD5:  60e0ea3cbe0913d39803477945e9e5ec
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is -1913132087


Pytorch implementation based on https://gist.github.com/xanderex-sid/ae6cd3ea0c3759c1e3f92835ebd6e158
Frame skip details: https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/

In [127]:
# Create parameters
in_dim = (80,80)
D = in_dim[0]*in_dim[1]
H = 200
gamma = 0.99
batch_size = 2 # param update everu
lr = 5e-4
device = "cpu"

def create_model():
  model = {}
  model['W1'] = torch.rand((H,D), dtype=torch.float) - 0.5
  # model['b1'] = torch.rand((H,1), dtype=torch.float16) - 0.5
  model['W2'] = torch.rand((1,H), dtype=torch.float) - 0.5

  for k in model:
    model[k] = model[k].to(device).requires_grad_()
  return model

def preprocess_frame(observation):
  # Crop out middle 160x160
  obs_torch = torch.from_numpy(observation[34:194, :])
  obs_grey = obs_torch.double().mean(axis=2).to(dtype=torch.uint8)
  obs_down = obs_grey[::2, ::2]
  obs_bin = (obs_down > 100).to(dtype=torch.float).ravel()

  return obs_bin

def unpreprocess(x):
  return x.reshape(in_dim).numpy()*255

def sigmoid(x):
  return 1.0 / (1 + torch.exp(-x))

def model_forward(x):
  h = model['W1'] @ x
  # h = h + model['b1'] # hidden layer
  h[h<0] = 0 # ReLU
  o = model['W2'] @ h
  p = sigmoid(o)
  return p

def out_to_action(p):
  # up is 1, down is 0
  return p < 0.5
  # discretize to -1, 0, 1
  return (abs(p) < 0.5) * np.sign(p)

  # if not thresh.any(): # noop if both false
  #   return 0
  # elif thresh[0]: # up
  #   return 2
  # else: # down
  #   return 3

def discount_rewards(r):

  d = torch.zeros(len(r))
  discounted_r = 0 # The index in the point
  for i in range(len(rewards)-1, -1, -1):
    if r[i] != 0:
      discounted_r = r[i]
    else:
      discounted_r *= gamma
    d[i] = discounted_r
  return d

frames = []
rewards, targets = [], []
probs = None
step, ep_idx = 0, 0
running_reward = None
prev_x = torch.zeros(D)
model = create_model()
opt = torch.optim.Adam(model.values(), lr=lr, betas=(0.9, 0.999))
loss_fn = torch.nn.BCELoss(reduction='none')


observation, info = env.reset()
for i in range(10):
  # Skip initial frames
  env.step(0)

# frames.append(unpreprocess(preprocess_frame(observation)))
# Loop for x steps:
while ep_idx < 1000000:
  # Preprocess
  curr_x = preprocess_frame(observation)
  # frames.append(unpreprocess(curr_x))

  x = curr_x - prev_x
  prev_x = curr_x

  p = model_forward(x.to(device))
  # Higher p means higher chance of 2
  # Lower p means higher chance of 3
  action = 2 if torch.rand(1) < p.item() else 3
  # with torch.no_grad():
  #   action = out_to_action(p.cpu())
  y = 1 if action == 2 else 0
  # logprobs.append(y - p)
  probs = p if probs is None else torch.cat((probs, p))
  targets.append(y)
  
  observation, reward, terminated, truncated, info = env.step(action) # We directly put next_state = state for recording logic
  rewards.append(reward)
  step += 1

  if terminated or truncated:
    ep_idx += 1

    ep_y = torch.tensor(targets, dtype=torch.float).to(device)

    discounted_ep_r = discount_rewards(rewards).to(device)
    discounted_ep_r -= torch.mean(discounted_ep_r)
    discounted_ep_r /= torch.std(discounted_ep_r)

    loss_per_element = loss_fn(probs, ep_y) # per element loss
    weighted_loss = loss_per_element * discounted_ep_r
    loss = weighted_loss.mean()

    loss.backward()

    if ep_idx % batch_size == 0:
      print(f"loss: {loss}")
      print(f"Reward sum: {reward_sum} running avg: {running_reward}")
      reward_sum = np.sum(rewards)
      running_reward = reward_sum if running_reward is None else running_reward * 0.99 + reward_sum * 0.01

      # Param update
      opt.step()
      opt.zero_grad()

    if ep_idx % 10000 == 0:
      e_reward = eval(vid_path=f'./logs/step_{step}')
      pickle.dump(model, open(f'./logs/pong_pixels_{step}.p', 'wb'))

    # Reset
    prev_x = torch.zeros(D)
    rewards, probs, targets = [], None, []
    observation, info = env.reset()
    for i in range(10):
      # Skip initial frames
      env.step(0)

# media.show_images(frames)
# media.show_video(frames, fps=30)

loss: -0.0054567670449614525
Reward sum: -21.0 running avg: None
loss: -0.009529346600174904
Reward sum: -21.0 running avg: -21.0
loss: -0.027547724545001984
Reward sum: -21.0 running avg: -21.0
loss: 0.01867346279323101
Reward sum: -21.0 running avg: -21.0
loss: 0.008285115472972393
Reward sum: -20.0 running avg: -20.99
loss: -0.019528791308403015
Reward sum: -21.0 running avg: -20.990099999999998
loss: -0.01085494551807642
Reward sum: -21.0 running avg: -20.990199
loss: -0.007816975004971027
Reward sum: -19.0 running avg: -20.970297010000003
loss: -0.004102419596165419
Reward sum: -20.0 running avg: -20.960594039900002
loss: -0.002229162026196718
Reward sum: -21.0 running avg: -20.960988099501
loss: 0.03462504595518112
Reward sum: -21.0 running avg: -20.961378218505992
loss: 0.0561051145195961
Reward sum: -21.0 running avg: -20.961764436320934
loss: -0.013055281713604927
Reward sum: -21.0 running avg: -20.962146791957725
loss: -0.0027643165085464716
Reward sum: -20.0 running avg: -20

A.L.E: Arcade Learning Environment (version 0.8.1+unknown)
[Powered by Stella]
Game console created:
  ROM file:  /Users/Eugene/miniconda3/envs/rl_learning/lib/python3.12/site-packages/AutoROM/roms/pong.bin
  Cart Name: Video Olympics (1978) (Atari)
  Cart MD5:  60e0ea3cbe0913d39803477945e9e5ec
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is 675587434


loss: -0.01286369189620018
Reward sum: 14.0 running avg: 11.068415267571812
loss: -0.010021879337728024
Reward sum: 11.0 running avg: 11.067731114896093
loss: -0.00930972583591938
Reward sum: 18.0 running avg: 11.137053803747131
loss: -0.015729622915387154
Reward sum: -5.0 running avg: 10.975683265709659
loss: -0.012941916473209858
Reward sum: 3.0 running avg: 10.895926433052562
loss: -0.026743507012724876
Reward sum: 4.0 running avg: 10.826967168722035
loss: -0.01816624030470848
Reward sum: 11.0 running avg: 10.828697497034815
loss: -0.00481400964781642
Reward sum: 17.0 running avg: 10.890410522064467
loss: -0.019763043150305748
Reward sum: -3.0 running avg: 10.751506416843823
loss: -0.014655865728855133
Reward sum: 14.0 running avg: 10.783991352675384
loss: -0.00910611730068922
Reward sum: 15.0 running avg: 10.826151439148632
loss: -0.018606621772050858
Reward sum: 7.0 running avg: 10.787889924757145
loss: -0.019567860290408134
Reward sum: 13.0 running avg: 10.810011025509574
loss: -

KeyboardInterrupt: 

In [122]:
def eval(vid_path=None):
  # Eval
  eval_env = gym.make("ALE/Pong-v5", render_mode="rgb_array", obs_type="rgb")

  observation, info = eval_env.reset()

  terminated, truncated = False, False
  frames = []
  reward_sum = 0
  prev_x = torch.zeros(D)
  while not terminated or truncated:
      
      curr_x = preprocess_frame(observation)
      x = curr_x - prev_x
      prev_x = curr_x
      frames.append(observation)

      with torch.no_grad():
        action = model_forward(x.to(device)).item()
      
      action = 2 if action > 0.5 else 3
      observation, reward, terminated, truncated, info = eval_env.step(action)
      reward_sum += reward
      
  if vid_path is not None:
    media.write_video(f"{vid_path}.mp4", frames)
  else:
     media.show_video(frames)
  return reward_sum


eval()

A.L.E: Arcade Learning Environment (version 0.8.1+unknown)
[Powered by Stella]
Game console created:
  ROM file:  /Users/Eugene/miniconda3/envs/rl_learning/lib/python3.12/site-packages/AutoROM/roms/pong.bin
  Cart Name: Video Olympics (1978) (Atari)
  Cart MD5:  60e0ea3cbe0913d39803477945e9e5ec
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is -792867331


0
This browser does not support the video tag.


-21.0

In [128]:

e_reward = eval(vid_path=f'./logs/ep_{ep_idx}')
pickle.dump(model, open(f'./logs/pong_pixels_ep_{ep_idx}.p', 'wb'))

A.L.E: Arcade Learning Environment (version 0.8.1+unknown)
[Powered by Stella]
Game console created:
  ROM file:  /Users/Eugene/miniconda3/envs/rl_learning/lib/python3.12/site-packages/AutoROM/roms/pong.bin
  Cart Name: Video Olympics (1978) (Atari)
  Cart MD5:  60e0ea3cbe0913d39803477945e9e5ec
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is -1275226604


: 

In [115]:
f

KeyError: 'render_fps'