In [1]:
import numpy as np
import gym
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimizer
import torchinfo


from IPython import display
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from a3c import A3CAgent, A3CTrainer

In [2]:

class Viewer(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(4, 128),
        )
    
    def forward(self, x):
        x = self.layers(x)
        return x

class Actor(nn.Module):
    def __init__(self, action_count):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_count),
            nn.Softmax()
        )
    
    def forward(self, x):
        x = self.layers(x)
        return x

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
    
    def forward(self, x):
        x = self.layers(x)
        return x


In [3]:
viewer = Viewer().cuda()
actor = Actor(2).cuda()
critic = Critic().cuda()

opt_viewer = optimizer.Adam(viewer.parameters(), lr = 0.0001)
opt_actor = optimizer.Adam(actor.parameters(), lr = 0.0001)
opt_critic = optimizer.Adam(critic.parameters(), lr = 0.0001)

In [4]:
class CartPoleAgent(A3CAgent):
    def __init__(self, action_count):
        super().__init__(action_count)
        self.train_log = [0]

    def actor(self, x):
        return actor(viewer(x))

    def critic(self, x):
        return critic(viewer(x))

    def train_critic(self, loss):
        viewer.train()
        critic.train()
        opt_viewer.zero_grad()
        opt_critic.zero_grad()
        loss.backward()
        opt_viewer.step()
        opt_critic.step()
        viewer.eval()
        critic.eval()

    def train_actor(self, loss):
        viewer.train()
        actor.train()
        opt_viewer.zero_grad()
        opt_actor.zero_grad()
        loss.backward()
        opt_viewer.step()
        opt_actor.step()
        viewer.eval()
        actor.eval()
    
    def onStep(self, s0, a, r, s1, p, done, info):
        self.train_log[-1] += r
        if done:
            self.train_log.append(0)
            

In [5]:
trainer = A3CTrainer("CartPole-v1", CartPoleAgent, 8, 10, gamma = 0.99)

  deprecation(
  deprecation(


In [9]:
img = trainer.episodes[0].env.render(mode='rgb_array')
while True:
    trainer.step()
    print(f"Episode {len(trainer.episodes[0].agent.train_log)}")
    print(f"score: {trainer.episodes[0].agent.train_log[-2:][0]}")
    print(f"high : {max(trainer.episodes[0].agent.train_log)}")
    display.clear_output(wait=True)
    img = trainer.episodes[0].env.render(mode='rgb_array')
    plt.figure(figsize=(12, 10))
    plt.imshow(img)
    plt.show()


KeyboardInterrupt

