In [59]:
import torch
import torch.nn as nn

In [60]:
class ScaleModule(nn.Module):
    def __init__(self, scale_factor1, scale_factor2):
        super().__init__()
        # Add a buffer to the module.
        # This is typically used to register a buffer that should not to be considered a model parameter. 
        self.register_buffer(name="scale1", 
                             tensor=torch.tensor(scale_factor1), 
                             # whether the buffer is part of this module’s state_dict.
                             persistent=False)
        self.register_buffer("scale2", torch.tensor(scale_factor2), True) 
    
    def forward(self, x):
        return x * self.scale1 * self.scale2

In [61]:
model = ScaleModule(2, 3)
x = torch.tensor([-1, -2])
print(x)

output = model(x)
print(output)


tensor([-1, -2])
tensor([ -6, -12])


In [62]:
print(model.scale1)
print(model.scale2)
print(list(model.buffers()))
print("state_dict: ", model.state_dict(), end='\n\n')

model.to(device="cuda" if torch.cuda.is_available() else "cpu")

print(model.scale1)  # 自动设备切换
print(model.scale2)
print(list(model.buffers()))
print("state_dict: ", model.state_dict())

tensor(2)
tensor(3)
[tensor(2), tensor(3)]
state_dict:  OrderedDict([('scale2', tensor(3))])

tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
[tensor(2, device='cuda:0'), tensor(3, device='cuda:0')]
state_dict:  OrderedDict([('scale2', tensor(3, device='cuda:0'))])
