In [21]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

learning_rate = 0.0002
gamma = 0.98
n_rollout = 10

In [22]:
class ActorCritic(nn.Module):
    def __init__(self, n_observation: int, n_actions: int, lr: float) -> None:
        super(ActorCritic, self).__init__()
        self.data = []
        self.n_obs = n_observation
        self.n_act = n_actions
        self.lr = lr

        self.fc1 = nn.Linear(self.n_obs, 256)
        self.fc_policy = nn.Linear(256, 2)
        self.fc_value = nn.Linear(256, 1)
        self.optim = optim.Adam(self.parameters(), lr=self.lr)

    def policy_net(self, x, softmax_dim = 0):
        x = F.relu(self.fc1(x))
        x = self.fc_policy(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob
    
    def value_net(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_value(x)
        return v
    
    def put_data(self, transition):
        self.data.append(transition)

    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], []
        for transition in self.data:
            s, a, r, s_prime, done = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r/100.0])
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0
            done_lst.append([done_mask])
        s_batch, a_batch, r_batch, s_prime_batch, done_batch = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                                                               torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), \
                                                               torch.tensor(done_lst, dtype=torch.float)
        self.data = []
        return s_batch, a_batch, r_batch, s_prime_batch, done_batch
    
    def train_net(self):
        s, a, r, s_prime, done = self.make_batch()
        td_target = r + gamma * self.value_net(s_prime) * done
        delta = td_target - self.value_net(s)

        pi = self.policy_net(s, softmax_dim=1)
        pi_a = pi.gather(1, a)
        loss = -torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.value_net(s), td_target.detach())

        self.optim.zero_grad()
        loss.mean().backward()
        self.optim.step()

In [27]:
env = gym.make('CartPole-v1', render_mode="rgb_array")
state = env.reset()
n_actions = env.action_space.n
n_observation = state[0].shape[0]
model = ActorCritic(n_observation=n_observation, n_actions=n_actions, lr=learning_rate)
print_interval = 20
score = 0.0

for episode in range(6000):
    count = 0
    terminated = False
    state = env.reset()[0]
    while not terminated:
        for t in range(n_rollout):
            prob = model.policy_net(torch.from_numpy(state).float())
            m = Categorical(prob)
            action = m.sample().item()
            s_prime, r, terminated, _, info = env.step(action)
            count += 1
            model.put_data((state, action, r, s_prime, terminated))

            state = s_prime
            score += r
            if terminated:
                break
        model.train_net()
    if episode % print_interval == 0 and episode != 0:
        print(f"# episode : {episode}, avg score : {score/print_interval}, count : {count}")
        score = 0.0
env.close()

  if not isinstance(terminated, (bool, np.bool8)):


# episode : 20, avg score : 22.4, count : 15
# episode : 40, avg score : 19.65, count : 19
# episode : 60, avg score : 15.65, count : 17
# episode : 80, avg score : 19.4, count : 19
# episode : 100, avg score : 22.2, count : 22
# episode : 120, avg score : 29.8, count : 13
# episode : 140, avg score : 28.65, count : 22
# episode : 160, avg score : 35.65, count : 35
# episode : 180, avg score : 34.2, count : 16
# episode : 200, avg score : 49.35, count : 20
# episode : 220, avg score : 45.9, count : 30
# episode : 240, avg score : 39.8, count : 77
# episode : 260, avg score : 44.0, count : 68
# episode : 280, avg score : 51.45, count : 48
# episode : 300, avg score : 55.8, count : 53
# episode : 320, avg score : 63.15, count : 97
# episode : 340, avg score : 54.3, count : 92
# episode : 360, avg score : 77.3, count : 44
# episode : 380, avg score : 88.3, count : 31
# episode : 400, avg score : 101.7, count : 110
# episode : 420, avg score : 101.55, count : 331
# episode : 440, avg score

In [65]:
from PIL import Image
def show_gif(images: list) -> None:
    image_file = 'test.gif'
    images[0].save(image_file, save_all=True, append_images=images[1:], loop=0, duration=1)

state, _ = env.reset()
images = []

terminated = False
count = 0
while not terminated:
    screen = env.render()
    images += [Image.fromarray(screen)]
    for t in range(n_rollout):
        prob = model.policy_net(torch.from_numpy(state).float())
        m = Categorical(prob)
        a = m.sample().item()
        s_prime, r, terminated, truncated, info = env.step(a)
        count += 1
        state = s_prime

        if terminated:
            break
env.close()
print(count)
show_gif(images)

335
