In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ActorNetwork(nn.Module):
    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super(ActorNetwork, self).__init__()

        n_input = input_shape
        n_output = output_shape
        
        #self.dummy_param = nn.Parameter(torch.empty(0))

        self._h1 = nn.Linear(n_input, n_features)
        self._h2 = nn.Linear(n_features, n_features)
        self._h3 = nn.Linear(n_features, n_output)

        nn.init.xavier_uniform_(self._h1.weight,
                                gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_uniform_(self._h2.weight,
                                gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_uniform_(self._h3.weight,
                                gain=nn.init.calculate_gain('linear'))
    
    @property
    def n_params(self):
        
        n = 0
        for p in self.parameters():
            p_shape = torch.tensor(p.shape)
            n += torch.prod(p_shape,0)
        
        return n
    
    @property
    def device(self):
        device = next(self.parameters()).device
        return device
            

    def _update_weights(self,new_weights,tau=1):
    
        new_weights = torch.tensor(new_weights,device=self.device)
        
        idx = 0
        for param in self.parameters():
            weights = param.data
            weights_shape = torch.tensor(weights.shape)
            n_steps = torch.prod(weights_shape,0)
            #print(f'n_steps :{n_steps}, idx{idx} idx+ {idx+n_steps}')
            new_params = new_weights[idx:idx+n_steps]
            #print(f'1 {new_params.shape}')
            new_params = new_params.reshape(*weights_shape)
            #print(f'2 {new_params.shape}')
            param.data = new_params
            idx += n_steps

    def forward(self, state):
        features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
        features2 = F.relu(self._h2(features1))
        a = self._h3(features2)

        return a

In [23]:
actr = ActorNetwork(input_shape=3,output_shape=1,n_features=10)
states = torch.rand(10,3)
output = torch.sum(actr(states))
#grad   = torch.autograd.grad(output,actr.parameters())

noise_vector = torch.rand(actr.n_params)
#actr._update_weights(noise_vector)


#actr.n_params
#noise_vector = torch.tensor(noise_vector,requires_grad=True)
#actr._update_weights(noise_vector)
grad   = torch.autograd.grad(output,actr.parameters())
for i in grad:
    print(i.shape)

# _update_target(self, current,noise_vector)

torch.Size([10, 3])
torch.Size([10])
torch.Size([10, 10])
torch.Size([10])
torch.Size([1, 10])
torch.Size([1])


In [28]:
torch.tensor(torch.range(1,30)).reshape(3,10)

  torch.tensor(torch.range(1,30)).reshape(3,10)
  torch.tensor(torch.range(1,30)).reshape(3,10)


tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
        [11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
        [21., 22., 23., 24., 25., 26., 27., 28., 29., 30.]])