In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [2]:
#Hyperparameters
learning_rate = 0.0002
gamma         = 0.98

In [3]:
# .REINFORCE 알고리즘을 사용해 에이전트를 학습시키는 class
class Policy(nn.Module): 
    def __init__(self):
        super(Policy, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4, 128) # 입력 크기 4, 출력크기 128인 완전 연결층
        self.fc2 = nn.Linear(128, 2) # 입력 크기 128, 출력 크기 2인 완전 연결층
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)# Adam을 사용하여 parameter 최적화
        
    def forward(self, x): #순전파 연산 정의
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=0)
        return x #확률 분포 반환
      
    def put_data(self, item):
        self.data.append(item) # 에피소드 간 측정된 데이터 리스트의 추가
        
    def train_net(self):# 에피소드에서으 데이터를 사용해  정책 신경망 학습
        R = 0
        self.optimizer.zero_grad() 
        for r, prob in self.data[::-1]: 
            R = r + gamma * R  #기준 리턴 계산
            loss = -torch.log(prob) * R # 손실 계산 후 업데이트
            loss.backward() #중요!! 역전파 사용해 gradient 계싼, 최적화 수행
        self.optimizer.step()
        self.data = []

In [4]:
def main():
    env = gym.make('CartPole-v1')
    pi = Policy() 
    score = 0.0
    print_interval = 20
    
    for n_epi in range(1000):
        s, _ = env.reset()
        done = False
        
        while not done: # CartPole-v1 forced to terminates at 500 step.
            prob = pi(torch.from_numpy(s).float()) #상태 s를 입력으로 정책 신경망 실행 후 prob 출력
            m = Categorical(prob) #확률 분포 prob 기반으로 Categorical 분포 m 생성
            a = m.sample() #분포 m에서 액션 a 샘플링
            s_prime, r, done, truncated, info = env.step(a.item()) # 액션 a를 환경에 적용한 뒤 S-prime, r, 종료여부 반환받음
            pi.put_data((r,prob[a]))
            s = s_prime
            score += r
            
        pi.train_net() #정책 신경망 학습
        
        if n_epi%print_interval==0 and n_epi!=0: #특정 구간별 학습 결과 출력
            print("# of episode :{}, avg score : {}".format(n_epi, score/print_interval))
            score = 0.0
    env.close()
    
if __name__ == '__main__':
    main()

  if not isinstance(terminated, (bool, np.bool8)):


# of episode :20, avg score : 19.1
# of episode :40, avg score : 22.75
# of episode :60, avg score : 19.35
# of episode :80, avg score : 25.55
# of episode :100, avg score : 27.55
# of episode :120, avg score : 31.7
# of episode :140, avg score : 27.75
# of episode :160, avg score : 29.8
# of episode :180, avg score : 27.7
# of episode :200, avg score : 25.0
# of episode :220, avg score : 30.45
# of episode :240, avg score : 35.4
# of episode :260, avg score : 31.95
# of episode :280, avg score : 35.85
# of episode :300, avg score : 29.55
# of episode :320, avg score : 43.35
# of episode :340, avg score : 29.05
# of episode :360, avg score : 38.05
# of episode :380, avg score : 34.8
# of episode :400, avg score : 43.0
# of episode :420, avg score : 42.9
# of episode :440, avg score : 37.2
# of episode :460, avg score : 60.65
# of episode :480, avg score : 52.55
# of episode :500, avg score : 42.75
# of episode :520, avg score : 47.15
# of episode :540, avg score : 46.9
# of episode :56