In [53]:
# default_exp qlearning.dqn_noisy

In [54]:
#export
import torch.nn.utils as nn_utils
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym
import torch.multiprocessing as mp
from torch.optim import *

from fastrl.data import *
from fastrl.async_data import *
from fastrl.basic_agents import *
from fastrl.learner import *
from fastrl.metrics import *
from fastrl.ptan_extension import *
from fastrl.qlearning.dqn import *
from fastrl.qlearning.dqn_target import *

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

# Noisy DQN

In [73]:
# export
class NoisyLinear(nn.Linear):
    def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
        super(NoisyLinear, self).__init__(in_features, out_features, bias=bias)
        self.sigma_weight = nn.Parameter(torch.full((out_features, in_features), sigma_init))
        self.register_buffer("epsilon_weight", torch.zeros(out_features, in_features))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_features,), sigma_init))
            self.register_buffer("epsilon_bias", torch.zeros(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(3 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)

    def forward(self, input):
        self.epsilon_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.epsilon_bias.normal_()
            bias = bias + self.sigma_bias * self.epsilon_bias.data
        o=F.linear(input, self.weight + self.sigma_weight * self.epsilon_weight.data, bias)
#         print(o)
        return o


class NoisyFactorizedLinear(nn.Linear):
    def __init__(self, in_features, out_features, sigma_zero=0.4, bias=True):
        super(NoisyFactorizedLinear, self).__init__(in_features, out_features, bias=bias)
        sigma_init = sigma_zero / math.sqrt(in_features)
        self.sigma_weight = nn.Parameter(torch.full((out_features, in_features), sigma_init))
        self.register_buffer("epsilon_input", torch.zeros(1, in_features))
        self.register_buffer("epsilon_output", torch.zeros(out_features, 1))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_features,), sigma_init))

    def forward(self, input):
        self.epsilon_input.normal_()
        self.epsilon_output.normal_()

        func = lambda x: torch.sign(x) * torch.sqrt(torch.abs(x))
        eps_in = func(self.epsilon_input.data)
        eps_out = func(self.epsilon_output.data)

        bias = self.bias
        if bias is not None:
            bias = bias + self.sigma_bias * eps_out.t()
        noise_v = torch.mul(eps_in, eps_out)
        return F.linear(input, self.weight + self.sigma_weight * noise_v, bias)


class NoisyDQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(NoisyDQN, self).__init__()

        self.noisy_layers = [
            NoisyLinear(input_shape[0], 512),
            NoisyLinear(512, n_actions)
        ]
        self.fc = nn.Sequential(
            self.noisy_layers[0],
            nn.ReLU(),
            self.noisy_layers[1]
        )

    def forward(self, x):
        fx = x.float() 
        return self.fc(fx)

    def noisy_layers_sigma_snr(self):
        return [
            ((layer.weight ** 2).mean().sqrt() / (layer.sigma_weight ** 2).mean().sqrt()).item()
            for layer in self.noisy_layers
        ]



In [56]:
# export
class TargetDQNLearner(AgentLearner):
    def __init__(self,dls,discount=0.99,n_steps=3,target_sync=300,**kwargs):
        store_attr()
        self.target_q_v=[]
        super().__init__(dls,loss_func=nn.MSELoss(),**kwargs)
        self.target_model=deepcopy(self.model)

In [74]:
class TestArgmaxActionSelector(ArgmaxActionSelector):
    def __call__(self,scores):
        assert isinstance(scores,np.ndarray)
        o= np.argmax(scores,axis=1)
#         print(o)
        return o

In [76]:
env='CartPole-v1'
model=NoisyDQN((4,),2)
agent=DiscreteAgent(model=model.to(default_device()),device=default_device(),
                    a_selector=TestArgmaxActionSelector())

block=FirstLastExperienceBlock(agent=agent,seed=0,n_steps=2,dls_kwargs={'bs':32,'num_workers':0,'verbose':False,'indexed':True,'shuffle_train':False})
blk=IterableDataBlock(blocks=(block),
                      splitter=FuncSplitter(lambda x:False),
                     )
dls=blk.dataloaders([env]*1,n=32*100,device=default_device())

learner=TargetDQNLearner(dls,agent=agent,n_steps=1,wd_bn_bias=True,cbs=[
                                        ExperienceReplay(sz=50000,bs=32,starting_els=32,max_steps=gym.make(env)._max_episode_steps),
                                        TargetDQNTrainer],metrics=[AvgEpisodeRewardMetric(experience_cls=ExperienceFirstLast)])
learner.fit(6,lr=0.0001,wd=0)

epoch,train_loss,train_avg_episode_r,valid_loss,valid_avg_episode_r,time
0,1.222407,9.400833,,9.400833,00:05
1,0.604335,9.008333,,9.008333,00:05
2,0.377277,9.0325,,9.0325,00:05
3,1.486743,9.0025,,9.0025,00:05
4,1.073201,9.0,,9.0,00:05
5,0.809176,9.0,00:03,,


KeyboardInterrupt: 

In [1]:
# hide
from nbdev.export import *
from nbdev.export2html import *
notebook2script()
notebook2html()

Converted 00_core.ipynb.
Converted 01_wrappers.ipynb.
Converted 03_basic_agents.ipynb.
Converted 04_learner.ipynb.
Converted 05a_ptan_extend.ipynb.
Converted 05b_async_data.ipynb.
Converted 05c_data.ipynb.
Converted 13_metrics.ipynb.
Converted 14a_actorcritic.sac.ipynb.
Converted 14b_actorcritic.diayn.ipynb.
Converted 15_actorcritic.a3c_data.ipynb.
Converted 16_actorcritic.a2c.ipynb.
Converted 17_actorcritc.v1.dads.ipynb.
Converted 18_policy_gradient.ppo.ipynb.
Converted 19_policy_gradient.trpo.ipynb.
Converted 20a_qlearning.dqn.ipynb.
Converted 20b_qlearning.dqn_n_step.ipynb.
Converted 20c_qlearning.dqn_target.ipynb.
Converted 20d_qlearning.dqn_double.ipynb.
Converted 20e_qlearning.dqn_noisy.ipynb.
Converted index.ipynb.
Converted notes.ipynb.


converting: /opt/project/fastrl/nbs/20d_qlearning.dqn_double.ipynb
converting: /opt/project/fastrl/nbs/20c_qlearning.dqn_target.ipynb
converting: /opt/project/fastrl/nbs/20e_qlearning.dqn_noisy.ipynb
