<a href="https://colab.research.google.com/github/kuick1kim/kim/blob/main/DQN_FIRST_01.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

#Hyperparameters
learning_rate = 0.0009
gamma         = 0.98

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=0)
        return x
      
    def put_data(self, item):
        self.data.append(item)
        
    def train_net(self):
        R = 0
        self.optimizer.zero_grad()
        for r, prob in self.data[::-1]:
            R = r + gamma * R
            loss = -torch.log(prob) * R
            loss.backward()
        self.optimizer.step()
        self.data = []

def main():
    env = gym.make('CartPole-v1')
    pi = Policy()
    score = 0.0
    print_interval = 20
    
    
    for n_epi in range(10000):
        s = env.reset()
        done = False
        
        while not done: # CartPole-v1 forced to terminates at 500 step.
            prob = pi(torch.from_numpy(s).float())
            m = Categorical(prob)
            a = m.sample()
            s_prime, r, done, info = env.step(a.item())
            pi.put_data((r,prob[a]))
            s = s_prime
            score += r
            
        pi.train_net()
        
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {}".format(n_epi, score/print_interval))
            score = 0.0
    env.close()
    
if __name__ == '__main__':
    main()

# of episode :20, avg score : 23.05
# of episode :40, avg score : 30.35
# of episode :60, avg score : 39.65
# of episode :80, avg score : 51.8
# of episode :100, avg score : 44.0
# of episode :120, avg score : 41.0
# of episode :140, avg score : 48.7
# of episode :160, avg score : 50.4
# of episode :180, avg score : 68.7
# of episode :200, avg score : 86.2
# of episode :220, avg score : 124.0
# of episode :240, avg score : 115.5
# of episode :260, avg score : 146.15
# of episode :280, avg score : 199.75
# of episode :300, avg score : 199.45
# of episode :320, avg score : 227.7
# of episode :340, avg score : 206.85
# of episode :360, avg score : 283.9
# of episode :380, avg score : 249.15
# of episode :400, avg score : 220.55
# of episode :420, avg score : 272.2
# of episode :440, avg score : 240.6
# of episode :460, avg score : 338.4
# of episode :480, avg score : 315.95
# of episode :500, avg score : 285.1
# of episode :520, avg score : 189.25
# of episode :540, avg score : 199.7
# of