# Deep Reinforcement Learning

# Taxi game

[Taxi_game documentation](https://gymnasium.farama.org/environments/toy_text/taxi/#taxi)

<div>
<img src="images/taxi_game.png" width="500"/>
</div>

## State space

## Action space

## Rewards

## Implementacija v Gymnasium

`$ pip install gymnasium`

`$ pip install gymnasium[toy-text]`

Dodatno bomo potrebovali:

```
matplotlib
tensorflow
stable_baselines3[extra]>=2.0.0a9
gym
```

In py `pytorch`.

In [2]:
import gymnasium as gym
import time

env = gym.make("Taxi-v3", render_mode="human").env
env.reset(seed=123)
env.render()
time.sleep(3)
env.close()

In [4]:
import gymnasium as gym
import time

env = gym.make("Taxi-v3", render_mode="human").env
env.reset(seed=123)
env.render()

print("State Space {}".format(env.observation_space))
print("Action Space {}".format(env.action_space))

time.sleep(3)
env.close()

State Space Discrete(500)
Action Space Discrete(6)


In [None]:
import gymnasium as gym
import time

env = gym.make("Taxi-v3", render_mode="human")
env.reset(seed=123)

print(env.s)
print(env.P[env.s])
env.render()

time.sleep(3)
env.close()


In [None]:
import gymnasium as gym
import time

env = gym.make("Taxi-v3", render_mode="human")
env.metadata["render_fps"] = 5
env.reset(seed=123)

total_reward = 0
steps = 0

actions = [1, 3, 3, 1, 1, 4, 0, 0, 2, 2, 2, 2, 1, 1, 5] # shortest path for seed 123

done = False
while not done:
    env.render()
    action = actions[steps]    
    observation, reward, done, truncated, info = env.step(action)
    
    total_reward += reward
    steps += 1

env.render()
time.sleep(2)
env.close()

print("Steps taken", steps)
print("Total reward", total_reward)


## Baseline - Random walking

In [None]:
import gymnasium as gym
import random
import time

env = gym.make("Taxi-v3", render_mode="human")
env.metadata["render_fps"] = 200
env.reset(seed=123)

total_reward = 0
steps = 0

done = False
while not done:
    env.render()

    rewards = env.P[env.s]
    max_reward = max(rewards.values(), key=lambda x: x[0][2])[0][2] # find max reward in current state
    best_actions = {action:outcome for action, outcome in rewards.items() if outcome[0][2]==max_reward} # collect all actions that will give this max reward
    action = random.choice(list(best_actions.keys())) # select a random action out of these best actions
    
    observation, reward, done, truncated, info = env.step(action) # do the action
    
    total_reward += reward
    steps += 1

env.render()
time.sleep(2)
env.close()

print("Steps taken", steps)
print("Total reward", total_reward)


## Q-Learning

$new \ Q(state, action) = (1 - \alpha) * Q(state, action) + \alpha(reward + \gamma maxQ(next \ state, all \ actions))$

* $Q(state, action)$ - q-value za state v katerem se trenutno nahajamo in akcijo katero bomo naredili
* $\alpha$ - predstavlja learning rate ($0 < \alpha \le 1$). Predstavlja za koliko želimo posodobiti novi q-value
* $reward$ - reward katerega smo prejeli, ko smo opravili naš action
* $\gamma$ - predstavlja **discount factor**. Ta nam pove koliko pomembnosti dajemo na prihodnjo nagrado oziroma Q vrednost. Vrednost blizu 1 pomeni, da želimo povdariti končni rezultat. Vrednosti proti 0 pomeni, da želimo gledati le trenutno nagrado (greedy policy).
* $maxQ(next \ state, all \ actions)$ - pogledamo v katerem stanju se bomo znašli po naši akciji. Nato vzememo največjo q-value za tiste `state-action` pare.

### Koda

In [None]:
import random
import gymnasium as gym
import numpy as np
import time

env = gym.make("Taxi-v3", render_mode="human")
env.metadata["render_fps"] = 200

# Q-Table
q_table = np.zeros([env.observation_space.n, env.action_space.n])

# Hyperparameters
alpha = 0.1
gamma = 0.6
epsilon = 0.1

# Performance tracking
total_reward = 0
steps = 0

# Env setup
done = False
state, info = env.reset(seed=123)
while not done:
    env.render()

    if random.uniform(0,1) <= epsilon:
        action = env.action_space.sample() # Explore by making random action
    else:
        action = np.argmax(q_table[state]) # Be greedy and make best action
    next_state, reward, done, truncated, info = env.step(action)
    print("State", state, "Action", action, "Reward", reward, "Next state", next_state, "Done", done, "Turncated", truncated, "Step", steps) if steps % 100 == 0 else None

    # Update Q-Table
    old_value = q_table[state, action]
    next_max = np.max(q_table[next_state])
    new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)
    q_table[state, action] = new_value


    # Update performance tracking
    total_reward += reward
    steps += 1

env.render()
time.sleep(2)
env.close()

print("Steps taken", steps)
print("Total reward", total_reward)
print(q_table[341])

In [None]:
import random
import gymnasium as gym
import numpy as np
import time

env = gym.make("Taxi-v3", render_mode="human")
env.metadata["render_fps"] = 200

# Q-Table
q_table = np.zeros([env.observation_space.n, env.action_space.n])

# Hyperparameters
alpha = 0.1
gamma = 0.6
epsilon = 0.1

# Performance tracking
total_reward = 0
steps = 0

# Env setup
done = False
state, info = env.reset(seed=123)
while not done:
    env.render()

    if random.uniform(0,1) <= epsilon:
        action = env.action_space.sample() # Explore by making random action
    else:
        action = np.argmax(q_table[state]) # Be greedy and make best action
    next_state, reward, done, truncated, info = env.step(action)
    print("State", state, "Action", action, "Reward", reward, "Next state", next_state, "Done", done, "Turncated", truncated, "Step", steps) if steps % 20 == 0 else None
    # vvvv    HERE HERE HERE    vvvvv
    if truncated:
        done = True
        reward = -20
    # ^^^^    HERE HERE HERE    ^^^^

    # Update Q-Table
    old_value = q_table[state, action]
    next_max = np.max(q_table[next_state])
    new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)
    q_table[state, action] = new_value


    # Update performance tracking
    total_reward += reward
    steps += 1

env.render()
time.sleep(2)
env.close()

print("Steps taken", steps)
print("Total reward", total_reward)
print(q_table[341])


In [None]:
import random
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

env = gym.make("Taxi-v3", render_mode=None)
env.metadata["render_fps"] = 300

# Q-Table
q_table = np.zeros([env.observation_space.n, env.action_space.n])

# Hyperparameters
alpha = 0.1
gamma = 0.6
epsilon = 0.1

# vvvv    HERE HERE HERE    vvvv
episodes = 10_001
total_rewards = []
total_steps = []
for ep in range(episodes):
    # ^^^^    HERE HERE HERE    ^^^^
    # Performance tracking
    total_reward = 0
    steps = 0

    # Env setup
    done = False
    state, info = env.reset()  # (seed=123)
    while not done:
        if random.uniform(0, 1) <= epsilon:
            action = env.action_space.sample()  # Explore by making random action
        else:
            action = np.argmax(q_table[state])  # Be greedy and make best action
        next_state, reward, done, truncated, info = env.step(action)
        if truncated:
            done = True
            reward = -20

        # Update Q-Table
        old_value = q_table[state, action]
        next_max = np.max(q_table[next_state])
        new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)
        q_table[state, action] = new_value
        state = next_state

        # Update performance tracking
        total_reward += reward
        steps += 1
    # vvvv    HERE HERE HERE    vvvv
    total_rewards.append(total_reward)
    total_steps.append(steps)

    if ep % 200 == 0:
        print("Episode", ep)
        print("Steps taken", steps)
        print("Total reward", total_reward)
        print()

env.close()
print(q_table[341])
np.save("models/taxi.npy", q_table)


fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)

ax1.plot(total_rewards[::50], c="lightskyblue")
ax1.set_ylabel("Rewards")

ax2.plot(total_steps[::50], c="pink")
ax2.set_xlabel("Episode")
ax2.set_ylabel("Steps")

fig.tight_layout()
plt.show()
# ^^^^    HERE HERE HERE    ^^^^

In [None]:
import gymnasium as gym
import numpy as np

env = gym.make("Taxi-v3", render_mode="human")
env.metadata["render_fps"] = 5

# Q-Table
q_table = np.load("models/taxi_good.npy")

episodes = 1

for ep in range(episodes):
    # Performance tracking
    total_reward = 0
    steps = 0

    # Env setup
    done = False
    state, info = env.reset(seed=123)
    while not done:
        action = np.argmax(q_table[state])  # Be greedy and make best action

        state, reward, done, truncated, info = env.step(action)
        if truncated:
            done = True
            reward = -20

        # Update performance tracking
        total_reward += reward
        steps += 1

    print("Episode", ep)
    print("Steps taken", steps)
    print("Total reward", total_reward)
    print()

env.render()
env.close()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

q_table = np.load("models/taxi_good.npy")

# Create the figure and axes
fig, ax = plt.subplots()
im = ax.imshow(q_table, cmap="hot", interpolation="none", aspect="auto")
plt.colorbar(im)
plt.title("Q-Table")

# Add column names to the heatmap
column_names = ["Down", "Up", "Right", "Left", "Pickup", "Dropoff"]
ax.set_xticks(np.arange(len(column_names)))
ax.set_xticklabels(column_names, rotation=45)

# Display the heatmap
plt.show()

In [None]:
import gymnasium as gym
import numpy as np

env = gym.make("Taxi-v3", render_mode="human")
env.metadata["render_fps"] = 5

# Q-Table
q_table = np.load("models/taxi_good.npy")

# vvvv    HERE HERE HERE    vvvv
# Holds all states taxi visites and all neighboring states
states = set()
# ^^^^    HERE HERE HERE    ^^^^

# Env setup
done = False
state, info = env.reset(seed=123)
while not done:
    # vvvv    HERE HERE HERE    vvvv
    states.update([state])  # Add the state taxi is in
    neigbour_states = [
        v[0][1] for v in env.P[state].values()
    ]  # grab neighboring states
    states.update(neigbour_states)
    # ^^^^    HERE HERE HERE    ^^^^

    action = np.argmax(q_table[state])  # Be greedy and make best action

    state, reward, done, truncated, info = env.step(action)
    if truncated:
        done = True
        reward = -20

env.render()
env.close()

# vvvv    HERE HERE HERE    vvvv
print("States:", list(states))
# ^^^^    HERE HERE HERE    ^^^^


In [None]:
import numpy as np
import matplotlib.pyplot as plt

import gymnasium as gym

states = [1, 257, 137, 397, 17, 21, 277, 157, 421, 37, 297, 177, 441, 321, 197, 201, 77, 337, 341, 85, 217, 221, 97, 101, 357, 377, 237, 241, 117, 121]


env = gym.make("Taxi-v3", render_mode="human")
env.metadata["render_fps"] = 5

# Q-Table
q_table = np.load("models/taxi_good.npy")

# Create the figure and axes
fig, ax = plt.subplots()
im = ax.imshow(q_table, cmap="hot", interpolation="none", aspect="auto")
plt.colorbar(im)
plt.title("Q-Table")
# Add column names to the heatmap
column_names = ["Down", "Up", "Right", "Left", "Pickup", "Dropoff"]
ax.set_xticks(np.arange(len(column_names)))
ax.set_xticklabels(column_names, rotation=45)
# Display the initial heatmap
plt.show(block=False)

# Env setup
done = False
state, info = env.reset(seed=123)
while not done:
    action = np.argmax(q_table[state])  # Be greedy and make best action

    matrix_ = q_table.copy()
    matrix_[state, action] = np.max(q_table)
    matrix_ = matrix_[states, :]  # select only our path states
    im.set_data(matrix_)
    plt.pause(0.5)

    state, reward, done, truncated, info = env.step(action)
    if truncated:
        done = True
        reward = -20

env.render()
env.close()

plt.close(fig)


In [None]:
import random
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

env = gym.make("Taxi-v3", render_mode=None)
env.metadata["render_fps"] = 300

# Q-Table
q_table = np.zeros([env.observation_space.n, env.action_space.n])

# Hyperparameters
alpha = 0.1
gamma = 0.6
epsilon = 0.1


