-
Notifications
You must be signed in to change notification settings - Fork 212
Closed
Description
Python:
import numpy as np
import torch as th
from torch import nn
class NatureCNN(nn.Module):
def __init__(self, features_dim: int = 512) -> None:
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)
temp = np.zeros((3, 128, 128))
with th.no_grad():
n_flatten = self.cnn(th.as_tensor(temp).float()).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
cnn = NatureCNN()
print(cnn.parameters())
for param in cnn.parameters():
print(param.size())
Console:
<generator object Module.parameters at 0x7f5dec222b20>
torch.Size([32, 3, 8, 8])
torch.Size([32])
torch.Size([64, 32, 4, 4])
torch.Size([64])
torch.Size([64, 64, 3, 3])
torch.Size([64])
torch.Size([512, 144])
torch.Size([512])
C#
using TorchSharp.Modules;
using static TorchSharp.torch;
namespace TestApp1;
public class NatureCNN : nn.Module<Tensor, Tensor>
{
private readonly Sequential cnn;
private readonly Sequential linear;
public NatureCNN(int featuresDim = 512)
: base(nameof(NatureCNN))
{
cnn = nn.Sequential(
nn.Conv2d(3, 32, 8, 4, 0),
nn.ReLU(),
nn.Conv2d(32, 64, 4, 2, 0),
nn.ReLU(),
nn.Conv2d(64, 64, 3, 1, 0),
nn.ReLU(),
nn.Flatten()
);
double[,,] temp = new double[3, 128, 128];
long nFlatten;
// Compute shape by doing one forward pass
using (no_grad())
{
nFlatten = cnn.forward(as_tensor(temp).@float()).shape[1];
}
linear = nn.Sequential(nn.Linear(nFlatten, featuresDim), nn.ReLU());
}
public override Tensor forward(Tensor observations)
=> linear.forward(cnn.forward(observations));
protected override void Dispose(bool disposing)
{
if (disposing)
{
cnn.Dispose();
linear.Dispose();
}
base.Dispose(disposing);
}
}
NatureCNN cNN = new NatureCNN();
var a = cNN.parameters().ToList();
// a nothing
Metadata
Metadata
Assignees
Labels
No labels