# Reinforcement Learning

# 4. Online control

This notebook presents the **online control** of an agent by SARSA and Q-learning.

In [82]:
import numpy as np

In [83]:
from model import TicTacToe, Nim, ConnectFour
from agent import Agent, OnlineControl
from dynamic import ValueIteration

In [141]:
import random

## To do

* Complete the class ``SARSA`` and test it on Tic-Tac-Toe.
* Complete the class ``QLearning`` and test it on Tic-Tac-Toe.
* Compare these algorithms on Tic-Tac-Toe (play first) and Nim (play second), using a random adversary, then a perfect adversary. Comment your results.
* Test these algorithms on Connect 4 against a random adversary. Comment your results.

## SARSA

In [494]:
class SARSA(OnlineControl):
    """Online control by SARSA."""
        
    def update_values(self, state=None, horizon=100, epsilon=0.5):
        """Learn the action-value function online."""
        self.model.reset(state)
        state = self.model.state
        if not self.model.is_terminal(state):
            action = self.randomize_best_action(state, epsilon=epsilon)
            for t in range(horizon):
                code = self.model.encode(state)
                self.action_count[code][action] += 1
                reward, stop = self.model.step(action)
                # to be modified (get sample gain)
                next_state = self.model.state
                code_next = self.model.encode(next_state)
                # begin
                
                if self.model.is_terminal(next_state):
                    gain = reward  # No future rewards after a terminal state
                else:
                    next_action = self.randomize_best_action(next_state, epsilon=epsilon)
                  
                    gain = reward + self.gamma * self.action_value[code_next][next_action]
                # end
                diff = gain - self.action_value[code][action]
                count = self.action_count[code][action]
                self.action_value[code][action] += diff / count
                if stop:
                    break
                # to be modified (update state and action)
                # begin
                state = next_state
                action = next_action
                # end

## Q-learning

In [495]:
class QLearning(OnlineControl):
    """Online control by Q-learning."""
        
    def update_values(self, state=None, horizon=100, epsilon=0.5):
        """Learn the action-value function online."""
        self.model.reset(state)
        state = self.model.state
        all_actions = self.model.get_all_actions()
        # to be completed
        if not self.model.is_terminal(state):
            for t in range(horizon):
                action = self.randomize_best_action(state, epsilon=epsilon)
                code = self.model.encode(state)

                self.action_count[code][action] += 1
                reward, stop = self.model.step(action)

                next_state = self.model.state
                code_next = self.model.encode(next_state)

                if self.model.is_terminal(next_state):
                    gain = reward
                else:
                   # print(self.action_value, "STOP", state, "STOP", code_next)
                   # q_values = np.array([self.action_value[code_next][act] for act in all_actions])
                    #max_act = np.argmax(q_values)
                    #gain = reward + self.gamma * self.action_value[code_next][max_act]
                    gain = reward + self.gamma * max(self.action_value[code_next].values(), default=0)
                    
                diff = gain - self.action_value[code][action]
                count = self.action_count[code][action]
                self.action_value[code][action] += diff / count
                if stop:
                    break
                # to be modified (update state and action)
                # begin
                state = next_state



## TEST SARSA

In [518]:
Game = TicTacToe
game = Game(adversary_policy="random")

In [519]:
agent = Agent(game)
np.mean(agent.get_gains())

0.23

In [520]:
Control = SARSA
algo = Control(game)
n_games = 1000

for i in range(n_games):
    algo.update_values(epsilon=0.1)
policy = algo.get_policy()

agent = Agent(game, policy)
np.mean(agent.get_gains())

0.9

Let's try perfect adversary policy

In [521]:
algo = ValueIteration(game)
policy, ad_policy = algo.get_perfect_players()

game = Game(adversary_policy=ad_policy)
agent = Agent(game)
algo = Control(game)

n_games = 1000
for i in range(n_games):
    algo.update_values(epsilon=0.1)
policy = algo.get_policy()
agent = Agent(game, policy)

np.mean(agent.get_gains())

-0.5

In [531]:
Game = Nim
game = Game()

In [532]:
agent = Agent(game)
np.mean(agent.get_gains())

0.04

In [533]:
Control = SARSA
algo = Control(game)
n_games = 1000

for i in range(n_games):
    algo.update_values(epsilon=0.1)
policy = algo.get_policy()

agent = Agent(game, policy)
np.mean(agent.get_gains())

0.72

In [534]:
algo = ValueIteration(game)
policy, ad_policy = algo.get_perfect_players()

game = Game(adversary_policy=ad_policy)
agent = Agent(game)
algo = Control(game)

n_games = 1000
for i in range(n_games):
    algo.update_values(epsilon=0.1)
policy = algo.get_policy()
agent = Agent(game, policy)

np.mean(agent.get_gains())

-1.0

TEST Q Learning

In [528]:
Game = TicTacToe
game = Game(adversary_policy="random")

In [529]:
agent = Agent(game)
np.mean(agent.get_gains())

0.27

In [530]:
Control = QLearning
algo = Control(game)

n_games = 1000
for i in range(n_games):
    algo.update_values(epsilon=0.1)
policy = algo.get_policy()
agent = Agent(game, policy)

np.mean(agent.get_gains())

0.86

In [507]:
algo = ValueIteration(game)
policy, ad_policy = algo.get_perfect_players()

game = Game(adversary_policy=ad_policy)
agent = Agent(game)

Control = QLearning
algo = Control(game)

n_games = 5000
for i in range(n_games):
    algo.update_values(epsilon=0.1)
policy = algo.get_policy()
agent = Agent(game, policy)

np.mean(agent.get_gains())

-0.07

In [535]:
Game = Nim
game = Game()

In [536]:
agent = Agent(game)
np.mean(agent.get_gains())

-0.06

In [537]:
Control = QLearning
algo = Control(game)

n_games = 1000
for i in range(n_games):
    algo.update_values(epsilon=0.1)
policy = algo.get_policy()
agent = Agent(game, policy)

np.mean(agent.get_gains())

0.76

In [538]:
algo = ValueIteration(game)
policy, ad_policy = algo.get_perfect_players()

game = Game(adversary_policy=ad_policy)
agent = Agent(game)

Control = QLearning
algo = Control(game)

n_games = 5000
for i in range(n_games):
    algo.update_values(epsilon=0.1)
policy = algo.get_policy()
agent = Agent(game, policy)

np.mean(agent.get_gains())

-1.0

In games with a random policy (such as Tic-Tac-Toe and Nim), both SARSA and Q-learning algorithms can consistently achieve high gains because the opponent frequently makes suboptimal moves, allowing the algorithms to easily exploit these mistakes. However, against an ideal opponent who always makes optimal moves, both algorithms struggle to adapt and often lose, as the opponent's strategy is too strong and leaves little room for error.

In [539]:
Game = ConnectFour
game = Game(adversary_policy="random")

In [540]:
agent = Agent(game)
np.mean(agent.get_gains())

0.08

In [541]:
Control = QLearning
algo = Control(game)

n_games = 10000
for i in range(n_games):
    algo.update_values(epsilon=0.1)
policy = algo.get_policy()
agent = Agent(game, policy)

np.mean(agent.get_gains())

0.48

In [542]:
Control = SARSA
algo = Control(game)
n_games = 10000

for i in range(n_games):
    algo.update_values(epsilon=0.1)
policy = algo.get_policy()

agent = Agent(game, policy)
np.mean(agent.get_gains())

0.24

Connect Four requires a large number of episodes to achieve a reasonable gain because the game has a vast state and action space, making it challenging for algorithms like SARSA and Q-learning to explore and learn optimal strategies effectively. 