In [1]:
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np

import torch.nn as nn
import torch
import torch.functional as f

In [2]:
from ale_py import ALEInterface
ale = ALEInterface()

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [3]:
from ale_py.roms import Skiing
ale.loadROM(Skiing)

Game console created:
  ROM file:  /home/yukinon/code/rl-atari-skiing/rl-skiing/lib/python3.11/site-packages/AutoROM/roms/skiing.bin
  Cart Name: Skiing (1980) (Activision) [!]
  Cart MD5:  b76fbadc8ffb1f83e2ca08b6fb4d6c9f
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is 1713860199


In [4]:
from typing import Any, Text

from ale_py.env import gym as ale_gym

# Patch to allow rendering Atari games.
# The AtariEnv's render method expects the mode to be in self._render_mode
# (usually initialized with env.make) instead of taking mode as a param.
_original_atari_render = ale_gym.AtariEnv.render


def atari_render(self, mode: Text = 'rgb_array') -> Any:
  original_render_mode = self._render_mode
  try:
    self._render_mode = mode
    return _original_atari_render(self)
  finally:
    self._render_mode = original_render_mode


ale_gym.AtariEnv.render = atari_render

In [56]:
env = gym.make('ALE/Skiing-v5', render_mode='human')

In [6]:
def epsilon_greedy_policy(state, Q, epsilon, n_action):
  probs = torch.ones(n_action) * epsilon / n_action
  best_action = torch.argmax(Q[state]).item()
  probs[best_action] += 1.0 - epsilon
  return torch.multinomial(probs, 1).item()

# def epsilon_greedy_policy(observation, Q, epsilon, n_action):
#   A = np.ones(n_action, dtype=int) * epsilon / n_action
#   best_action = np.argmax(Q[observation])
#   A[best_action] += (1.0 - epsilon)
#   return A

In [71]:
from collections import defaultdict

def sarsa(env, gamma, n_episode, alpha, epsilon):
  n_action = env.action_space.n
  Q = defaultdict(lambda : torch.empty(n_action))

  for episode in range(n_episode):
    state = env.reset()

    state = custom_hash(state)
    
    is_done = False

    action = epsilon_greedy_policy(state, Q, epsilon, n_action)

    i = 0

    while not is_done:
      new_state, reward, is_done, _, _ = env.step(action)
      new_state = custom_hash(new_state)
      new_action = epsilon_greedy_policy(new_state, Q, epsilon, n_action)
      td_delta = reward + gamma * torch.max(Q[new_state]) - Q[state][action]
      Q[state][action] += alpha * td_delta
      length_episode[episode] += 1
      total_reward_episode[episode] += reward

      i += 1

      if i % 100 == 0:
        print(new_state)

      state = new_state
      action = new_action
    
    print("Episode:", episode, "Total Reward:", total_reward_episode[episode])

  policy = {}
  for state, actions in Q.items():
    policy[state] = torch.argmax(actions).item()

  return Q, policy

In [75]:
env.reset()

(array([[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        ...,
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]], dtype=uint8),
 {'lives': 0, 'episode_frame_number': 0, 'frame_number': 10600})

In [76]:
new_state, reward, is_done, _, _ = env.step(env.action_space.sample())

In [99]:
new_state = new_state[len(new_state) // 2 : len(new_state) // 2 + 20]

In [100]:
new_state.shape

(20, 160, 3)

In [81]:
new_state2, reward, is_done, _, _ = env.step(env.action_space.sample())

In [53]:
new_state[1][1][0] = 4

In [107]:
test_1 = custom_hash(new_state)
test_2 = custom_hash(new_state2)

print(test_1)
print(test_2)

579328207535684514385190949815799908127862702687037699782937585709161028504465392180198176506681393
579328207535684514385190949815799908127862702687037699782937585709161028504465392180198176506681393


In [91]:
import hashlib

def custom_hash(arr):
    m = hashlib.sha256()
    m.update(str(arr).encode('utf-8'))
    return int(m.hexdigest(), 36)

In [9]:
n_episode = 20

length_episode = [0] * n_episode
total_reward_episode = [0] * n_episode

In [11]:
gamma = 0.1
alpha = 0.2
epsilon = 0.1

In [72]:
optimal_Q, optimal_policy = sarsa(env, gamma, n_episode, alpha, epsilon)

263954951215689104011182707965790003866
263954951215689104011182707965790003866
263954951215689104011182707965790003866
263954951215689104011182707965790003866
263954951215689104011182707965790003866
263954951215689104011182707965790003866
263954951215689104011182707965790003866
263954951215689104011182707965790003866


KeyboardInterrupt: 

In [134]:
new_state, reward, is_done, _, _ = env.step(env.action_space.sample())

In [178]:
n_action = env.action_space.n
Q = defaultdict(lambda : torch.empty(n_action, dtype=torch.int16))

state = env.reset()

state = custom_hash(state)

is_done = False

action = epsilon_greedy_policy(state, Q, epsilon, n_action)


In [179]:
action

array([0.03333333, 0.03333333, 0.93333333])

In [174]:
Q

defaultdict(<function __main__.<lambda>()>,
            {1: tensor([11168,  7351, 32561], dtype=torch.int16)})

In [124]:
np.argmax(Q)

0

In [107]:
Q

defaultdict(<function __main__.<lambda>()>,
            {1: tensor([3.6098e+08, 3.0805e-41, 0.0000e+00])})

In [90]:
new_state.shape

(210, 160, 3)

In [74]:
env.observation_space

Box(0, 255, (210, 160, 3), uint8)