In [None]:
# default_exp qlearning.dist_dqn

In [None]:
#export
import torch.nn.utils as nn_utils
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym
import torch.multiprocessing as mp
from torch.optim import *

from fastrl.data import *
from fastrl.async_data import *
from fastrl.basic_agents import *
from fastrl.learner import *
from fastrl.metrics import *
from fastrl.ptan_extension import *
from fastrl.qlearning.dqn import *
from fastrl.qlearning.dqn_target import *

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

# Distributional DQN

In [None]:
# export
Vmax = 10
Vmin = -10
N_ATOMS = 51
DELTA_Z = (Vmax - Vmin) / (N_ATOMS - 1)

class DistributionalDQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DistributionalDQN, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(input_shape[0], 512),
            nn.ReLU(),
            nn.Linear(512, n_actions * N_ATOMS)
        )

        self.register_buffer("supports", torch.arange(Vmin, Vmax+DELTA_Z, DELTA_Z))
        self.softmax=nn.Softmax(dim=1)

        self.loss_func=None

    def set_opt(self,_):pass

    def forward(self, x,only_qvals=False):
        batch_size = x.size()[0]
        fc_out = self.fc(x.float())
        return fc_out.view(batch_size, -1, N_ATOMS)if not only_qvals else self.qvals(x)

    def both(self, x):
        cat_out = self(x)
        probs = self.apply_softmax(cat_out)
        weights = probs * self.supports
        res = weights.sum(dim=2)
        return cat_out, res

    def qvals(self, x):
        return self.both(x)[1]

    def apply_softmax(self, t):
        return self.softmax(t.view(-1, N_ATOMS)).view(t.size())

> Notes: This is an ugly function. Is there is a way we can simplify this? Will need to look at during the refactor

In [None]:
# export
def distr_projection(next_distr, rewards, dones, Vmin, Vmax, n_atoms, gamma):
    """
    Perform distribution projection aka Catergorical Algorithm from the
    "A Distributional Perspective on RL" paper
    
    Note: direct from https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On
    """
#     next_distr=next_distr.detach().cpu()
    rewards=rewards.detach().cpu().numpy()
    dones=dones.detach().cpu().numpy()
    
    batch_size = len(rewards)
    proj_distr = np.zeros((batch_size, n_atoms), dtype=np.float32)
    delta_z = (Vmax - Vmin) / (n_atoms - 1)
    for atom in range(n_atoms):
        tz_j = np.minimum(Vmax, np.maximum(Vmin, rewards + (Vmin + atom * delta_z) * gamma))
        b_j = (tz_j - Vmin) / delta_z
        l = np.floor(b_j).astype(np.int64)
        u = np.ceil(b_j).astype(np.int64)
        eq_mask = u == l
        proj_distr[eq_mask, l[eq_mask]] += next_distr[eq_mask, atom]
        ne_mask = u != l
        proj_distr[ne_mask, l[ne_mask]] += next_distr[ne_mask, atom] * (u - b_j)[ne_mask]
        proj_distr[ne_mask, u[ne_mask]] += next_distr[ne_mask, atom] * (b_j - l)[ne_mask]
    if dones.any():
        proj_distr[dones] = 0.0
        tz_j = np.minimum(Vmax, np.maximum(Vmin, rewards[dones]))
        b_j = (tz_j - Vmin) / delta_z
        l = np.floor(b_j).astype(np.int64)
        u = np.ceil(b_j).astype(np.int64)
        eq_mask = u == l
        eq_dones = dones.copy()
        eq_dones[dones] = eq_mask
        if eq_dones.any():
            proj_distr[eq_dones, l[eq_mask]] = 1.0
        ne_mask = u != l
        ne_dones = dones.copy()
        ne_dones[dones] = ne_mask
        if ne_dones.any():
            proj_distr[ne_dones, l[ne_mask]] = (u - b_j)[ne_mask]
            proj_distr[ne_dones, u[ne_mask]] = (b_j - l)[ne_mask]
    return proj_distr


In [None]:
# export
def loss_fn(a,b): return (-a*b).sum(dim=1).mean()

def calc_dist_target_batch(learn,trainer,s,a,sp,r,d):
#     states_v = torch.tensor(states).to(device)
#     actions_v = torch.tensor(actions).to(device)
#     next_states_v = torch.tensor(next_states).to(device)

    
    next_distr_v, next_qvals_v = learn.target_model.both(sp)
    next_actions = next_qvals_v.max(1)[1].data.cpu().numpy()
    next_distr = learn.target_model.apply_softmax(next_distr_v).data.cpu().numpy()

    next_best_distr = next_distr[range(s.shape[0]), next_actions]
    # project our distribution using Bellman update
    with torch.no_grad():
        proj_distr = distr_projection(next_best_distr, r, d, Vmin, Vmax, N_ATOMS, learn.discount)

    # calculate net output
    distr_v = learn.model(s,).to(device=default_device())
    state_action_values = distr_v[range(s.shape[0]), a.data]
    state_log_sm_v = F.log_softmax(state_action_values, dim=1).to(device=default_device())
    proj_distr_v = torch.tensor(proj_distr).to(device=default_device())

