In [1]:
import numpy as np
import gymnasium as gym
from collections import defaultdict  # required for creating Q(s, a)
from moviepy import ImageSequenceClip # to generate gif
from IPython.display import Image

import matplotlib 
#matplotlib.use('Qt5Agg') # Activte it if you want external plot for any interaction
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [2]:
# We train our agent by SARSA in two following environments

envs = ['Taxi-v3', 'CliffWalking-v1']


In [3]:
# This functions are for visualization of episodes after training
# -------------------------
# Render Episodes Using RGB Frames
# -------------------------



def create_gif(frames, filename, fps=5):
    """Creates a GIF animation from a list of frames."""
    clip = ImageSequenceClip(frames, fps=fps)
    clip.write_gif(filename, fps=fps)

    
    
def run_multi_episodes(env, Q_table, run_num=10, epsilon=0):
    """Run a single episode using the learned Q-table."""
    total_frames = []
    total_reward = []
    for run in range(run_num):
        state, _ = env.reset()
        done = False
        episode_reward = 0
        frames = [env.render()]

        while not done:
            action = np.argmax(Q_table[state])
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            frames.append(env.render())
            episode_reward += reward
            state = next_state
        total_frames.extend(frames)
        total_reward.append(episode_reward)
    return total_frames, total_reward

In [4]:
def Train_SARSA(envIdx, episodes_num, alpha=0.1, gamma=0.99, epsilon=0.1):
    
    # Set up the environment
    env = gym.make(envs[envIdx])
    
    
    # Initializing the Q table
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    Q_table = np.zeros((n_states, n_actions))
    print(f'Environment {envs[envIdx]}:\n    Size of observation space:{n_states}\n    Size of action space:{n_actions}')
    
    
    episode_rewards = []
    episode_lengths = []
    
    
    for episode in range(episodes_num):
        
        
        state, _ = env.reset()
        # epsilon greedy
        if np.random.random() < epsilon:
            action = env.action_space.sample()
        else:
            action = np.argmax(Q_table[state])
            

        done = False
        total_reward = 0
        steps = 0

        while not done:
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # epsilon greedy
            if np.random.random() < epsilon:
                next_action = env.action_space.sample()
            else:
                next_action = np.argmax(Q_table[next_state])

            # SARSA update on Q-table
            Q_table[state, action] += alpha * (
                                        reward + gamma * Q_table[next_state, next_action] - Q_table[state, action]
                                    )

            state = next_state
            action = next_action
            total_reward += reward
            steps += 1

        episode_rewards.append(total_reward)
        episode_lengths.append(steps)
        
        epsilon *= 0.99 

        if episode % 5000 == 0: # printout the training progress
            avg_reward = np.mean(episode_rewards[-100:])
            avg_length = np.mean(episode_lengths[-100:])
            print(f"Episode {episode}, Avg Reward: {avg_reward:.2f}, Avg Length: {avg_length:.2f}")
            
    
    # plot return and length for episodes
    # plots will be saved into the project folder
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(episode_rewards)
    plt.title("Episode Return")
    plt.xlabel("Episode")
    plt.ylabel("Total Reward")

    plt.subplot(1, 2, 2)
    plt.plot(episode_lengths)
    plt.title("Episode Lengths")
    plt.xlabel("Episode")
    plt.ylabel("Number of Steps")

    plt.tight_layout()
    plt.savefig(f'metrics_{envs[envIdx]}.png', bbox_inches='tight',  dpi=100)
    plt.close()
    
    
    # play and visualize some episodes
    env_vis = gym.make(envs[envIdx], render_mode='rgb_array')
    frames, total_reward = run_multi_episodes(env_vis, Q_table, run_num=5)
    create_gif(frames, f"trained_{envs[envIdx]}.gif", fps=5)
    print(f"Episodes completed with total rewards: {total_reward}")

    env.close()
     