In [23]:
import torch.nn as nn
import torch

class ModuleDict(nn.Module):
    def __init__(self):
        super(ModuleDict, self).__init__()
        self.choices = nn.ModuleDict({ # 接受一个由网络层层组成的字典
            'conv': nn.Conv2d(10, 10, 3),
            'pool': nn.MaxPool2d(3)
        })

        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'prelu': nn.PReLU()
        })

    def forward(self, x, choice, act):
        """前向传播(选择不同的网络层进行组合)"""
        x = self.choices[choice](x) # 类似字典的访问
        x = self.activations[act](x)
        return x

In [24]:
net = ModuleDict()
net

ModuleDict(
  (choices): ModuleDict(
    (conv): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
    (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  )
  (activations): ModuleDict(
    (prelu): PReLU(num_parameters=1)
    (relu): ReLU()
  )
)

In [25]:
img = torch.randn((4, 10, 32, 32))
output = net(img, 'conv', 'relu')
print(output.shape)

torch.Size([4, 10, 30, 30])


In [26]:
net.choices['conv'] # 与字典的访问问题

Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))

In [27]:
net.choices.keys() #  Return an iterable of the ModuleDict value pairs.

odict_keys(['conv', 'pool'])

In [28]:
net.choices.values() # Return an iterable of the ModuleDict values.

odict_values([Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)), MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)])

In [29]:
net.choices.items() # Return an iterable of the ModuleDict key/value pairs.

odict_items([('conv', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))), ('pool', MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False))])

In [30]:
net.choices.clear() # Remove all items from the ModuleDict.
net

ModuleDict(
  (choices): ModuleDict()
  (activations): ModuleDict(
    (prelu): PReLU(num_parameters=1)
    (relu): ReLU()
  )
)

In [31]:
net.choices['conv1'] = nn.Conv2d(10, 10, 3) # 若键'conv1'不存在,创建一个新的键并赋值,否则改写键对应的值
net

ModuleDict(
  (choices): ModuleDict(
    (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  )
  (activations): ModuleDict(
    (prelu): PReLU(num_parameters=1)
    (relu): ReLU()
  )
)

In [32]:
# Update the ModuleDict with the key-value pairs from a mapping or an iterable, overwriting existing keys
net.choices.update({'conv1': nn.Conv2d(20, 20, 3) ,
                    'pool': nn.MaxPool2d(3)})
net

ModuleDict(
  (choices): ModuleDict(
    (conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1))
    (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  )
  (activations): ModuleDict(
    (prelu): PReLU(num_parameters=1)
    (relu): ReLU()
  )
)

In [33]:
net.activations.pop('relu') # Remove key from the ModuleDict and return its module.
net


ModuleDict(
  (choices): ModuleDict(
    (conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1))
    (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  )
  (activations): ModuleDict(
    (prelu): PReLU(num_parameters=1)
  )
)