In [10]:
#hide
#skip
%config Completer.use_jedi = False
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [11]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [12]:
# default_exp agents.dqn.dueling

In [13]:
# export
# Python native modules
# Third party libs
from torch.nn import *
from fastcore.all import *
from fastai.learner import *
from fastai.torch_basics import *
from fastai.torch_core import *
from fastai.callback.all import *
# Local modules
from fastrl.data.block_simple import *
from fastrl.data.gym import *
from fastrl.agent import *
from fastrl.core import *
from fastrl.agents.dqn.core import *
from fastrl.agents.dqn.targets import *
from fastrl.agents.dqn.double import *

# Dueling DQN
> DQN using a split head for comparing the davantage of different actions.

In [14]:
# export
class DuelingBlock(nn.Module):
    def __init__(self,n_actions,hidden=512,lin_cls=nn.Linear):
        super().__init__()
        self.val=lin_cls(hidden,1)
        self.adv=lin_cls(hidden,n_actions)

    def forward(self,xi):
        val,adv=self.val(xi),self.adv(xi)
        xi=val.expand_as(adv)+(adv-adv.mean()).squeeze(0)
        return xi
    
class DuelingDQN(DQN):
    def __init__(self,state_sz:int,n_actions,hidden=512):
        super(DQN,self).__init__()
        self.layers=nn.Sequential(
            nn.Linear(state_sz,hidden),
            nn.ReLU(),
            DuelingBlock(n_actions,hidden)
        )

In [15]:
dqn=DuelingDQN(4,2)

agent=Agent(dqn,cbs=[ArgMaxFeed,DiscreteEpsilonRandomSelect])
source=Source(cbs=[GymLoop('CartPole-v1',agent,steps_count=3,seed=0,
                           steps_delta=1),FirstLast])
dls=SourceDataBlock().dataloaders([source],n=1000,bs=1,num_workers=0)

learn=Learner(dls,agent,loss_func=MSELoss(),
              cbs=[ExperienceReplay(bs=32,max_sz=100000,warmup_sz=32),DQNTargetTrainer(n_steps=3)],
              metrics=[Reward,Epsilon,NEpisodes])

Could not do one pass in your dataloader, there is something wrong in it


In [16]:
slow=False
learn.fit(3 if not slow else 47,lr=0.0001,wd=0)

epoch,train_loss,train_reward,train_epsilon,train_n_episodes,valid_loss,valid_reward,valid_epsilon,valid_n_episodes,time
0,3.859209,23.11,0.5996,988,00:30,,,,
1,7.212608,24.52,0.2,2307,00:27,,,,
2,9.551857,30.65,0.2,3446,00:28,,,,


## Double Dueling DQN (DDQN)

In [17]:
dqn=DuelingDQN(4,2)

agent=Agent(dqn,cbs=[ArgMaxFeed,DiscreteEpsilonRandomSelect])
source=Source(cbs=[GymLoop('CartPole-v1',agent,steps_count=3,seed=0,
                           steps_delta=1),FirstLast])
dls=SourceDataBlock().dataloaders([source],n=1000,bs=1,num_workers=0)

learn=Learner(dls,agent,loss_func=MSELoss(),
              cbs=[ExperienceReplay(bs=32,max_sz=100000,warmup_sz=32),DoubleDQNTrainer(n_steps=3)],
              metrics=[Reward,Epsilon])

Could not do one pass in your dataloader, there is something wrong in it


In [None]:
slow=False
learn.fit(3 if not slow else 47,lr=0.0001,wd=0)

epoch,train_loss,train_reward,train_epsilon,valid_loss,valid_reward,valid_epsilon,time
0,4.161854,20.26,0.5996,00:28,,,


If you want to run this using multiple processess, the multiprocessing code looks like below.
However you will not be able to run this in a notebook, instead add this to a py file and run it from there.

> Warning: There is a bug in data block that prevents this. Should be a simple fix.

In [None]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import make_readme
    make_readme()
    notebook2script()
    notebook2html()