states = [
    1,
    257,
    137,
    397,
    17,
    21,
    277,
    157,
    421,
    37,
    297,
    177,
    441,
    321,
    197,
    201,
    77,
    337,
    341,
    85,
    217,
    221,
    97,
    101,
    357,
    377,
    237,
    241,
    117,
    121,
]

# Create the figure and axes
fig, ax = plt.subplots()
im = ax.imshow(q_table, cmap="hot", interpolation="none", aspect="auto")
plt.colorbar(im)
plt.title("Q-Table")
# Add column names to the heatmap
column_names = ["Down", "Up", "Right", "Left", "Pickup", "Dropoff"]
ax.set_xticks(np.arange(len(column_names)))
ax.set_xticklabels(column_names, rotation=45)
# Display the initial heatmap
plt.show(block=False)

episodes = 10_001
total_rewards = []
total_steps = []
for ep in range(episodes):
    # Performance tracking
    total_reward = 0
    steps = 0

    # Env setup
    done = False
    state, info = env.reset(seed=123)
    while not done:
        if random.uniform(0, 1) <= epsilon:
            action = env.action_space.sample()  # Explore by making random action
        else:
            action = np.argmax(q_table[state])  # Be greedy and make best action
        next_state, reward, done, truncated, info = env.step(action)
        if truncated:
            done = True
            reward = -20

        # Update Q-Table
        old_value = q_table[state, action]
        next_max = np.max(q_table[next_state])
        new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)
        q_table[state, action] = new_value
        state = next_state

        # Update performance tracking
        total_reward += reward
        steps += 1
        
    total_rewards.append(total_reward)
    total_steps.append(steps)

    if ep % 200 == 0:
        print("Episode", ep)
        print("Steps taken", steps)
        print("Total reward", total_reward)
        print()
        matrix_ = q_table.copy()
        matrix_ = matrix_[states, :]  # select only our path states
        im.set_data(matrix_)
        plt.pause(0.2)


env.close()
plt.close(fig)

# Deep Q Learning

## CartPole

[CartPole-v1](https://gymnasium.farama.org/environments/classic_control/cart_pole/)

<div>
<img src="images/cartpole.png" width="500"/>
</div>

### Action Space

### Observation Space

### Rewards

In [None]:
import gymnasium as gym

env = gym.make("CartPole-v1", render_mode="human")
env.metadata["render_fps"] = 24
print("Action space:", env.action_space)
print("Observation space:", env.observation_space)

episodes = 5

for ep in range(episodes):
    # Performance tracking
    total_reward = 0

    done = False
    env.reset()
    while not done:
        next_state, reward, done, truncated, info = env.step(env.action_space.sample())
        if truncated:
            done = True

        # Performance tracking
        total_reward += reward
    
    print("Episode: ", ep)
    print("Total reward: ", total_reward)

env.close()

---

In [None]:
import gymnasium as gym

import random
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

env = gym.make("CartPole-v1", render_mode="human")
env.metadata["render_fps"] = 24


class DQNAgent:
    def __init__(self, env):
        self.env = env
        # Hyperparameters
        self.epsilon = 0.1

        self.create_model()

    def create_model(self):
        self.model = Sequential()
        self.model.add(Dense(32, input_dim=self.env.observation_space.shape[0], activation="relu"))
        self.model.add(Dense(16, activation="relu"))
        self.model.add(Dense(self.env.action_space.n, activation="linear"))
        self.model.compile(loss="mse", optimizer=Adam(lr=0.001))

    def action(self, state):
        if np.random.random() <= self.epsilon:
            return env.action_space.sample()
        else:
            return np.argmax(self.model.predict(state)[0])

agent = DQNAgent(env)
print(agent.model.summary())

episodes = 5

for ep in range(episodes):
    # Performance tracking
    total_reward = 0

    done = False
    state, _ = env.reset()
    while not done:

        state = np.reshape(state, [1, env.observation_space.shape[0]])
        action = agent.action(state)
        state, reward, done, truncated, info = env.step(action)
        if truncated:
            done = True

        # Performance tracking
        total_reward += reward
    
    print("Episode: ", ep)
    print("Total reward: ", total_reward)

env.close()


---

In [None]:
import random
import gymnasium as gym
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
import matplotlib.pyplot as plt

env = gym.make("CartPole-v1", render_mode=None)
env.metadata["render_fps"] = 200


class DQNAgent:
    def __init__(self, env):
        self.env = env
        self.memory = deque(maxlen=2000)
        # Hyperparameters
        self.epsilon = 0.1
        # vvvv    HERE HERE HERE    vvvv
        self.gamma = 0.95
        # ^^^^    HERE HERE HERE    ^^^^

        self.create_model()

    def create_model(self):
        self.model = Sequential()
        self.model.add(Dense(32, input_dim=self.env.observation_space.shape[0], activation="relu"))
        self.model.add(Dense(16, activation="relu"))
        self.model.add(Dense(self.env.action_space.n, activation="linear")) # Try something else than linear
        self.model.compile(loss="mse", optimizer=Adam(lr=0.001))

    def action(self, state):
        if np.random.random() <= self.epsilon:
            return env.action_space.sample()
        else:
            return np.argmax(self.model.predict(state)[0])
    
    # vvvv    HERE HERE HERE    vvvv
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self):
        batch_size = 64

        if len(self.memory) < batch_size:
            return None # We don't have enough data to train
        
        minibatch = random.sample(self.memory, batch_size)

        state = np.zeros((batch_size, self.env.observation_space.shape[0]))
        next_state = np.zeros((batch_size, self.env.observation_space.shape[0]))
        action, reward, done = [], [], []

        for i in range(batch_size):
            state[i] = minibatch[i][0]
            action.append(minibatch[i][1])
            reward.append(minibatch[i][2])
            next_state[i] = minibatch[i][3]
            done.append(minibatch[i][4])

        target = self.model.predict(state)
        target_next = self.model.predict(next_state)

        for i in range(batch_size):
            if done[i]:
                target[i][action[i]] = reward[i]
            else:
                target[i][action[i]] = reward[i] + self.gamma * np.amax(target_next[i])

        self.model.fit(state, target, batch_size=batch_size, verbose=0)
    # ^^^^    HERE HERE HERE    ^^^^



agent = DQNAgent(env)
print(agent.model.summary())

episodes = 25
# vvvv    HERE HERE HERE    vvvv
total_rewards = []
# ^^^^    HERE HERE HERE    ^^^^

for ep in range(episodes):
    # Performance tracking
    total_reward = 0

    done = False
    state, _ = env.reset()
    state = np.reshape(state, [1, env.observation_space.shape[0]])

    while not done:
        action = agent.action(state)
        
        next_state, reward, done, truncated, info = env.step(action)
        next_state = np.reshape(next_state, [1, env.observation_space.shape[0]])
        # vvvv    HERE HERE HERE    vvvv
        if done:
            # If agent lost give big negative reward, because default reward is 1
            reward = -10
        elif truncated:
            # Don't give negative reward, but stop playing so we don't have infinite episode
            done = True
       
        agent.remember(state, action, reward, next_state, done)
        state = next_state
        agent.replay()
        # ^^^^    HERE HERE HERE    ^^^^

        # Performance tracking
        total_reward += reward
    
    total_rewards.append(total_reward)
    print("Episode: ", ep)
    print("Total reward: ", total_reward)


env.close()
# vvvv    HERE HERE HERE vvvv
agent.model.save("models/cartpole.h5")

# Plot total rewards
plt.plot(total_rewards)
plt.show()
# ^^^^    HERE HERE HERE    ^^^^

<div>
<img src="./images/cartpole_learning.png" width="500">
</div>

In [None]:
import gymnasium as gym
import numpy as np
from keras.models import load_model


env = gym.make("CartPole-v1", render_mode="human")
env.metadata["render_fps"] = 30

agent = load_model("models/cartpole_good.h5")

episodes = 5

for ep in range(episodes):
    # Performance tracking
    total_reward = 0

    done = False
    state, _ = env.reset()
    state = np.reshape(state, [1, env.observation_space.shape[0]])

    while not done:
        action = np.argmax(agent.predict(state)[0])        
        state, reward, done, truncated, info = env.step(action)
        state = np.reshape(state, [1, env.observation_space.shape[0]])
        
        if truncated:
            # Don't give negative reward, but stop playing so we don't have infinite episode
            done = True
        
        # Performance tracking
        total_reward += reward
    
    print("Episode: ", ep)
    print("Total reward: ", total_reward)


env.close()

---

# Stable baseline 3

Installing:
* `pip install stable-baselines3`

In [None]:
import gymnasium as gym

from stable_baselines3 import DQN

model = DQN("MlpPolicy", "CartPole-v1").learn(100_000, progress_bar=True)
model.save("models/DQN_SB3")

env = gym.make("CartPole-v1", render_mode="human")
env.metadata["render_fps"] = 30

episodes = 5

for ep in range(episodes):
    # Performance tracking
    total_reward = 0

    done = False
    obs, _ = env.reset()
    while not done:
        action, _state = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
        if truncated:
            done = True
        # Performance tracking
        total_reward += reward

    print("Episode: ", ep)
    print("Total reward: ", total_reward)


In [None]:
import gymnasium as gym

from stable_baselines3 import DQN

env = gym.make("CartPole-v1", render_mode="human")
env.metadata["render_fps"] = 30

model = DQN.load("models/DQN_SB3_Best")


episodes = 5

for ep in range(episodes):
    # Performance tracking
    total_reward = 0

    done = False
    obs, _ = env.reset()
    while not done:
        action, _state = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
        if truncated:
            done = True
        # Performance tracking
        total_reward += reward

    print("Episode: ", ep)
    print("Total reward: ", total_reward)


---

# Drugi algoritmi

![rl algos](./images/rl_algos_classification.svg)

## Model-based vs. Model-free

## Policy learning vs. Value learning

In [None]:
import gymnasium as gym

from stable_baselines3 import PPO

model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
model.save("models/ppo")

env = gym.make("CartPole-v1", render_mode="human")
env.metadata["render_fps"] = 30

episodes = 5

for ep in range(episodes):
    # Performance tracking
    total_reward = 0

    done = False
    obs, _ = env.reset()
    while not done:
        action, _state = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
        if truncated:
            done = True
        # Performance tracking
        total_reward += reward

    print("Episode: ", ep)
    print("Total reward: ", total_reward)


In [None]:
import gymnasium as gym

from stable_baselines3 import DQN, A2C, PPO

#model = DQN("MlpPolicy", "CartPole-v1").learn(10_000, progress_bar=True)
model = A2C("MlpPolicy", "CartPole-v1").learn(10_000, progress_bar=True)
#model = PPO("MlpPolicy", "CartPole-v1").learn(10_000, progress_bar=True)

env = gym.make("CartPole-v1", render_mode="human")
env.metadata["render_fps"] = 30

episodes = 5
total_rewards = []

for ep in range(episodes):
    # Performance tracking
    total_reward = 0

    done = False
    obs, _ = env.reset()
    while not done:
        action, _state = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
        if truncated:
            done = True
        # Performance tracking
        total_reward += reward

    total_rewards.append(total_reward)
    print("Episode: ", ep)
    print("Total reward: ", total_reward)

print("Average reward: ", sum(total_rewards) / len(total_rewards))

