In [6]:
import gym
import random
import warnings
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict

warnings.filterwarnings("ignore")

In [7]:
class SARSA:
    def __init__(self, bin_list, num_episodes, gamma, epsilon, max_steps, alpha):
        self.bin_list = bin_list
        self.num_episodes = num_episodes
        self.gamma = gamma
        self.epsilon = epsilon
        self.max_steps = max_steps
        self.alpha = alpha  # 학습률 추가

    def discretize(self, obs):
        idxs = [
            min(max(np.digitize(o, b) - 1, 0), len(b) - 1)
            for o, b in zip(obs, self.bin_list)
        ]
        return tuple(idxs)

    def decay_epsilon(self, rewards_log):
        avg_reward = np.mean(rewards_log[-10:])
        if avg_reward > 450:
            self.epsilon = 0.0001

    def fit(self, env, Q):
        rewards_log = []

        for _ in tqdm(range(self.num_episodes)):
            obs, info = env.reset()
            state = self.discretize(obs)
            total_reward = 0

            # ε-탐욕적으로 초기 행동 선택
            if random.random() < self.epsilon:
                action = random.randint(0, 1)
            else:
                action = np.argmax(Q[state])

            for _ in range(self.max_steps):
                obs_next, reward, done, truncated, info = env.step(action)
                next_state = self.discretize(obs_next)

                # 다음 행동 a'도 ε-탐욕적으로 선택 (SARSA 방식)
                if random.random() < self.epsilon:
                    next_action = random.randint(0, 1)
                else:
                    next_action = np.argmax(Q[next_state])

                # SARSA 업데이트 식 적용
                if done:
                    target = reward  # 종료 상태에서는 다음 상태의 Q값 없음
                else:
                    target = reward + self.gamma * Q[next_state][next_action]

                Q[state][action] += self.alpha * (target - Q[state][action])
                total_reward += reward
                state, action = next_state, next_action  # 다음 상태, 행동으로 이동

            rewards_log.append(total_reward)
            self.decay_epsilon(rewards_log)

        env.close()
        return Q, rewards_log

In [8]:
# 1) 환경 & 이산화 준비
env = gym.make('CartPole-v1')
num_bins = 12
# 관측 범위
ranges = [(-4.8, 4.8), (-3, 3), (-0.418, 0.418), (-4, 4)]
bins_list = [np.linspace(lo, hi, num_bins) for (lo, hi) in ranges]

# 2) Q, 카운트(점진적 평균)
Q = defaultdict(lambda: np.random.uniform(low=0, high=10, size=(env.action_space.n)).astype(np.float32))

In [9]:
SA = SARSA(bins_list, num_episodes=5000, gamma=0.99, epsilon=0.1, max_steps=1000, alpha=0.1)
Q, rewards = SA.fit(env, Q)

 31%|███       | 1534/5000 [02:20<05:16, 10.94it/s]

KeyboardInterrupt



In [None]:
plt.figure(figsize=(15, 3), dpi=400)
plt.plot(np.array(rewards), label="SARSA")
plt.title('CartPole-v1')
plt.xlabel('Episode', fontsize=15)
plt.ylabel('Reward', fontsize=15)
plt.legend(loc='upper right')
plt.tight_layout()
plt.show()

print(f"최종 epsilon: {SA.epsilon}")
print("초반 10개 에피소드 보상:", rewards[:10])
print("마지막 10개 에피소드 보상:", rewards[-10:])
avg_reward_recent = np.mean(rewards[-100:])
print(f"최근 100 에피소드 평균 보상: {avg_reward_recent:.2f}")