<a href="https://colab.research.google.com/github/henry-bokyum-kim/NNStudy/blob/bokyum/%5BRL%5D%5BPG%5D%5BREINFORCE%5D%5BBK%5Dreinforce_basic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#!/usr/bin/env python3
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opt
import numpy as np
import random

class PGN(nn.Module):
    def __init__(self, in_size, out_size, hidden = 256):
        super(PGN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_size, hidden),
            nn.ReLU(),
            nn.Linear(hidden, int(hidden/2)),
            nn.ReLU(),
            nn.Linear(int(hidden/2), out_size)
        )
    def forward(self, x):
        return self.net(x)

EPOCH = 4000
LR=0.0005
GAMMA = 0.99
BATCH_EP = 3
ENTROPY_WEIGHT = 1
GAME = "CartPole-v1"

env = gym.make(GAME)
net = PGN(env.observation_space.shape[0], env.action_space.n, 128)
optim = opt.Adam(net.parameters(), lr = LR)

In [0]:
def get_onehot(i, size):
    ret= np.zeros(size)
    ret[i]=1
    return list(ret)

In [0]:
if __name__ == "__main__":    
    batch_size = 0
    batch_obs = []
    batch_act = []
    batch_qval = []
    done_ = 0
    total_rew = 0
    total_step = 0
    for epoch in range(EPOCH):
        obs = env.reset()
        if GAME == "FrozenLake-v0":
            obs = get_onehot(obs, env.observation_space.n)
        cur_rew = []
        count = 0
        while True:
            with torch.no_grad():
                out = net(torch.FloatTensor(obs))
                act = np.random.choice(env.action_space.n, p=F.softmax(out).numpy())
                
            next_obs, rew, done, _ = env.step(act)
            count+=1
            total_step+=1
            
            if GAME == "FrozenLake-v0":
                if done and rew == 0 or count == 100:
                    rew = -100
                elif done and rew == 1:
                    rew = 10
                else:
                    rew = -0.1
            
            batch_obs.append(obs)
            batch_act.append(act)
            cur_rew.append(rew)
            
            obs = next_obs
            if GAME == "FrozenLake-v0":
                obs = get_onehot(obs, env.observation_space.n)
            
            if done:
                qval = []
                r_sum = 0
                for r in reversed(cur_rew):
                    r_sum*= GAMMA
                    r_sum+= r
                    qval.append(r_sum)
                qval = np.array(qval)
                total_rew += qval.sum()
                baseline = total_rew/total_step
                qval = qval - baseline
                
                batch_qval.extend(list(reversed(qval)))
                batch_size+=1
                break
        print("epoch : %d count : %d rew : %d"%(epoch, count, rew))
        if rew > 0:
            done_ +=1
                
        if batch_size == BATCH_EP:
            obss = torch.FloatTensor(batch_obs)
            acts = torch.LongTensor(batch_act).unsqueeze(-1)
            qvals = torch.FloatTensor(batch_qval).unsqueeze(-1)
                
            optim.zero_grad()
            
            logit = net(obss)
            log_prob = F.log_softmax(logit).gather(1, acts)
            value = qvals * log_prob
            value_loss = -value.mean()
            
            prob = F.softmax(logit).gather(1, acts)
            entropy = -prob * log_prob
            entropy_loss = (-ENTROPY_WEIGHT * entropy).mean()
            
            
            loss = value_loss + entropy_loss
            loss.backward()
            optim.step()

            batch_size = 0
            batch_obs = []
            batch_act = []
            batch_qval = []



epoch : 0 count : 15 rew : 1
epoch : 1 count : 30 rew : 1
epoch : 2 count : 15 rew : 1
epoch : 3 count : 23 rew : 1
epoch : 4 count : 22 rew : 1
epoch : 5 count : 19 rew : 1
epoch : 6 count : 21 rew : 1
epoch : 7 count : 13 rew : 1
epoch : 8 count : 11 rew : 1
epoch : 9 count : 8 rew : 1
epoch : 10 count : 14 rew : 1
epoch : 11 count : 17 rew : 1
epoch : 12 count : 12 rew : 1
epoch : 13 count : 72 rew : 1
epoch : 14 count : 17 rew : 1
epoch : 15 count : 29 rew : 1
epoch : 16 count : 22 rew : 1
epoch : 17 count : 13 rew : 1
epoch : 18 count : 13 rew : 1
epoch : 19 count : 16 rew : 1
epoch : 20 count : 27 rew : 1
epoch : 21 count : 21 rew : 1
epoch : 22 count : 22 rew : 1
epoch : 23 count : 34 rew : 1
epoch : 24 count : 11 rew : 1
epoch : 25 count : 11 rew : 1
epoch : 26 count : 14 rew : 1
epoch : 27 count : 17 rew : 1
epoch : 28 count : 24 rew : 1
epoch : 29 count : 25 rew : 1
epoch : 30 count : 15 rew : 1
epoch : 31 count : 16 rew : 1
epoch : 32 count : 9 rew : 1
epoch : 33 count : 19 