<table class="docutils align-default">
<thead>
<tr class="row-odd"><th class="head"><p>Name</p></th>
<th class="head"><p><code class="docutils literal notranslate"><span class="pre">Box</span></code></p></th>
<th class="head"><p><code class="docutils literal notranslate"><span class="pre">Discrete</span></code></p></th>
<th class="head"><p><code class="docutils literal notranslate"><span class="pre">MultiDiscrete</span></code></p></th>
<th class="head"><p><code class="docutils literal notranslate"><span class="pre">MultiBinary</span></code></p></th>
<th class="head"><p>Multi Processing</p></th>
</tr>
</thead>
<tbody>
<tr class="row-even"><td><p>ARS <a class="footnote-reference brackets" href="#f1" id="id1" role="doc-noteref"><span class="fn-bracket">[</span>1<span class="fn-bracket">]</span></a></p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-odd"><td><p>A2C</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-even"><td><p>DDPG</p></td>
<td><p>✔️</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-odd"><td><p>DQN</p></td>
<td><p>❌</p></td>
<td><p>✔️</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-even"><td><p>HER</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-odd"><td><p>PPO</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-even"><td><p>QR-DQN <a class="footnote-reference brackets" href="#f1" id="id2" role="doc-noteref"><span class="fn-bracket">[</span>1<span class="fn-bracket">]</span></a></p></td>
<td><p>❌</p></td>
<td><p>️ ✔️</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-odd"><td><p>RecurrentPPO <a class="footnote-reference brackets" href="#f1" id="id3" role="doc-noteref"><span class="fn-bracket">[</span>1<span class="fn-bracket">]</span></a></p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-even"><td><p>SAC</p></td>
<td><p>✔️</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-odd"><td><p>TD3</p></td>
<td><p>✔️</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-even"><td><p>TQC <a class="footnote-reference brackets" href="#f1" id="id4" role="doc-noteref"><span class="fn-bracket">[</span>1<span class="fn-bracket">]</span></a></p></td>
<td><p>✔️</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>❌</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-odd"><td><p>TRPO  <a class="footnote-reference brackets" href="#f1" id="id5" role="doc-noteref"><span class="fn-bracket">[</span>1<span class="fn-bracket">]</span></a></p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
</tr>
<tr class="row-even"><td><p>Maskable PPO <a class="footnote-reference brackets" href="#f1" id="id6" role="doc-noteref"><span class="fn-bracket">[</span>1<span class="fn-bracket">]</span></a></p></td>
<td><p>❌</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
<td><p>✔️</p></td>
</tr>
</tbody>
</table>

---

# Custom Environments

In [None]:
# source: https://github.com/TheAILearner/Snake-Game-using-OpenCV-Python/blob/master/snake_game_using_opencv.ipynb
import numpy as np
import cv2
import random
import time


def collision_with_apple(apple_position, score):
    apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
    score += 1
    return apple_position, score


def collision_with_boundaries(snake_head):
    if (
        snake_head[0] >= 500
        or snake_head[0] < 0
        or snake_head[1] >= 500
        or snake_head[1] < 0
    ):
        return 1
    else:
        return 0


def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0


img = np.zeros((500, 500, 3), dtype="uint8")
# Initial Snake and Apple position
snake_position = [[250, 250], [240, 250], [230, 250]]
apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]

score = 0
prev_button_direction = 1
button_direction = 1

snake_head = [250, 250]
while True:
    cv2.imshow("a", img)
    cv2.waitKey(1)
    img = np.zeros((500, 500, 3), dtype="uint8")
    # Display Apple
    cv2.rectangle(
        img,
        (apple_position[0], apple_position[1]),
        (apple_position[0] + 10, apple_position[1] + 10),
        (0, 0, 255),
        3,
    )
    # Display Snake
    for position in snake_position:
        cv2.rectangle(
            img,
            (position[0], position[1]),
            (position[0] + 10, position[1] + 10),
            (0, 255, 0),
            3,
        )

    # Takes step after fixed time
    t_end = time.time() + 0.05
    k = -1
    while time.time() < t_end:
        if k == -1:
            k = cv2.waitKey(1)
        else:
            continue

    # 0-Left, 1-Right, 3-Up, 2-Down, q-Break
    # a-Left, d-Right, w-Up, s-Down

    if k == ord("a") and prev_button_direction != 1:
        button_direction = 0
    elif k == ord("d") and prev_button_direction != 0:
        button_direction = 1
    elif k == ord("w") and prev_button_direction != 2:
        button_direction = 3
    elif k == ord("s") and prev_button_direction != 3:
        button_direction = 2
    elif k == ord("q"):
        break
    else:
        button_direction = button_direction
    prev_button_direction = button_direction

    # Change the head position based on the button direction
    if button_direction == 1:
        snake_head[0] += 10
    elif button_direction == 0:
        snake_head[0] -= 10
    elif button_direction == 2:
        snake_head[1] += 10
    elif button_direction == 3:
        snake_head[1] -= 10

    # Increase Snake length on eating apple
    if snake_head == apple_position:
        apple_position, score = collision_with_apple(apple_position, score)
        snake_position.insert(0, list(snake_head))

    else:
        snake_position.insert(0, list(snake_head))
        snake_position.pop()

    # On collision kill the snake and print the score
    if (
        collision_with_boundaries(snake_head) == 1
        or collision_with_self(snake_position) == 1
    ):
        font = cv2.FONT_HERSHEY_SIMPLEX
        img = np.zeros((500, 500, 3), dtype="uint8")
        cv2.putText(
            img,
            "Your Score is {}".format(score),
            (140, 250),
            font,
            1,
            (255, 255, 255),
            2,
            cv2.LINE_AA,
        )
        cv2.imshow("a", img)
        cv2.waitKey(0)
        break

cv2.destroyAllWindows()

```python
import gymnasium as gym
import numpy as np
from gymnasium import spaces


class CustomEnv(gym.Env):
    """Custom Environment that follows gym interface."""

    metadata = {"render_modes": ["human"], "render_fps": 30}

    def __init__(self, arg1, arg2, ...):
        super().__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=0, high=255,
                                            shape=(N_CHANNELS, HEIGHT, WIDTH), dtype=np.uint8)

    def step(self, action):
        ...
        return observation, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        ...
        return observation, info

    def render(self):
        ...

    def close(self):
        ...
```

---

In [None]:
import random

import gymnasium as gym
import numpy as np
from gymnasium import spaces


def collision_with_apple(apple_position, score):
    apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
    score += 1
    return apple_position, score


def collision_with_boundaries(snake_head):
    if (
        snake_head[0] >= 500
        or snake_head[0] < 0
        or snake_head[1] >= 500
        or snake_head[1] < 0
    ):
        return 1
    else:
        return 0


def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0


class SnekEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 30}

    def __init__(self):
        super(SnekEnv, self).__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.action_space = spaces.Discrete(4)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=0, high=1, shape=(8,), dtype=np.float32)

    def step(self, action):
        pass

    def reset(self, seed=None, options=None):
        pass

    def render(self, mode="human"):
        pass

    def close(self):
        pass

In [None]:
import random

import gymnasium as gym
import numpy as np
from gymnasium import spaces


def collision_with_apple(apple_position, score):
    apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
    score += 1
    return apple_position, score


def collision_with_boundaries(snake_head):
    if (
        snake_head[0] >= 500
        or snake_head[0] < 0
        or snake_head[1] >= 500
        or snake_head[1] < 0
    ):
        return 1
    else:
        return 0


