Skip to content

Custom model can't get parameters #887

@ChengYen-Tang

Description

@ChengYen-Tang

Python:

import numpy as np
import torch as th
from torch import nn

class NatureCNN(nn.Module):
    def __init__(self, features_dim: int = 512) -> None:
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        temp = np.zeros((3, 128, 128))

        with th.no_grad():
            n_flatten = self.cnn(th.as_tensor(temp).float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))

cnn = NatureCNN()
print(cnn.parameters())
for param in cnn.parameters():
    print(param.size())

Console:
<generator object Module.parameters at 0x7f5dec222b20>
torch.Size([32, 3, 8, 8])
torch.Size([32])
torch.Size([64, 32, 4, 4])
torch.Size([64])
torch.Size([64, 64, 3, 3])
torch.Size([64])
torch.Size([512, 144])
torch.Size([512])

C#

using TorchSharp.Modules;
using static TorchSharp.torch;

namespace TestApp1;

public class NatureCNN : nn.Module<Tensor, Tensor>
{
    private readonly Sequential cnn;
    private readonly Sequential linear;

    public NatureCNN(int featuresDim = 512)
        : base(nameof(NatureCNN))
    {
        cnn = nn.Sequential(
            nn.Conv2d(3, 32, 8, 4, 0),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 0),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1, 0),
            nn.ReLU(),
            nn.Flatten()
            );


        double[,,] temp = new double[3, 128, 128];

        long nFlatten;
        // Compute shape by doing one forward pass
        using (no_grad())
        {
            nFlatten = cnn.forward(as_tensor(temp).@float()).shape[1];
        }
        linear = nn.Sequential(nn.Linear(nFlatten, featuresDim), nn.ReLU());
    }

    public override Tensor forward(Tensor observations)
        => linear.forward(cnn.forward(observations));

    protected override void Dispose(bool disposing)
    {
        if (disposing)
        {
            cnn.Dispose();
            linear.Dispose();
        }
        base.Dispose(disposing);
    }
}

NatureCNN cNN = new NatureCNN();
var a = cNN.parameters().ToList();
// a nothing

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions