In [216]:
import torch.nn as nn
import torch.optim as optim
import gym
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import torch
from torch.distributions import Categorical
import random as rd

In [217]:
learning_rate = 0.00001
gamma = 0.98
episodes = 1000
seed = 777

rd.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x22bc2d0fc70>

In [218]:
class Actor(nn.Module):
    def __init__(self):
        super(Actor,self).__init__()
        self.fc1 = nn.Linear(4,256)
        self.fc2 = nn.Linear(256,256)
        self.fc3 = nn.Linear(256,2)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        prob = F.softmax(self.fc3(x))
        return prob
    
    def optimization(self, Q, prob):
        self.optimizer.zero_grad()

        cost = - Q * torch.log(prob)

        cost.backward()
        self.optimizer.step()

class Critic(nn.Module):
    def __init__(self):
        super(Critic,self).__init__()
        self.fc1 = nn.Linear(4,256)
        self.fc2 = nn.Linear(256,256)
        self.fc3 = nn.Linear(256,1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def optimization(self, delta):
        self.optimizer.zero_grad()

        cost = delta

        cost.backward()
        self.optimizer.step()

In [219]:
is_ipython = 'inline' in plt.get_backend()
if is_ipython:
    from IPython import display

def plot_durations(score_list,show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(score_list, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # 100개의 에피소드 평균을 가져 와서 도표 그리기
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # 도표가 업데이트되도록 잠시 멈춤
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

In [220]:
env = gym.make('CartPole-v1')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
critic_net = Critic().to(device)
actor_net = Actor().to(device)
score_list = []

for episode in range(episodes):
    state, _ = env.reset()
    done = False
    score = 0

    while not done:
        state = torch.tensor(state, device=device)
        prob = actor_net(state)
        distribution = Categorical(prob)
        action = distribution.sample()
        next_state, reward, done, truncated, _ = env.step(action.item())

        next_state = torch.tensor(next_state, device=device)

        done = done
        done_num = 0 if done else 1
        with torch.no_grad():
            y = reward + gamma * critic_net(next_state) * done_num
            Q = critic_net(state)
        delta = (y - critic_net(state))**2

        actor_net.optimization(Q, prob[action])
        critic_net.optimization(delta)

        state = next_state

        score += reward

        if done:
            score_list.append(score + 1)
            plot_durations(score_list)
            break
    
    env.close()

plot_durations(score_list, show_result=True)
plt.ioff()
plt.show()


KeyboardInterrupt: 