In [87]:
# default_exp qlearning.dqn_noisy

In [88]:
#export
import torch.nn.utils as nn_utils
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.basics import *
from dataclasses import field,asdict
from typing import List,Any,Dict,Callable
from collections import deque
import gym
import torch.multiprocessing as mp
from torch.optim import *

from fastrl.data import *
from fastrl.async_data import *
from fastrl.basic_agents import *
from fastrl.learner import *
from fastrl.metrics import *
from fastrl.ptan_extension import *
from fastrl.qlearning.dqn import *
from fastrl.qlearning.dqn_target import *

if IN_NOTEBOOK:
    from IPython import display
    import PIL.Image

# Noisy DQN

In [89]:
# export
class NoisyLinear(nn.Linear):
    def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
        super(NoisyLinear, self).__init__(in_features, out_features, bias=bias)
        self.sigma_weight = nn.Parameter(torch.full((out_features, in_features), sigma_init))
        self.register_buffer("epsilon_weight", torch.zeros(out_features, in_features))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_features,), sigma_init))
            self.register_buffer("epsilon_bias", torch.zeros(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(3 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)

    def forward(self, x):
        self.epsilon_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.epsilon_bias.normal_()
            bias = bias + self.sigma_bias * self.epsilon_bias.data
        o=F.linear(x, self.weight + self.sigma_weight * self.epsilon_weight.data, bias)
        return o


class NoisyFactorizedLinear(nn.Linear):
    def __init__(self, in_features, out_features, sigma_zero=0.4, bias=True):
        super(NoisyFactorizedLinear, self).__init__(in_features, out_features, bias=bias)
        sigma_init = sigma_zero / math.sqrt(in_features)
        self.sigma_weight = nn.Parameter(torch.full((out_features, in_features), sigma_init))
        self.register_buffer("epsilon_input", torch.zeros(1, in_features))
        self.register_buffer("epsilon_output", torch.zeros(out_features, 1))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_features,), sigma_init))

    def forward(self, input):
        self.epsilon_input.normal_()
        self.epsilon_output.normal_()

        func = lambda x: torch.sign(x) * torch.sqrt(torch.abs(x))
        eps_in = func(self.epsilon_input.data)
        eps_out = func(self.epsilon_output.data)

        bias = self.bias
        if bias is not None:
            bias = bias + self.sigma_bias * eps_out.t()
        noise_v = torch.mul(eps_in, eps_out)
        return F.linear(input, self.weight + self.sigma_weight * noise_v, bias)


class NoisyDQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(NoisyDQN, self).__init__()

        self.noisy_layers = [
            NoisyLinear(input_shape[0], 512),
            NoisyLinear(512, n_actions)
        ]
        self.fc = nn.Sequential(
            self.noisy_layers[0],
            nn.ReLU(),
            self.noisy_layers[1]
        )
        self.counter=0

    def forward(self, x):
        fx = x.float() 
                
        if self.counter%500==0:
            print(self.noisy_layers_sigma_snr())
        self.counter+=1
        
        return self.fc(fx)

    def noisy_layers_sigma_snr(self):
        return [
            ((layer.weight ** 2).mean().sqrt() / (layer.sigma_weight ** 2).mean().sqrt()).item()
            for layer in self.noisy_layers
        ]



In [100]:
# export
class TargetDQNLearner(AgentLearner):
    def __init__(self,dls,discount=0.99,n_steps=3,target_sync=300,**kwargs):
        store_attr()
        self.target_q_v=[]
        super().__init__(dls,loss_func=nn.MSELoss(),**kwargs)
        self.target_model=deepcopy(self.model)

In [101]:
class TestArgmaxActionSelector(ArgmaxActionSelector):
    def __call__(self,scores):
        assert isinstance(scores,np.ndarray)
        o= np.argmax(scores,axis=1)
#         print(o)
        return o

In [102]:
env='CartPole-v1'
model=NoisyDQN((4,),2)

In [103]:
agent=DiscreteAgent(model=model.to(default_device()),device=default_device(),
                    a_selector=TestArgmaxActionSelector())

block=FirstLastExperienceBlock(agent=agent,seed=0,n_steps=2,dls_kwargs={'bs':32,'num_workers':0,'verbose':False,'indexed':True,'shuffle_train':False})
blk=IterableDataBlock(blocks=(block),
                      splitter=FuncSplitter(lambda x:False),
                     )
dls=blk.dataloaders([env]*1,n=32*100,device=default_device())

learner=TargetDQNLearner(dls,agent=agent,n_steps=2,wd_bn_bias=True,cbs=[
                                        ExperienceReplay(sz=100000,bs=32,starting_els=32,max_steps=gym.make(env)._max_episode_steps),
                                        TargetDQNTrainer],metrics=[AvgEpisodeRewardMetric(experience_cls=ExperienceFirstLast)])

[29.51614761352539, 2.6565346717834473]


In [104]:
learner.fit(6,lr=0.0001,wd=0)

epoch,train_loss,train_avg_episode_r,valid_loss,valid_avg_episode_r,time
0,1.383956,10.0,,10.0,00:14
1,0.772038,10.0,,10.0,00:13
2,0.521154,10.0,,10.0,00:14
3,1.972121,10.0,,10.0,00:14
4,1.743856,10.0,00:06,,


[29.503910064697266, 2.6564838886260986]
[29.49394989013672, 2.656461715698242]
[29.489553451538086, 2.6570115089416504]
[29.491273880004883, 2.658151388168335]
[29.49207878112793, 2.659501075744629]
[29.49216651916504, 2.6610922813415527]
[29.49304962158203, 2.6626505851745605]
[29.487321853637695, 2.6641170978546143]
[29.48851203918457, 2.6653125286102295]
[29.486324310302734, 2.6662514209747314]
[29.48097038269043, 2.6675326824188232]
[29.47698974609375, 2.6693923473358154]
[29.47264862060547, 2.670992851257324]
[29.468135833740234, 2.672858476638794]
[29.463985443115234, 2.675305128097534]
[29.468393325805664, 2.677896738052368]
[29.473608016967773, 2.680013656616211]
[29.474868774414062, 2.6814517974853516]
[29.47462272644043, 2.6830427646636963]
[29.477842330932617, 2.6840872764587402]
[29.47919464111328, 2.6855201721191406]
[29.48183822631836, 2.6881930828094482]
[29.478355407714844, 2.6903135776519775]
[29.469249725341797, 2.6928770542144775]
[29.468219757080078, 2.695256710052

KeyboardInterrupt: 

In [None]:
# hide
from nbdev.export import *
from nbdev.export2html import *
notebook2script()
notebook2html()