Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Module save and state_ dict method error #912

Closed
toolgood opened this issue Feb 14, 2023 · 5 comments · Fixed by #915
Closed

nn.Module save and state_ dict method error #912

toolgood opened this issue Feb 14, 2023 · 5 comments · Fixed by #915

Comments

@toolgood
Copy link

nn.Module save and state_ dict method error

    internal partial class Program
    {
        static void Main(string[] args)
        {
            var test = new Test(2);
            Dictionary<string, Tensor> sd = test.state_dict(); // No layer2.modules keyword

            test.save("1.pth");
        }
    }

    public class Test : nn.Module
    {
        public nn.Module layer;
        public nn.Module layer2;

        public Test(int layernum) : base("Test")
        {
            layer = nn.Linear(16, 2);
            layer2 = new Test2((from l in Enumerable.Range(0, layernum)
                                select new Test3()).ToArray(),
                                (from l in Enumerable.Range(0, layernum)
                                select new Test3()).ToArray());
            this.RegisterComponents();
        }
    }
    public class Test2 : nn.Module
    {
        public ModuleList<Test3> modules;
        public ModuleList<nn.Module> modules2;
        public nn.Module layer;
        public Test2(Test3[] ms, nn.Module[] ms2) : base("Test2")
        {
            layer = nn.Linear(16, 2);
            modules = nn.ModuleList(ms);
            modules2 = nn.ModuleList(ms2);
            this.RegisterComponents();
        }
    }
    public class Test3 : nn.Module
    {
        public nn.Module layer;
        public Test3( ) : base("Test3")
        {
            layer = nn.Linear(16, 2);
            this.RegisterComponents();
        }
    }
@GeorgeS2019
Copy link

GeorgeS2019 commented Feb 14, 2023

@toolgood

I assume u have read this <=====

Minor feedback

Some of the discussions involving the choice of extensions "pth", "pt"

Typically in TorchSharp, saving state_dict, the extension "dat" is used. This format seems to be specific only to TorchSharp AFAIK, I could be wrong, not transferable to pytorch, unless further transformation.

@toolgood
Copy link
Author

Error getting parameters,

Test2 class
public ModuleList<nn.Module> modules2; Registered
public ModuleList modules; Not registered

@NiklasGustafsson
Copy link
Contributor

NiklasGustafsson commented Feb 14, 2023

Interesting, thanks for the repro. What happens if you declare Test2.modules as a ModuleList<nn.Module>, instead?

@toolgood
Copy link
Author

Interesting, thanks for the repro. What happens if you declare Test2.modules as a ModuleList<nn.Module>, instead?

Yes, it is normal to declare to ModuleList<nn. Module>

Test2 class
public ModuleList<nn.Module> modules2; Registered
public ModuleList modules; Not registered

@NiklasGustafsson
Copy link
Contributor

Found it. The fix will be in next release. In the meantime, declare the ModuleList<T> with T = nn.Module instead.

@NiklasGustafsson NiklasGustafsson linked a pull request Feb 15, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants