In [None]:
# default_exp actorcritic.sac

In [None]:
#export
import torch.nn.utils as nn_utils
from fastai.torch_basics import *
import torch.nn.functional as F
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym
import torch.multiprocessing as mp
from torch.optim import *
from dataclasses import dataclass

from fastrl.data import *
from fastrl.async_data import *
from fastrl.basic_agents import *
from fastrl.learner import *
from fastrl.metrics import *
from fastai.callback.progress import *
from fastrl.ptan_extension import *

from torch.distributions import *

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

  return torch._C._cuda_getDeviceCount() > 0


In [None]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# TRPO

In [None]:
# export
HID_SIZE = 64

class ModelActor(nn.Module):
    def __init__(self, obs_size, act_size):
        super(ModelActor, self).__init__()

        self.mu = nn.Sequential(
            nn.Linear(obs_size, HID_SIZE),
            nn.Tanh(),
            nn.Linear(HID_SIZE, HID_SIZE),
            nn.Tanh(),
            nn.Linear(HID_SIZE, act_size),
            nn.Tanh(),
        )
        self.logstd = nn.Parameter(torch.zeros(act_size))

    def forward(self, x):
        return self.mu(x.float())


class ModelCritic(nn.Module):
    def __init__(self, obs_size):
        super(ModelCritic, self).__init__()

        self.value = nn.Sequential(
            nn.Linear(obs_size, HID_SIZE),
            nn.ReLU(),
            nn.Linear(HID_SIZE, HID_SIZE),
            nn.ReLU(),
            nn.Linear(HID_SIZE, 1),
        )

    def forward(self, x):
        return self.value(x.float())


class AgentA2C(BaseAgent):
    
    preprocessor:Callable=default_states_preprocessor
    def __init__(self, model, device="cpu"):
        self.model = model
        self.device = device

    def __call__(self, states, agent_states):
        states_v = torch.tensor(np.stack(states)).float().to(self.device) #self.preprocessor(states[0].reshape(1,-1))

        mu_v = self.model(states_v)
        mu = mu_v.data.cpu().numpy()
        logstd = self.model.logstd.data.cpu().numpy()
        actions = mu + np.exp(logstd) * np.random.normal(size=logstd.shape)
        actions = np.clip(actions, -1, 1)
#         print(actions)
        return actions, agent_states

In [None]:
# export
GAMMA = 0.99
GAE_LAMBDA = 0.95

LEARNING_RATE_ACTOR = 1e-4
LEARNING_RATE_CRITIC = 1e-3

PPO_EPS = 0.2
PPO_EPOCHES = 10
PPO_BATCH_SIZE = 64

def calc_logprob(mu_v, logstd_v, actions_v):
    p1 = - ((mu_v - actions_v) ** 2) / (2*torch.exp(logstd_v).clamp(min=1e-3))
    p2 = - torch.log(torch.sqrt(2 * math.pi * torch.exp(logstd_v)))
    return p1 + p2


def calc_adv_ref(trajectory, net_crt, states_v, device="cpu"):
    """
    By trajectory calculate advantage and 1-step ref value
    :param trajectory: trajectory list
    :param net_crt: critic network
    :param states_v: states tensor
    :return: tuple with advantage numpy array and reference values
    """
    values_v = net_crt(states_v)
    values = values_v.squeeze().data.cpu().numpy()
    # generalized advantage estimator: smoothed version of the advantage
    last_gae = 0.0
    result_adv = []
    result_ref = []
    for val, next_val, (exp,) in zip(reversed(values[:-1]), reversed(values[1:]),
                                     reversed(trajectory[:-1])):
        if exp.done:
            delta = exp.reward - val
            last_gae = delta
        else:
            delta = exp.reward + GAMMA * next_val - val
            last_gae = delta + GAMMA * GAE_LAMBDA * last_gae
        result_adv.append(last_gae)
        result_ref.append(last_gae + val)

    adv_v = torch.FloatTensor(list(reversed(result_adv))).to(device)
    ref_v = torch.FloatTensor(list(reversed(result_ref))).to(device)
    return adv_v, ref_v


