In [1]:
# default_exp actorcritic.sac

In [2]:
#export
import torch.nn.utils as nn_utils
from fastai.torch_basics import *
import torch.nn.functional as F
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym
import torch.multiprocessing as mp
from torch.optim import *
from dataclasses import dataclass

from fastrl.data import *
from fastrl.async_data import *
from fastrl.basic_agents import *
from fastrl.learner import *
from fastrl.metrics import *
from fastai.callback.progress import *
from fastrl.ptan_extension import *

from torch.distributions import *

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# PPO

In [4]:
# export
HID_SIZE = 64

class ModelActor(nn.Module):
    def __init__(self, obs_size, act_size):
        super(ModelActor, self).__init__()

        self.mu = nn.Sequential(
            nn.Linear(obs_size, HID_SIZE),
            nn.Tanh(),
            nn.Linear(HID_SIZE, HID_SIZE),
            nn.Tanh(),
            nn.Linear(HID_SIZE, act_size),
            nn.Tanh(),
        )
        self.logstd = nn.Parameter(torch.zeros(act_size))

    def forward(self, x):
        return self.mu(x.float())


class ModelCritic(nn.Module):
    def __init__(self, obs_size):
        super(ModelCritic, self).__init__()

        self.value = nn.Sequential(
            nn.Linear(obs_size, HID_SIZE),
            nn.ReLU(),
            nn.Linear(HID_SIZE, HID_SIZE),
            nn.ReLU(),
            nn.Linear(HID_SIZE, 1),
        )

    def forward(self, x):
        return self.value(x.float())


class AgentA2C(BaseAgent):
    
    preprocessor:Callable=default_states_preprocessor
    def __init__(self, model, device="cpu"):
        self.model = model
        self.device = device

    def __call__(self, states, agent_states):
        states_v = torch.tensor(np.stack(states)).float().to(self.device) #self.preprocessor(states[0].reshape(1,-1))

        mu_v = self.model(states_v)
        mu = mu_v.data.cpu().numpy()
        logstd = self.model.logstd.data.cpu().numpy()
        actions = mu + np.exp(logstd) * np.random.normal(size=logstd.shape)
        actions = np.clip(actions, -1, 1)
#         print(actions)
        return actions, agent_states

In [5]:
# export
GAMMA = 0.99
GAE_LAMBDA = 0.95

LEARNING_RATE_ACTOR = 1e-4
LEARNING_RATE_CRITIC = 1e-3

PPO_EPS = 0.2
PPO_EPOCHES = 10
PPO_BATCH_SIZE = 64

def calc_logprob(mu_v, logstd_v, actions_v):
    p1 = - ((mu_v - actions_v) ** 2) / (2*torch.exp(logstd_v).clamp(min=1e-3))
    p2 = - torch.log(torch.sqrt(2 * math.pi * torch.exp(logstd_v)))
    return p1 + p2


def calc_adv_ref(trajectory, net_crt, states_v, device="cpu"):
    """
    By trajectory calculate advantage and 1-step ref value
    :param trajectory: trajectory list
    :param net_crt: critic network
    :param states_v: states tensor
    :return: tuple with advantage numpy array and reference values
    """
    values_v = net_crt(states_v)
    values = values_v.squeeze().data.cpu().numpy()
    # generalized advantage estimator: smoothed version of the advantage
    last_gae = 0.0
    result_adv = []
    result_ref = []
    for val, next_val, (exp,) in zip(reversed(values[:-1]), reversed(values[1:]),
                                     reversed(trajectory[:-1])):
        if exp.done:
            delta = exp.reward - val
            last_gae = delta
        else:
            delta = exp.reward + GAMMA * next_val - val
            last_gae = delta + GAMMA * GAE_LAMBDA * last_gae
        result_adv.append(last_gae)
        result_ref.append(last_gae + val)

    adv_v = torch.FloatTensor(list(reversed(result_adv))).to(device)
    ref_v = torch.FloatTensor(list(reversed(result_ref))).to(device)
    return adv_v, ref_v


In [6]:
# export
def loss_func(*yb,learn):
    b=list(learn.xb)+list(learn.yb)
    yxb=b
    trajectory=[(Experience(state=b[0][i],action=b[1][i],reward=b[2][i],
                done=(b[3][i] and learn.max_step!=b[5][i]),episode_reward=b[4][i],
                steps=b[5][i]),) for i in range(len(b[0]))]
    net_crt=learn.net_crt
    net_act=learn.model
    opt_crt=learn.opt_crt
    opt_act=learn.opt_act
    

    traj_states = [t[0].state for t in trajectory]
    traj_actions = [t[0].action for t in trajectory]
#     traj_states = [t.state for t in trajectory]
#     traj_actions = [t.action for t in trajectory]
#     traj_states_v = torch.FloatTensor(traj_states).to(device)
#     traj_actions_v = torch.FloatTensor(traj_actions).to(device)
    traj_states_v = torch.stack(traj_states).float().to(default_device())
    traj_actions_v = torch.stack(traj_actions).float().to(default_device())