def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0


class SnekEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 30}

    def __init__(self):
        super(SnekEnv, self).__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.action_space = spaces.Discrete(4)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=0, high=1, shape=(8,), dtype=np.float32)

    def step(self, action):
        pass

    def reset(self, seed=None, options=None):
        # vvvv    HERE HERE HERE    vvvv
        self.img = np.zeros((500, 500, 3), dtype="uint8")

        # Initial Snake and Apple position
        self.snake_position = [[250, 250], [240, 250], [230, 250]]
        self.apple_position = [
            random.randrange(1, 50) * 10,
            random.randrange(1, 50) * 10,
        ]

        self.score = 0
        self.snake_head = self.snake_position[0].copy()

        self.done = False
        self.truncated = False

        observation = self.constructObservation()

        return observation, {}  # observation, info
        # ^^^^    HERE HERE HERE    ^^^^

    def render(self, mode="human"):
        pass

    def close(self):
        pass

    # vvvv    HERE HERE HERE    vvvv
    def constructObservation(self):
        # Used to create a current observation
        """Observation holds: (all 1 or 0 values)
        * is apple to the left of head
        * is apple to the right of head
        * is apple above head
        * is apple below head
        * is wall or tail to the left of head
        * is wall or tail to the right of head
        * is wall or tail above head
        * is wall or tail below head
        """
        observation = []

        if self.apple_position[0] < self.snake_head[0]:  # to the left
            observation.extend([1, 0])
        else:
            observation.extend([0, 1])

        if self.apple_position[1] < self.snake_head[1]:  # above
            observation.extend([1, 0])
        else:
            observation.extend([0, 1])

        positions = self.snake_position.copy()
        head = positions[0].copy()

        # Check left
        head_left = [head[0] - 10, head[1]]
        positions[0] = head_left
        o = collision_with_boundaries(head_left) or collision_with_self(positions)
        observation.append(o)

        # Check right
        head_right = [head[0] + 10, head[1]]
        positions[0] = head_right
        o = collision_with_boundaries(head_right) or collision_with_self(positions)
        observation.append(o)

        # Check above
        head_above = [head[0], head[1] - 10]
        positions[0] = head_above
        o = collision_with_boundaries(head_above) or collision_with_self(positions)
        observation.append(o)

        # Check below
        head_below = [head[0], head[1] + 10]
        positions[0] = head_below
        o = collision_with_boundaries(head_below) or collision_with_self(positions)
        observation.append(o)

        observation = np.array(observation, dtype=np.float32)
        return observation
        # ^^^^    HERE HERE HERE    ^^^^


In [None]:
import random

import gymnasium as gym
import numpy as np
from gymnasium import spaces
import cv2


def collision_with_apple(apple_position, score):
    apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
    score += 1
    return apple_position, score


def collision_with_boundaries(snake_head):
    if (
        snake_head[0] >= 500
        or snake_head[0] < 0
        or snake_head[1] >= 500
        or snake_head[1] < 0
    ):
        return 1
    else:
        return 0


def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0


class SnekEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 30}

    def __init__(self):
        super(SnekEnv, self).__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.action_space = spaces.Discrete(4)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=0, high=1, shape=(8,), dtype=np.float32)

    def step(self, action):
        return None, None, None, None, None

    def reset(self, seed=None, options=None):
        self.img = np.zeros((500, 500, 3), dtype="uint8")

        # Initial Snake and Apple position
        self.snake_position = [[250, 250], [240, 250], [230, 250]]
        self.apple_position = [
            random.randrange(1, 50) * 10,
            random.randrange(1, 50) * 10,
        ]

        self.score = 0
        self.snake_head = self.snake_position[0].copy()

        self.done = False
        self.truncated = False

        observation = self.constructObservation()

        return observation, {}  # observation, info

    def render(self, mode="human"):
        # vvvv    HERE HERE HERE    vvvv
        cv2.imshow("Snake", self.img)  # Show the current frame
        cv2.waitKey(
            50
        )  # wait 50ms. Otherwise a person can't see anything because it is too fast

        # Create the next frame to be rendered
        self.img = np.zeros((500, 500, 3), dtype="uint8")  # create empty board

        # Display Apple
        cv2.rectangle(
            self.img,
            (self.apple_position[0], self.apple_position[1]),
            (self.apple_position[0] + 10, self.apple_position[1] + 10),
            (0, 0, 255),
            3,
        )
        # Display Snake
        for position in self.snake_position:
            cv2.rectangle(
                self.img,
                (position[0], position[1]),
                (position[0] + 10, position[1] + 10),
                (0, 255, 0),
                3,
            )
        # ^^^^    HERE HERE HERE ^^^^

    def close(self):
        pass

    def constructObservation(self):
        # Used to create a current observation
        """Observation holds: (all 1 or 0 values)
        * is apple to the left of head
        * is apple to the right of head
        * is apple above head
        * is apple below head
        * is wall or tail to the left of head
        * is wall or tail to the right of head
        * is wall or tail above head
        * is wall or tail below head
        """
        observation = []

        if self.apple_position[0] < self.snake_head[0]:  # to the left
            observation.extend([1, 0])
        else:
            observation.extend([0, 1])

        if self.apple_position[1] < self.snake_head[1]:  # above
            observation.extend([1, 0])
        else:
            observation.extend([0, 1])

        positions = self.snake_position.copy()
        head = positions[0].copy()

        # Check left
        head_left = [head[0] - 10, head[1]]
        positions[0] = head_left
        o = collision_with_boundaries(head_left) or collision_with_self(positions)
        observation.append(o)

        # Check right
        head_right = [head[0] + 10, head[1]]
        positions[0] = head_right
        o = collision_with_boundaries(head_right) or collision_with_self(positions)
        observation.append(o)

        # Check above
        head_above = [head[0], head[1] - 10]
        positions[0] = head_above
        o = collision_with_boundaries(head_above) or collision_with_self(positions)
        observation.append(o)

        # Check below
        head_below = [head[0], head[1] + 10]
        positions[0] = head_below
        o = collision_with_boundaries(head_below) or collision_with_self(positions)
        observation.append(o)

        observation = np.array(observation, dtype=np.float32)
        return observation


In [None]:
import random

import gymnasium as gym
import numpy as np
from gymnasium import spaces
import cv2


