In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class VanillaSACPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, lr, device):
        super(VanillaSACPolicy, self).__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.actor_lr = lr
        self.device = device


        self.fc1 = nn.Linear(state_dim,256).to(device)
        self.fc2 = nn.Linear(256,256).to(device)
        self.fc3 = nn.Linear(256,action_dim*2).to(device)

        nn.init.uniform_(tensor=self.fc3.weight, a = -3e-3, b=3e-3)
        nn.init.uniform_(tensor=self.fc3.bias, a=-3e-3, b=3e-3)

        self.optimizer = optim.Adam(self.parameters(),lr)


    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))

        x = x.view(-1, self.action_dim, 2)

        # It should return mu, sigma

        return x

In [2]:
t = VanillaSACPolicy(1,1,1e-4,torch.device("cuda"))

In [7]:
k = t.forward(torch.FloatTensor([0.0]).to(torch.device("cuda")))

In [14]:
kk = torch.sum(k)

In [15]:
kk.backward()

In [20]:
t.fc3.weight.grad[0][:20]

tensor([0.0000, 0.3183, 0.3409, 0.5602, 0.0139, 0.0000, 0.0000, 0.1902, 0.0000,
        0.0000, 0.1893, 0.0000, 0.4379, 0.2406, 0.0000, 0.1619, 0.0000, 0.2684,
        0.0000, 0.4259], device='cuda:0')

In [45]:
with torch.no_grad():
    a = t.forward(torch.FloatTensor([1.0]).to(torch.device("cuda")))
    aa = torch.mean(a)
    aa.backward()

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [46]:
a = t.forward(torch.FloatTensor([1.0]).to(torch.device("cuda")))
aa = torch.mean(a)
aa.backward()

In [47]:
t.fc3.weight.grad[0][:20]

tensor([0.0000, 0.4626, 0.6405, 0.8171, 0.0139, 0.0000, 0.0000, 0.4331, 0.0000,
        0.0350, 0.2309, 0.0000, 0.5991, 0.5030, 0.0884, 0.3874, 0.1658, 0.3437,
        0.0000, 0.8348], device='cuda:0')

In [55]:
a = t.forward(torch.FloatTensor([1.0]).to(torch.device("cuda")))
aa = torch.mean(a, dim=2,keepdim=True)

In [57]:
aa.shape

torch.Size([1, 1, 1])

In [54]:
a.shape

torch.Size([1, 1, 2])

In [52]:
aa.grad_fn

<MeanBackward1 at 0x7f85551be2b0>

In [50]:
k = aa.detach()

In [51]:
k.grad_fn