In [61]:
import torch
import torch.nn as nn

In [124]:
class Conv2d(nn.Module):
    """Module allows us to save the name of our layers for cache."""
    def __init__(self, kernel_size, in_channels, out_channels):
        super(Conv2d, self).__init__()
        
        self.conv = nn.Sequential()
        self.conv.add_module(
            f"conv2d_{str(kernel_size)}_{str(out_channels)}",
            nn.Conv2d(in_channels, out_channels, kernel_size)
        )
    
    def forward(self, x):
        return self.conv(x)

In [132]:
class CNN(nn.Module):
    
    def __init__(self, hparams=None):
        super(CNN, self).__init__()
        self.conv1 = Conv2d(3, 300, 300)
        self.conv2 = Conv2d(7, 300, 300)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

In [128]:
x = torch.randn(300, 300, 1)

In [129]:
m = Conv2d(3, 300, 300)

In [130]:
m

Conv2d(
  (conv): Sequential(
    (conv2d_3_300): Conv2d(300, 300, kernel_size=(3, 3), stride=(1, 1))
  )
)

In [131]:
for name, param in m.named_parameters():
    if param.requires_grad:
        print(name)

conv.conv2d_3_300.weight
conv.conv2d_3_300.bias


In [74]:
torch.save(m.state_dict(), 'model')

In [108]:
save = torch.load('model')
model = CNN()
model.load_state_dict(save)

RuntimeError: Error(s) in loading state_dict for CNN:
	Missing key(s) in state_dict: "conv2.conv.conv2d_7_300.weight", "conv2.conv.conv2d_7_300.bias". 
	Unexpected key(s) in state_dict: "conv2.conv.conv2d_5_300.weight", "conv2.conv.conv2d_5_300.bias". 

In [106]:
model

CNN(
  (conv1): Conv2d(
    (conv): Sequential(
      (conv2d_3_300): Conv2d(300, 300, kernel_size=(3, 3), stride=(1, 1))
    )
  )
  (conv2): Conv2d(
    (conv): Sequential(
      (conv2d_5_300): Conv2d(300, 300, kernel_size=(5, 5), stride=(1, 1))
    )
  )
)

In [104]:
torch.save(cnn.state_dict(), 'model')

In [133]:
for name, param in cnn.named_parameters():
    if param.requires_grad:
        print(name)

conv1.conv.conv2d_3_300.weight
conv1.conv.conv2d_3_300.bias
conv2.conv.conv2d_5_300.weight
conv2.conv.conv2d_5_300.bias


In [109]:
save.keys()

odict_keys(['conv1.conv.conv2d_3_300.weight', 'conv1.conv.conv2d_3_300.bias', 'conv2.conv.conv2d_5_300.weight', 'conv2.conv.conv2d_5_300.bias'])

In [139]:
list(cnn.named_parameters())[0][0]

'conv1.conv.conv2d_3_300.weight'

In [141]:
save.keys()

odict_keys(['conv1.conv.conv2d_3_300.weight', 'conv1.conv.conv2d_3_300.bias', 'conv2.conv.conv2d_5_300.weight', 'conv2.conv.conv2d_5_300.bias'])