def collision_with_apple(apple_position, score):
    apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
    score += 1
    return apple_position, score


def collision_with_boundaries(snake_head):
    if (
        snake_head[0] >= 500
        or snake_head[0] < 0
        or snake_head[1] >= 500
        or snake_head[1] < 0
    ):
        return 1
    else:
        return 0


def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0


class SnekEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 30}

    def __init__(self):
        super(SnekEnv, self).__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.action_space = spaces.Discrete(4)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=0, high=1, shape=(8,), dtype=np.float32)
        # vvvv    HERE HERE HERE    vvvv
        self.render_training = False
        # ^^^^    HERE HERE HERE    ^^^^

    def step(self, action):
        # vvvv    HERE HERE HERE    vvvv
        self.render() if self.render_training else None  # If i want to see training or not

        # Change the head position based on the button direction
        if action == 1:  # move RIGHT
            self.snake_head[0] += 10
        elif action == 0:  # move LEFT
            self.snake_head[0] -= 10
        elif action == 2:  # move DOWN
            self.snake_head[1] += 10
        elif action == 3:  # move UP
            self.snake_head[1] -= 10

        # Calculating reward
        self.reward = 0

        # Increase Snake length on eating apple
        if self.snake_head == self.apple_position:
            self.apple_position, self.score = collision_with_apple(
                self.apple_position, self.score
            )
            self.snake_position.insert(0, list(self.snake_head))
        else:
            self.snake_position.insert(0, list(self.snake_head))
            self.snake_position.pop()

        # On collision kill the snake and print the score
        if (
            collision_with_boundaries(self.snake_head) == 1
            or collision_with_self(self.snake_position) == 1
        ):
            font = cv2.FONT_HERSHEY_SIMPLEX
            self.img = np.zeros((500, 500, 3), dtype="uint8")
            cv2.putText(
                self.img,
                "Your Score is {}".format(self.score),
                (140, 250),
                font,
                1,
                (255, 255, 255),
                2,
                cv2.LINE_AA,
            )
            self.done = True  # if sneak hits itself it dies

        self.reward += len(self.snake_position)

        info = {}

        observation = self.constructObservation()

        return (
            observation,
            self.reward,
            self.done,
            self.truncated,
            info,
        )  # observation, reward, terminated, truncated, info
        # ^^^^    HERE HERE HERE    ^^^^

    def reset(self, seed=None, options=None):
        self.img = np.zeros((500, 500, 3), dtype="uint8")

        # Initial Snake and Apple position
        self.snake_position = [[250, 250], [240, 250], [230, 250]]
        self.apple_position = [
            random.randrange(1, 50) * 10,
            random.randrange(1, 50) * 10,
        ]

        self.score = 0
        self.snake_head = self.snake_position[0].copy()

        self.done = False
        self.truncated = False

        observation = self.constructObservation()

        return observation, {}  # observation, info

    def render(self, mode="human"):
        cv2.imshow("Snake", self.img)  # Show the current frame
        cv2.waitKey(
            50
        )  # wait 50ms. Otherwise a person can't see anything because it is too fast

        # Create the next frame to be rendered
        self.img = np.zeros((500, 500, 3), dtype="uint8")  # create empty board

        # Display Apple
        cv2.rectangle(
            self.img,
            (self.apple_position[0], self.apple_position[1]),
            (self.apple_position[0] + 10, self.apple_position[1] + 10),
            (0, 0, 255),
            3,
        )
        # Display Snake
        for position in self.snake_position:
            cv2.rectangle(
                self.img,
                (position[0], position[1]),
                (position[0] + 10, position[1] + 10),
                (0, 255, 0),
                3,
            )

    def close(self):
        pass

    def constructObservation(self):
        # Used to create a current observation
        """Observation holds: (all 1 or 0 values)
        * is apple to the left of head
        * is apple to the right of head
        * is apple above head
        * is apple below head
        * is wall or tail to the left of head
        * is wall or tail to the right of head
        * is wall or tail above head
        * is wall or tail below head
        """
        observation = []

        if self.apple_position[0] < self.snake_head[0]:  # to the left
            observation.extend([1, 0])
        else:
            observation.extend([0, 1])

        if self.apple_position[1] < self.snake_head[1]:  # above
            observation.extend([1, 0])
        else:
            observation.extend([0, 1])

        positions = self.snake_position.copy()
        head = positions[0].copy()

        # Check left
        head_left = [head[0] - 10, head[1]]
        positions[0] = head_left
        o = collision_with_boundaries(head_left) or collision_with_self(positions)
        observation.append(o)

        # Check right
        head_right = [head[0] + 10, head[1]]
        positions[0] = head_right
        o = collision_with_boundaries(head_right) or collision_with_self(positions)
        observation.append(o)

        # Check above
        head_above = [head[0], head[1] - 10]
        positions[0] = head_above
        o = collision_with_boundaries(head_above) or collision_with_self(positions)
        observation.append(o)

        # Check below
        head_below = [head[0], head[1] + 10]
        positions[0] = head_below
        o = collision_with_boundaries(head_below) or collision_with_self(positions)
        observation.append(o)

        observation = np.array(observation, dtype=np.float32)
        return observation


In [None]:
from stable_baselines3.common.env_checker import check_env
from snakeenv import SnekEnv


env = SnekEnv()
# It will check your custom environment and output additional warnings if needed
check_env(env, warn=True)

In [None]:
from snakeenv import SnekEnv

env = SnekEnv()

episodes = 5
for ep in range(episodes):
    env.reset()
    total_reward = 0
    steps = 0

    done = False
    while not done:
        env.render()
        observation, reward, done, _, info = env.step(env.action_space.sample())

        total_reward += reward
        steps += 1

    env.render()
    print("Steps taken", steps)
    print("Total reward", total_reward)


In [None]:
from stable_baselines3 import PPO
from snakeenv import SnekEnv

env = SnekEnv()
policy = "MlpPolicy"
model_name = "snek"

try:
    model = PPO.load(f"./models/11111{model_name}", env)
except FileNotFoundError as e:
    model = PPO(policy, env)

model.learn(3_000, progress_bar=True)
model.save(f"models/{model_name}")

done = False
obs, _ = env.reset()
while not done:
    env.render()
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)
    if truncated:
        done = True
env.render()