#     loss_v = -state_log_sm_v * proj_distr_v
#     print(-state_log_sm_v * proj_distr_v)
#     print(state_log_sm_v.shape,proj_distr_v.shape)
    return state_log_sm_v,proj_distr_v
    
#     state_action_values=learn.model(s.float()).gather(1, a.unsqueeze(-1)).squeeze(-1)

#     next_state_values=trainer.get_next_state_values(sp)
#     next_state_values[d] = 0.0

#     expected_state_action_values=next_state_values.detach()*(learn.discount**learn.n_steps)+r
#     return expected_state_action_values,state_action_values

In [None]:
from dataclasses import dataclass
@dataclass
class DistDiscreteAgent(BaseAgent):
    "DiscreteAgent a simple discrete action selector."
    a_selector:ActionSelector=None
    device:str=None
    preprocessor:Callable=default_states_preprocessor
    apply_softmax:bool=False

    def safe_unbatch(self,o:np.array)->np.array:return o[0] if o.shape[0]==1 and len(o.shape)>1 else o
    def split_v(self,v,asl): return v,asl

    @torch.no_grad()
    def __call__(self,x,asl=None,include_batch_dim=True):
        x=self.preprocessor(x) if self.preprocessor is not None else s
        asl= np.zeros(x.shape) if asl is None or len(asl)==0 else asl
        if torch.is_tensor(x):
            x=x.to(self.device)
        v=self.model(x,only_qvals=True)
        if type(v)==tuple:v,asl=self.split_v(v,asl)
        if self.apply_softmax:
            v=F.softmax(v,dim=1)
        q=v.data.cpu().numpy()
        al=self.a_selector(q)
        if not include_batch_dim:al=self.safe_unbatch(al).tolist()

#         print(al)
#         if not isinstance(al,list): al=[al]
        if include_batch_dim:
            al=np.array(al)
            asl=np.array(asl)
            if len(al.shape)==0: al=al.reshape(1,)
            if len(asl.shape)==0: asl=asl.reshape(1,)
            return al,asl

        return (al[0],asl[0])

In [None]:
env='CartPole-v1'
model=DistributionalDQN((4,),2)
agent=DistDiscreteAgent(model=model.to(default_device()),device=default_device(),
                    a_selector=EpsilonGreedyActionSelector())

block=FirstLastExperienceBlock(agent=agent,seed=0,n_steps=1,dls_kwargs={'bs':1,'num_workers':0,'verbose':False,'indexed':True,'shuffle_train':False})
blk=IterableDataBlock(blocks=(block),
                      splitter=FuncSplitter(lambda x:False),
                     )
dls=blk.dataloaders([env]*1,n=1*1000,device=default_device())

learner=TargetDQNLearner(dls,agent=agent,n_steps=3,loss_func=loss_fn,cbs=[EpsilonTracker,
                                        ExperienceReplay(sz=100000,bs=32,starting_els=32,max_steps=gym.make(env)._max_episode_steps),
                                        TargetDQNTrainer(target_fn=calc_dist_target_batch)],metrics=[AvgEpisodeRewardMetric(experience_cls=ExperienceFirstLast,always_extend=True)])
learner.fit(47,lr=0.0001,wd=0)

epoch,train_loss,train_avg_episode_r,valid_loss,valid_avg_episode_r,time
0,3.458412,21.428571,,21.428571,00:18
1,3.04431,25.164384,,25.164384,00:18
2,2.57193,30.340659,,30.340659,00:18
3,1.739827,36.06,,36.06,00:18
4,1.252485,41.34,,41.34,00:19
5,0.912612,47.68,,47.68,00:20
6,0.714993,53.64,,53.64,00:18
7,0.612818,59.39,,59.39,00:18
8,0.524101,66.49,,66.49,00:18
9,0.454261,70.3,,70.3,00:17


  warn("Your generator is empty.")


In [None]:
# hide
from nbdev.export import *
from nbdev.export2html import *
notebook2script()
notebook2html()

Converted 00_core.ipynb.
Converted 01_wrappers.ipynb.
Converted 03_basic_agents.ipynb.
Converted 04_learner.ipynb.
Converted 05a_ptan_extend.ipynb.
Converted 05b_data.ipynb.
Converted 05c_async_data.ipynb.
Converted 13_metrics.ipynb.
Converted 14a_actorcritic.sac.ipynb.
Converted 14b_actorcritic.diayn.ipynb.
Converted 14c_actorcritic.dads.ipynb.
Converted 15_actorcritic.a3c_data.ipynb.
Converted 16_actorcritic.a2c.ipynb.
Converted 18_policy_gradient.ppo.ipynb.
Converted 19_policy_gradient.trpo.ipynb.
Converted 20a_qlearning.dqn.ipynb.
Converted 20b_qlearning.dqn_n_step.ipynb.
Converted 20c_qlearning.dqn_target.ipynb.
Converted 20d_qlearning.dqn_double.ipynb.
Converted 20e_qlearning.dqn_noisy.ipynb.
Converted 20f_qlearning.dqn_dueling.ipynb.
Converted 20g_qlearning.dddqn.ipynb.
Converted 20h_qlearning.dist_dqn.ipynb.
Converted index.ipynb.
Converted notes.ipynb.
converting: /opt/project/fastrl/nbs/20h_qlearning.dist_dqn.ipynb