In [None]:
# export
TRPO_MAX_KL = 0.01
TRPO_DAMPING = 0.1

def get_flat_params_from(model):
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))

    flat_params = torch.cat(params)
    return flat_params


def set_flat_params_to(model, flat_params):
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.data.copy_(
            flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size


def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10, device="cpu"):
    x = torch.zeros(b.size()).to(device)
    r = b.clone()
    p = b.clone()
    rdotr = torch.dot(r, r)
    for i in range(nsteps):
        _Avp = Avp(p)
        alpha = rdotr / torch.dot(p, _Avp)
        x += alpha * p
        r -= alpha * _Avp
        new_rdotr = torch.dot(r, r)
        betta = new_rdotr / rdotr
        p = r + betta * p
        rdotr = new_rdotr
        if rdotr < residual_tol:
            break
    return x


def linesearch(model,
               f,
               x,
               fullstep,
               expected_improve_rate,
               max_backtracks=10,
               accept_ratio=.1):
    fval = f().data
    for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
        xnew = x + fullstep * stepfrac
        set_flat_params_to(model, xnew)
        newfval = f().data
        actual_improve = fval - newfval
        expected_improve = expected_improve_rate * stepfrac
        ratio = actual_improve / expected_improve

        if ratio.item() > accept_ratio and actual_improve.item() > 0:
            return True, xnew
    return False, x

def trpo_step(model, get_loss, get_kl, max_kl, damping, device="cpu"):
    loss = get_loss()
    grads = torch.autograd.grad(loss, model.parameters())
    loss_grad = torch.cat([grad.view(-1) for grad in grads]).data

    def Fvp(v):
        kl = get_kl()
        kl = kl.mean()

        grads = torch.autograd.grad(kl, model.parameters(), create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

        v_v = torch.tensor(v).to(device)
        kl_v = (flat_grad_kl * v_v).sum()
        grads = torch.autograd.grad(kl_v, model.parameters())
        flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data

        return flat_grad_grad_kl + v * damping

    stepdir = conjugate_gradients(Fvp, -loss_grad, 10, device=device)

    shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)

    lm = torch.sqrt(shs / max_kl)
    fullstep = stepdir / lm[0]

    neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)

    prev_params = get_flat_params_from(model)
    success, new_params = linesearch(model, get_loss, prev_params, fullstep,
                                     neggdotstepdir / lm[0])
    set_flat_params_to(model, new_params)

    return loss

def loss_func(*yb,learn):
    b=list(learn.xb)+list(learn.yb)
    yxb=b
    trajectory=[(Experience(state=b[0][i],action=b[1][i],reward=b[2][i],
                done=(b[3][i] and learn.max_step!=b[5][i]),episode_reward=b[4][i],
                steps=b[5][i]),) for i in range(len(b[0]))]
    net_crt=learn.net_crt
    net_act=learn.model
    opt_crt=learn.opt_crt
    

    traj_states = [t[0].state for t in trajectory]
    traj_actions = [t[0].action for t in trajectory]
#     traj_states = [t.state for t in trajectory]
#     traj_actions = [t.action for t in trajectory]
#     traj_states_v = torch.FloatTensor(traj_states).to(device)
#     traj_actions_v = torch.FloatTensor(traj_actions).to(device)
    traj_states_v = torch.stack(traj_states).float().to(default_device())
    traj_actions_v = torch.stack(traj_actions).float().to(default_device())
