- [Open with Colab](https://colab.research.google.com/github/Danboruya/my-research-hub/blob/master/Cartpole-v0-pytorch.ipynb)

# CartPorle-v0 implement using PyTorch

## Running environment setup

In [0]:
!apt-get install xvfb ffmpeg
!pip install -q gym imageio PILLOW pyvirtualdisplay 'gym[atari]' JSAnimation
!pip install -U 'pyglet==1.3.2' pyopengl scipy
!apt-get install -y cmake zlib1g-dev libjpeg-dev xvfb ffmpeg xorg-dev python-opengl libboost-all-dev libsdl2-dev swig freeglut3-dev
!pip install 'gym[atari]' opencv-python pillow h5py pyyaml hyperdash pyvirtualdisplay

Reading package lists... Done
Building dependency tree       
Reading state information... Done
ffmpeg is already the newest version (7:3.4.6-0ubuntu0.18.04.1).
The following package was automatically installed and is no longer required:
  libnvidia-common-410
Use 'apt autoremove' to remove it.
The following NEW packages will be installed:
  xvfb
0 upgraded, 1 newly installed, 0 to remove and 4 not upgraded.
Need to get 783 kB of archives.
After this operation, 2,266 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 xvfb amd64 2:1.19.6-1ubuntu4.3 [783 kB]
Fetched 783 kB in 3s (302 kB/s)
Selecting previously unselected package xvfb.
(Reading database ... 131289 files and directories currently installed.)
Preparing to unpack .../xvfb_2%3a1.19.6-1ubuntu4.3_amd64.deb ...
Unpacking xvfb (2:1.19.6-1ubuntu4.3) ...
Setting up xvfb (2:1.19.6-1ubuntu4.3) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
  Building wheel for JSA

In [0]:
!cp xdpyinfo /usr/bin/
!cp libXxf86dga.* /usr/lib/x86_64-linux-gnu/
!chmod +x /usr/bin/xdpyinfo

## Package import

In [0]:
%matplotlib inline

import random
import os

import numpy as np
import matplotlib.pyplot as plt
import gym
import torch
import torch.nn.functional as F
import pyvirtualdisplay
import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import PIL.Image
import pyvirtualdisplay


from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from pyvirtualdisplay import Display
from IPython.display import display
from collections import namedtuple

In [0]:
# display = Display(visible=0, size=(1024, 768))
# display.start()
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
os.environ["DISPLAY"] = ":" + str(display.display) + "." + str(display.screen)

In [0]:
def display_frames_as_gif(frames, file_name):
    # plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)
    patch = plt.imshow(frames[0])
    # plt.axes('off')
    
    def animate(i):
        patch.set_data(frames[i])
    
    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    
    anim.save(file_name)
    display(display_animation(anim, default_mode='loop'))

In [0]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

## Hyper parameters

In [0]:
ENV = 'CartPole-v0' # @param
GAMMA = 0.99 # @param
MAX_STEPS = 200 #@param
NUM_EPISODES = 500 # @param
BATCH_SIZE = 32 # @param
CAPACITY = 10000 # @param
FILE_NAME = "test.mp4" # @param
LEARNING_RATE = 0.0001 # @param
random.seed(0)

## Requiermented modules

In [0]:
class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.index = 0
    
    
    def push(self, state, action, next_state, reward):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        
        self.memory[self.index] = Transition(state, action, next_state, reward)
        self.index = (self.index + 1) % self.capacity
        
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    
    def __len__(self):
        return len(self.memory)

In [0]:
class AgentCore:
    def __init__(self, num_state, num_actions):
        self.num_actions = num_actions
        self.memory = ReplayMemory(CAPACITY)
        
        # Agent neural network
        self.model = torch.nn.Sequential()
        self.model.add_module('fc0', torch.nn.Linear(num_state, 32))
        self.model.add_module('relu0', torch.nn.ReLU())
        self.model.add_module('fc1', torch.nn.Linear(32, 32))
        self.model.add_module('relu1', torch.nn.ReLU())
        self.model.add_module('fc2', torch.nn.Linear(32, num_actions))
        
        print(self.model)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        
    
    def replay(self):
        if len(self.memory) < BATCH_SIZE:
            return

        transitions = self.memory.sample(BATCH_SIZE)
        batch = Transition(*zip(*transitions))
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

        self.model.eval()
        state_action_values = self.model(state_batch).gather(1, action_batch)
        non_final_mask = torch.ByteTensor(tuple(map(lambda s: s is not None, batch.next_state)))
        next_state_values = torch.zeros(BATCH_SIZE)
        next_state_values[non_final_mask] = self.model(non_final_next_states).max(1)[0].detach()
        expected_state_action_values = reward_batch + GAMMA * next_state_values

        self.model.train()
        loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    
    
    def decide_action(self, state, episode):
        epsilon = 0.5 * (1 / (episode + 1))
        
        if epsilon <= np.random.uniform(0, 1):
            self.model.eval()
            with torch.no_grad():
                action = self.model(state).max(1)[1].view(1, 1)
        else:
            action = torch.LongTensor([[random.randrange(self.num_actions)]])
        return action

In [0]:
class Agent:
    def __init__(self, num_states, num_actions):
        self.core = AgentCore(num_states, num_actions)
        
        
    def update_q_function(self):
        self.core.replay()
    
    
    def get_action(self, state, episode):
        action = self.core.decide_action(state, episode)
        return action
    
    
    def memorize(self, state, action, next_state, reward):
        self.core.memory.push(state, action, next_state, reward)

In [0]:
class Environment:

    def __init__(self):
        self.env = gym.make(ENV)
        num_states = self.env.observation_space.shape[0]
        num_actions = self.env.action_space.n
        self.agent = Agent(num_states, num_actions)

        
    def run(self):
        episode_10_list = np.zeros(10)
        complete_episodes = 0
        episode_final = False
        frames = []

        for episode in range(NUM_EPISODES):
            observation = self.env.reset()

            state = observation
            state = torch.from_numpy(state).type(torch.FloatTensor)
            state = torch.unsqueeze(state, 0)

            for step in range(MAX_STEPS):
                if episode_final is True:
                    frames.append(self.env.render(mode='rgb_array'))

                action = self.agent.get_action(state, episode)
                next_observation, _, done, _ = self.env.step(action.item())

                if done:
                    next_state = None

                    episode_10_list = np.hstack((episode_10_list[1:], step + 1))

                    if step < 195:
                        reward = torch.FloatTensor([-1.0])
                        complete_episodes = 0
                    else:
                        reward = torch.FloatTensor([1.0])
                        complete_episodes = complete_episodes + 1
                else:
                    reward = torch.FloatTensor([0.0])
                    next_state = next_observation
                    next_state = torch.from_numpy(next_state).type(torch.FloatTensor)
                    next_state = torch.unsqueeze(next_state, 0)

                self.agent.memorize(state, action, next_state, reward)
                self.agent.update_q_function()
                state = next_state

                if done:
                    print('%d Episode: Finished after %d steps：10 step average = %.1lf' % (episode, step + 1, episode_10_list.mean()))
                    break

            if episode_final is True:
                display_frames_as_gif(frames, FILE_NAME)
                break

            if complete_episodes >= 10:
                print("10 consecutive successes")
                episode_final = True

## Evaluation of the model

In [0]:
cartpole_env = Environment()
cartpole_env.run()

Sequential(
  (fc0): Linear(in_features=4, out_features=32, bias=True)
  (relu0): ReLU()
  (fc1): Linear(in_features=32, out_features=32, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=32, out_features=2, bias=True)
)
0 Episode: Finished after 12 steps：10 step average = 1.2
1 Episode: Finished after 12 steps：10 step average = 2.4
2 Episode: Finished after 10 steps：10 step average = 3.4
3 Episode: Finished after 10 steps：10 step average = 4.4
4 Episode: Finished after 10 steps：10 step average = 5.4
5 Episode: Finished after 9 steps：10 step average = 6.3
6 Episode: Finished after 10 steps：10 step average = 7.3
7 Episode: Finished after 8 steps：10 step average = 8.1
8 Episode: Finished after 9 steps：10 step average = 9.0
9 Episode: Finished after 10 steps：10 step average = 10.0
10 Episode: Finished after 10 steps：10 step average = 9.8
11 Episode: Finished after 10 steps：10 step average = 9.6
12 Episode: Finished after 10 steps：10 step average = 9.6
13 Episode: Finished after 10 

TypeError: ignored

## Video visualization

In [0]:
def embed_mp4(filename):
    """Embeds an mp4 file in the notebook."""
    
    video = open(filename,'rb').read()
    b64 = base64.b64encode(video)
    tag = '''
    <video width="640" height="480" controls>
        <source src="data:video/mp4;base64,{0}" type="video/mp4">
    Your browser does not support the video tag.
    </video>'''.format(b64.decode())

    return IPython.display.HTML(tag)

In [0]:
embed_mp4(FILE_NAME)