In [None]:
#hide
#skip
%config Completer.use_jedi = False
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [None]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
# default_exp agents.dqn.dueling

In [2]:
# export
# Python native modules
# Third party libs
from torch.nn import *
from fastcore.all import *
from fastai.learner import *
from fastai.torch_basics import *
from fastai.torch_core import *
from fastai.callback.all import *
# Local modules
from fastrl.data.gym import *
from fastrl.agent import *
from fastrl.core import *
from fastrl.agents.dqn.core import *
from fastrl.agents.dqn.targets import *
from fastrl.agents.dqn.double import *

# Dueling DQN
> DQN using a split head for comparing the davantage of different actions.

In [None]:
# export
class DuelingBlock(nn.Module):
    def __init__(self,n_actions,hidden=512,lin_cls=nn.Linear):
        super().__init__()
        self.val=lin_cls(hidden,1)
        self.adv=lin_cls(hidden,n_actions)

    def forward(self,xi):
        val,adv=self.val(xi),self.adv(xi)
        xi=val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
        return xi
    
class DuelingDQN(DQN):
    def __init__(self,state_sz:int,n_actions,hidden=512):
        super(DQN,self).__init__()
        self.layers=nn.Sequential(
            nn.Linear(state_sz,hidden),
            nn.ReLU(),
            DuelingBlock(n_actions,hidden)
        )

In [None]:
dqn=DuelingDQN(4,2)

agent=Agent(dqn,cbs=[ArgMaxFeed,DiscreteEpsilonRandomSelect])
source=Source(cbs=[GymLoop('CartPole-v1',agent,steps_count=3,seed=0,
                           steps_delta=1),FirstLast])
dls=SourceDataBlock().dataloaders([source],n=1000,bs=1,num_workers=0)

learn=Learner(dls,agent,loss_func=MSELoss(),
              cbs=[ExperienceReplay(bs=32,max_sz=100000,warmup_sz=32),DQNTargetTrainer(n_steps=3)],
              metrics=[Reward,Epsilon,NEpisodes])

In [None]:
slow=False
learn.fit(3 if not slow else 47,lr=0.0001,wd=0)

epoch,train_loss,train_reward,train_epsilon,valid_loss,valid_reward,valid_epsilon,time
0,0.200769,22.49,0.8,00:20,,,
1,0.461518,20.48,0.6,00:20,,,
2,0.736573,28.13,0.4,00:20,,,


## Double Dueling DQN (DDQN)

In [None]:
dqn=DuelingDQN(4,2)

agent=Agent(dqn,cbs=[ArgMaxFeed,DiscreteEpsilonRandomSelect])
source=Source(cbs=[GymLoop('CartPole-v1',agent,steps_count=3,seed=0,
                           steps_delta=1),FirstLast])
dls=SourceDataBlock().dataloaders([source],n=1000,bs=1,num_workers=0)

learn=Learner(dls,agent,loss_func=MSELoss(),
              cbs=[ExperienceReplay(bs=32,max_sz=100000,warmup_sz=32),DoubleDQNTrainer(n_steps=3)],
              metrics=[Reward,Epsilon])

In [None]:
slow=False
learn.fit(3 if not slow else 47,lr=0.0001,wd=0)

epoch,train_loss,train_reward,train_epsilon,valid_loss,valid_reward,valid_epsilon,time
0,1.327045,31.97,0.2,00:21,,,
1,1.571332,29.19,0.2,00:21,,,
2,2.170591,33.83,0.2,00:20,,,
3,3.319467,35.03,0.2,00:20,,,
4,3.269851,36.74,0.2,00:20,,,
5,3.655275,40.04,0.2,00:23,,,
6,5.695289,40.41,0.2,00:25,,,
7,5.243069,43.33,0.2,00:27,,,
8,6.048951,47.99,0.2,00:27,,,
9,5.926082,46.43,0.2,00:27,,,


If you want to run this using multiple processess, the multiprocessing code looks like below.
However you will not be able to run this in a notebook, instead add this to a py file and run it from there.

> Warning: There is a bug in data block that prevents this. Should be a simple fix.

In [None]:
run=False
if run:
    from torch.nn import *
    import torch.multiprocessing as mp
    from fastcore.all import *
    from fastai.learner import *
    from fastai.torch_basics import *
    from fastai.torch_core import *
    from fastai.callback.all import *
    # Local modules
    from fastrl.data.block import *
    from fastrl.agent import *
    from fastrl.core import *
    from fastrl.agents.dqn.core import *
    from fastrl.agents.dqn.targets import *
    from fastrl.agents.dqn.double import *
    from fastrl.agents.dqn.dueling import *

    try:
        mp.set_start_method('spawn')
    except Exception: pass


    dqn=DuelingDQN(4,2)
    dqn.share_memory()

    agent=Agent(dqn,cbs=[ArgMaxFeed,DiscreteEpsilonRandomSelect])
    source=Src('CartPole-v1',agent,seed=0,steps_count=1,n_envs=1,steps_delta=1,cbs=[GymSrc,FirstLast])

    dls=SourceDataBlock(
        blocks=SourceBlock(source)
    ).dataloaders([source],n=1000,bs=1,num_workers=2)

    learn=Learner(dls,agent,loss_func=MSELoss(),
                  cbs=[ExperienceReplay(bs=32,max_sz=100000,warmup_sz=32),DoubleDQNTrainer],
                  metrics=[Reward,Epsilon])

    full=True
    learn.fit(47 if full else 3,lr=0.0001,wd=0)

In [None]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import make_readme
    make_readme()
    notebook2script()
    notebook2html()

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
Converted 00_core.ipynb.
Converted 00_nbdev_extension.ipynb.
Converted 03_callback.core.ipynb.
Converted 04_agent.ipynb.
Converted 05_data.block.ipynb.
Converted 05_data.test_async.ipynb.
Converted 10a_agents.dqn.core.ipynb.
Converted 10b_agents.dqn.targets.ipynb.
Converted 10c_agents.dqn.double.ipynb.
Converted 10d_agents.dqn.dueling.ipynb.
Converted 20_test_utils.ipynb.
Converted index.ipynb.
Converted nbdev_template.ipynb.
Converted tmp.ipynb.
converting: /home/fastrl_user/fastrl/nbs/10d_agents.dqn.dueling.ipynb