In [None]:
from stable_baselines3 import PPO
from snakeenv2 import SnekEnv

env = SnekEnv()
policy = "MlpPolicy"
model_name = "snek_V1"
env.render_training = False
train = True

try:
    model = PPO.load(f"./models/{model_name}", env)
except FileNotFoundError as e:
    model = PPO(policy, env)

model.learn(3_000, progress_bar=True) if train else None
model.save(f"models/{model_name}") if train else None

done = False
obs, _ = env.reset()
while not done:
    env.render()

    action, _state = model.predict(obs, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)
    if truncated:
        done = True

In [None]:
import random

import gymnasium as gym
import numpy as np
from gymnasium import spaces
import cv2

EP_LENGTH = 300


def collision_with_apple(apple_position, score):
    apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
    score += 1
    return apple_position, score


def collision_with_boundaries(snake_head):
    if (
        snake_head[0] >= 500
        or snake_head[0] < 0
        or snake_head[1] >= 500
        or snake_head[1] < 0
    ):
        return 1
    else:
        return 0


def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0


class SnekEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 30}

    def __init__(self):
        super(SnekEnv, self).__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.action_space = spaces.Discrete(4)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=0, high=1, shape=(8,), dtype=np.float32)
        self.render_training = False

    def step(self, action):
        self.render() if self.render_training else None  # If i want to see training or not

        # Change the head position based on the button direction
        if action == 1:  # move RIGHT
            self.snake_head[0] += 10
        elif action == 0:  # move LEFT
            self.snake_head[0] -= 10
        elif action == 2:  # move DOWN
            self.snake_head[1] += 10
        elif action == 3:  # move UP
            self.snake_head[1] -= 10
        self.ep_steps += 1

        # Calculating reward
        self.reward = 0

        prev_apple_dist = np.sqrt(
            (self.apple_position[0] - self.snake_position[1][0]) ** 2
            + (self.apple_position[1] - self.snake_position[1][1]) ** 2
        )
        head_apple_dist = np.sqrt(
            (self.apple_position[0] - self.snake_head[0]) ** 2
            + (self.apple_position[1] - self.snake_head[1]) ** 2
        )  # current head distance to apple

        if head_apple_dist < prev_apple_dist:
            self.reward += 1
        else:
            self.reward -= 1

        # Increase Snake length on eating apple
        if self.snake_head == self.apple_position:
            self.apple_position, self.score = collision_with_apple(
                self.apple_position, self.score
            )
            self.snake_position.insert(0, list(self.snake_head))
            self.reward += 100
        else:
            self.snake_position.insert(0, list(self.snake_head))
            self.snake_position.pop()

        # On collision kill the snake and print the score
        if (
            collision_with_boundaries(self.snake_head) == 1
            or collision_with_self(self.snake_position) == 1
        ):
            font = cv2.FONT_HERSHEY_SIMPLEX
            self.img = np.zeros((500, 500, 3), dtype="uint8")
            cv2.putText(
                self.img,
                "Your Score is {}".format(self.score),
                (140, 250),
                font,
                1,
                (255, 255, 255),
                2,
                cv2.LINE_AA,
            )
            self.done = True  # if sneak hits itself it dies
            self.reward -= 1_000
            print("Suicide")

        self.reward += len(self.snake_position)

        if self.ep_steps >= EP_LENGTH:
            self.truncated = True

        info = {}

        observation = self.constructObservation()

        return (
            observation,
            self.reward,
            self.done,
            self.truncated,
            info,
        )  # observation, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        self.ep_steps = 0  # how many steps we made since last reset
        self.img = np.zeros((500, 500, 3), dtype="uint8")

        # Initial Snake and Apple position
        self.snake_position = [[250, 250], [240, 250], [230, 250]]
        self.apple_position = [
            random.randrange(1, 50) * 10,
            random.randrange(1, 50) * 10,
        ]

        self.score = 0
        self.snake_head = self.snake_position[0].copy()

        self.done = False
        self.truncated = False

        observation = self.constructObservation()

        return observation, {}  # observation, info

    def render(self, mode="human"):
        cv2.imshow("Snake", self.img)  # Show the current frame
        cv2.waitKey(
            50
        )  # wait 50ms. Otherwise a person can't see anything because it is too fast

        # Create the next frame to be rendered
        self.img = np.zeros((500, 500, 3), dtype="uint8")  # create empty board

        # Display Apple
        cv2.rectangle(
            self.img,
            (self.apple_position[0], self.apple_position[1]),
            (self.apple_position[0] + 10, self.apple_position[1] + 10),
            (0, 0, 255),
            3,
        )
        # Display Snake
        for position in self.snake_position:
            cv2.rectangle(
                self.img,
                (position[0], position[1]),
                (position[0] + 10, position[1] + 10),
                (0, 255, 0),
                3,
            )

    def close(self):
        pass

    def constructObservation(self):
        # Used to create a current observation
        """Observation holds: (all 1 or 0 values)
        * is apple to the left of head
        * is apple to the right of head
        * is apple above head
        * is apple below head
        * is wall or tail to the left of head
        * is wall or tail to the right of head
        * is wall or tail above head
        * is wall or tail below head
        """
        observation = []

        if self.apple_position[0] < self.snake_head[0]:  # to the left
            observation.extend([1, 0])
        else:
            observation.extend([0, 1])

        if self.apple_position[1] < self.snake_head[1]:  # above
            observation.extend([1, 0])
        else:
            observation.extend([0, 1])

        positions = self.snake_position.copy()
        head = positions[0].copy()

        # Check left
        head_left = [head[0] - 10, head[1]]
        positions[0] = head_left
        o = collision_with_boundaries(head_left) or collision_with_self(positions)
        observation.append(o)

        # Check right
        head_right = [head[0] + 10, head[1]]
        positions[0] = head_right
        o = collision_with_boundaries(head_right) or collision_with_self(positions)
        observation.append(o)

        # Check above
        head_above = [head[0], head[1] - 10]
        positions[0] = head_above
        o = collision_with_boundaries(head_above) or collision_with_self(positions)
        observation.append(o)

        # Check below
        head_below = [head[0], head[1] + 10]
        positions[0] = head_below
        o = collision_with_boundaries(head_below) or collision_with_self(positions)
        observation.append(o)

        observation = np.array(observation, dtype=np.float32)
        return observation