#     print(traj_states_v.size(),net_act)
    traj_adv_v, traj_ref_v = calc_adv_ref(trajectory, net_crt, traj_states_v, device=default_device())
    mu_v = net_act(traj_states_v)
    old_logprob_v = calc_logprob(mu_v, net_act.logstd, traj_actions_v)

    # normalize advantages
    traj_adv_v = (traj_adv_v - torch.mean(traj_adv_v)) / torch.std(traj_adv_v)

    # drop last entry from the trajectory, an our adv and ref value calculated without it
    trajectory = trajectory[:-1]
    old_logprob_v = old_logprob_v[:-1].detach()

    sum_loss_value = 0.0
    sum_loss_policy = 0.0
    count_steps = 0

    for epoch in range(PPO_EPOCHES):
        for batch_ofs in range(0, len(trajectory), PPO_BATCH_SIZE):
            states_v = traj_states_v[batch_ofs:batch_ofs + PPO_BATCH_SIZE]
            actions_v = traj_actions_v[batch_ofs:batch_ofs + PPO_BATCH_SIZE]
            batch_adv_v = traj_adv_v[batch_ofs:batch_ofs + PPO_BATCH_SIZE].unsqueeze(-1)
            batch_ref_v = traj_ref_v[batch_ofs:batch_ofs + PPO_BATCH_SIZE]
            batch_old_logprob_v = old_logprob_v[batch_ofs:batch_ofs + PPO_BATCH_SIZE]

            # critic training
            opt_crt.zero_grad()
            value_v = net_crt(states_v)
            loss_value_v = F.mse_loss(value_v.squeeze(-1), batch_ref_v)
            loss_value_v.backward()
            opt_crt.step()

            # actor training
            opt_act.zero_grad()
            mu_v = net_act(states_v)
            logprob_pi_v = calc_logprob(mu_v, net_act.logstd, actions_v)
            ratio_v = torch.exp(logprob_pi_v - batch_old_logprob_v)
            surr_obj_v = batch_adv_v * ratio_v
            clipped_surr_v = batch_adv_v * torch.clamp(ratio_v, 1.0 - PPO_EPS, 1.0 + PPO_EPS)
            loss_policy_v = -torch.min(surr_obj_v, clipped_surr_v).mean()
            loss_policy_v.backward()
            opt_act.step()

            sum_loss_value += loss_value_v.item()
            sum_loss_policy += loss_policy_v.item()
            count_steps += 1
    return torch.tensor(sum_loss_value+sum_loss_policy)

In [7]:
# export
class PPOTrainer(Callback):
    def after_loss(self):raise CancelBatchException

In [8]:
# export
class PPOLearner(AgentLearner):
    def __init__(self,dls,agent=None,actr_lr=1e-4,crtic_lr=1e-3,max_step=1,**kwargs):
        store_attr()
        self.net_crt=ModelCritic(3).to(default_device())
        self.opt_act = Adam(agent.model.parameters(), lr=actr_lr)
        self.opt_crt = Adam(self.net_crt.parameters(), lr=crtic_lr)

        super().__init__(dls,loss_func=partial(loss_func,learn=self),model=agent.model,**kwargs)

In [15]:
env='Pendulum-v0'
agent=AgentA2C(ModelActor(3, 1).to(default_device()), device=default_device())

block=ExperienceBlock(agent=agent,seed=0,n_steps=1,
                      dls_kwargs={'bs':2049,'num_workers':0,'verbose':False,'indexed':True,'shuffle_train':False})
blk=IterableDataBlock(blocks=(block),splitter=FuncSplitter(lambda x:False))
dls=blk.dataloaders([env]*1,n=2049,device=default_device())

learner=PPOLearner(dls,agent=agent,cbs=[PPOTrainer], metrics=[AvgEpisodeRewardMetric(Experience)],max_step=gym.make(env)._max_episode_steps)
learner.fit(30,lr=0.001,wd=0)

epoch,train_loss,train_avg_episode_r,valid_loss,valid_avg_episode_r,time
0,2447894.25,-833.402907,,-833.402907,00:02
1,1443182.875,-1079.645498,,-1079.645498,00:02
2,1004285.0625,-1214.675004,,-1214.675004,00:02
3,771823.125,-1296.550904,,-1296.550904,00:02
4,629101.875,-1336.413165,,-1336.413165,00:02
5,534776.8125,-1362.607398,,-1362.607398,00:02
6,484733.28125,-1370.539334,,-1370.539334,00:02
7,485329.96875,-1365.002604,,-1365.002604,00:02
8,446202.34375,-1364.49923,,-1364.49923,00:02
9,428917.25,-1354.56139,,-1354.56139,00:02


  warn("Your generator is empty.")