#     print(traj_states_v.size(),net_act)
    traj_adv_v, traj_ref_v = calc_adv_ref(trajectory, net_crt, traj_states_v, device=default_device())
    mu_v = net_act(traj_states_v)
    old_logprob_v = calc_logprob(mu_v, net_act.logstd, traj_actions_v)

    # normalize advantages
    traj_adv_v = (traj_adv_v - torch.mean(traj_adv_v)) / torch.std(traj_adv_v)

    # drop last entry from the trajectory, an our adv and ref value calculated without it
    trajectory = trajectory[:-1]
    old_logprob_v = old_logprob_v[:-1].detach()
    traj_states_v = traj_states_v[:-1]
    traj_actions_v = traj_actions_v[:-1]
    sum_loss_value = 0.0
    sum_loss_policy = 0.0
    count_steps = 0

    # critic step
    opt_crt.zero_grad()
    value_v = net_crt(traj_states_v)
    loss_value_v = F.mse_loss(value_v.squeeze(-1), traj_ref_v)
    loss_value_v.backward()
#     print(loss_value_v)
    opt_crt.step()

    # actor step
    def get_loss():
        mu_v = net_act(traj_states_v)
        logprob_v = calc_logprob(mu_v, net_act.logstd, traj_actions_v)
        action_loss_v = -traj_adv_v.unsqueeze(dim=-1) * torch.exp(logprob_v - old_logprob_v)
#         print(action_loss_v,action_loss_v.mean())
        return action_loss_v.mean()

    def get_kl():
        mu_v = net_act(traj_states_v)
        logstd_v = net_act.logstd
        mu0_v = mu_v.detach()
        logstd0_v = logstd_v.detach()
        std_v = torch.exp(logstd_v)
        std0_v = std_v.detach()
        kl = logstd_v - logstd0_v + (std0_v ** 2 + (mu0_v - mu_v) ** 2) / (2.0 * std_v ** 2) - 0.5
        return kl.sum(1, keepdim=True)

    loss=trpo_step(net_act, get_loss, get_kl, TRPO_MAX_KL, TRPO_DAMPING, device=default_device())

    return torch.tensor(loss)

In [None]:
# export
class TRPOTrainer(Callback):
    def after_loss(self):raise CancelBatchException

In [None]:
# export
class TRPOLearner(AgentLearner):
    def __init__(self,dls,agent=None,crtic_lr=1e-3,max_step=1,**kwargs):
        store_attr()
        self.net_crt=ModelCritic(26).to(default_device())
        self.opt_crt = Adam(self.net_crt.parameters(), lr=crtic_lr)

        super().__init__(dls,loss_func=partial(loss_func,learn=self),model=agent.model,**kwargs)

In [None]:
import gym
import pybulletgym
env='HalfCheetahPyBulletEnv-v0'
agent=AgentA2C(ModelActor(26, 6).to(default_device()), device=default_device())

block=ExperienceBlock(agent=agent,seed=0,n_steps=1,
                      dls_kwargs={'bs':2049,'num_workers':0,'verbose':False,'indexed':True,'shuffle_train':False})
blk=IterableDataBlock(blocks=(block),splitter=FuncSplitter(lambda x:False))
dls=blk.dataloaders([env]*1,n=2049*10,device=default_device())

learner=TRPOLearner(dls,agent=agent,cbs=[TRPOTrainer], metrics=[AvgEpisodeRewardMetric(Experience)],max_step=gym.make(env)._max_episode_steps)
learner.fit(20,lr=0.001,wd=0)

WalkerBase::__init__
WalkerBase::__init__




epoch,train_loss,train_avg_episode_r,valid_loss,valid_avg_episode_r,time
0,0.0,1.867225,,1.867225,00:37
1,0.0,4.581237,,4.581237,00:38
2,0.0,6.013871,,6.013871,00:40
3,0.0,6.744021,,6.744021,00:38
4,0.0,7.103516,,7.103516,00:38
5,0.0,7.459056,,7.459056,00:38
6,0.0,7.711311,,7.711311,00:37
7,0.0,8.075411,,8.075411,00:37
8,0.0,8.610297,,8.610297,00:37
9,0.0,9.437673,,9.437673,00:37


  warn("Your generator is empty.")
