diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 3380c5cf3fb0b..c02b3204573b7 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -22,6 +22,7 @@ These are the basic building blocks for graphs: :nosignatures: :template: classtemplate.rst + ~parameter.Buffer ~parameter.Parameter ~parameter.UninitializedParameter ~parameter.UninitializedBuffer diff --git a/test/distributed/fsdp/test_fsdp_flatten_params.py b/test/distributed/fsdp/test_fsdp_flatten_params.py index f8c148462f20e..df533ca6d1752 100644 --- a/test/distributed/fsdp/test_fsdp_flatten_params.py +++ b/test/distributed/fsdp/test_fsdp_flatten_params.py @@ -59,7 +59,7 @@ def _get_transformer(self, seed=0): dim_feedforward=128, dropout=0.1, ) - module.register_buffer("dummy_buffer", torch.tensor(1.0)) + module.dummy_buffer = nn.Buffer(torch.tensor(1.0)) def get_input(device, dtype): torch.manual_seed(1) # keep everything deterministic diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index d2589cc9d7e49..e0a4db47dc563 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -594,11 +594,11 @@ def init_nested_wrapped_module(): # Check that `device_id` with `sync_module_states=True` works nested_wrapped_module = init_nested_wrapped_module() - nested_wrapped_module.register_buffer( - "buf", torch.ones((2, 2), device="cpu") * self.rank + nested_wrapped_module.buf = nn.Buffer( + torch.ones((2, 2), device="cpu") * self.rank ) - nested_wrapped_module.module[0].register_buffer( - "buf", torch.ones((3, 2), device="cpu") * self.rank + nested_wrapped_module.module[0].buf = nn.Buffer( + torch.ones((3, 2), device="cpu") * self.rank ) nested_wrapped_module = FSDP( nested_wrapped_module, @@ -892,7 +892,7 @@ def __init__(self, rank): torch.manual_seed(rank) torch.cuda.manual_seed(rank) self.lin = nn.Linear(10, 10, bias=False) - self.register_buffer("buffer", torch.ones(1) * rank) + self.buffer = nn.Buffer(torch.ones(1) * rank) m = MyModel(self.rank).cuda() _assert_module_states( diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 6374b06702b50..cae4b9a15479d 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -102,7 +102,7 @@ def __init__( super().__init__() self.inner = Linear(*INNER_SHAPE) if register_buffers: - self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE)) + self.inner.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE)) self.inner.register_buffer( "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False ) @@ -121,7 +121,7 @@ def __init__( ) self.outer = Linear(*OUTER_SHAPE) if register_buffers: - self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE)) + self.outer.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE)) self.outer.register_buffer( "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False ) diff --git a/test/distributed/fsdp/test_fsdp_unshard_params.py b/test/distributed/fsdp/test_fsdp_unshard_params.py index 1cef7e8ec889b..0407d7f2f1ff2 100644 --- a/test/distributed/fsdp/test_fsdp_unshard_params.py +++ b/test/distributed/fsdp/test_fsdp_unshard_params.py @@ -436,7 +436,7 @@ def _test_named_parameters_and_buffers(self, prefix: str, recurse: bool): CUDAInitMode.CUDA_BEFORE, deterministic=True, ) - model.register_buffer("buffer", torch.ones(1)) + model.buffer = nn.Buffer(torch.ones(1)) # Wrap the top-level with FSDP since `named_parameters()` and # `named_buffers` will contain FSDP prefixes if called on a non-FSDP # root module @@ -449,7 +449,7 @@ def _test_named_parameters_and_buffers(self, prefix: str, recurse: bool): ), self.process_group, ) - fsdp_model.register_buffer("buffer", torch.ones(1)) + fsdp_model.buffer = nn.Buffer(torch.ones(1)) with FSDP.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index 485df8f5b5437..3796a09103ddc 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -857,8 +857,7 @@ def test_local_optimizer_parity( torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM), ).to(self.device) - model.register_buffer( - "test_buffer", + model.test_buffer = torch.nn.Buffer( torch.ones((1), device=self.device) * self.rank, ) # Define models/optimizers for DDP with ZeRO and DDP with local diff --git a/test/distributed/test_data_parallel.py b/test/distributed/test_data_parallel.py index 8e380ca9df86f..497e2d713e4f8 100644 --- a/test/distributed/test_data_parallel.py +++ b/test/distributed/test_data_parallel.py @@ -43,8 +43,8 @@ def test_data_parallel_buffers_requiring_grad(self): class TestModule(nn.Module): def __init__(self, t): super().__init__() - self.register_buffer("t_rg", t) - self.register_buffer("t_not_rg", t.clone().detach()) + self.t_rg = nn.Buffer(t, t.requires_grad) + self.t_not_rg = nn.Buffer(t.clone().detach()) def forward(self, x): return x * self.t_rg + self.t_not_rg diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 058118d6fa4c1..5e2ed8e4736d3 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1159,9 +1159,8 @@ class DuplicateModule(nn.Module): def __init__(self) -> None: super().__init__() self._param = torch.randn((3,), device="cuda") - self.register_buffer( - "_buf", torch.randn((3,), requires_grad=False, device="cuda") - ) + self._buf = torch.nn.Buffer( + torch.randn((3,), requires_grad=False, device="cuda")) def forward(self, x: torch.Tensor) -> torch.Tensor: # Use `_param` and `_buf` each twice in this compiled forward @@ -1190,8 +1189,8 @@ def test_fsdp_dup_tensors_diff_source(self): class BufModule(nn.Module): def __init__(self) -> None: super().__init__() - self.register_buffer( - "_buf", torch.randn((3,), requires_grad=False, device="cuda") + self._buf = nn.Buffer( + torch.randn((3,), requires_grad=False, device="cuda") ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -1203,7 +1202,7 @@ def __init__(self) -> None: self._param = nn.Parameter(torch.randn((1,), device="cuda")) self._buf_module = BufModule() # Share the buffer, meaning same tensor but different source - self.register_buffer("_buf", self._buf_module._buf) + self._buf = self._buf_module._buf def forward(self, x: torch.Tensor) -> torch.Tensor: # Use the same buffer tensor twice in the compiled forward, diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index e224cd4c4eb4e..9a84dc3788450 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -990,7 +990,7 @@ class MyBlock(torch.nn.Module): def __init__(self): super().__init__() self.weight = torch.nn.Parameter(torch.ones(1, 1)) - self.register_buffer("buffer", torch.ones(1, 1)) + self.buffer = torch.nn.Buffer(torch.ones(1, 1)) def forward(self, x): x = torch.nn.functional.linear(x, torch.randn(4, 4)) @@ -3013,7 +3013,7 @@ def test_not_functionalize(self): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("buffer1", torch.ones(6, 2)) + self.buffer1 = torch.nn.Buffer(torch.ones(6, 2)) def forward(self, x): x.add_(2) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 0b12767583bdc..e64473cfecc4e 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1589,7 +1589,7 @@ def __init__(self): super().__init__() self.relu = torch.nn.ReLU() self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) + self.buf0 = torch.nn.Buffer(torch.randn(10, 10)) def forward(self, x): return self.relu(self.linear(x) + self.buf0) @@ -1802,7 +1802,7 @@ class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) + self.buf0 = torch.nn.Buffer(torch.randn(10, 10)) def forward(self, x): return self.r(torch.sin(x)) + self.buf0 @@ -1829,7 +1829,7 @@ class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) + self.register_buffer("buf0", torch.nn.Buffer(torch.randn(10, 10))) self.register_parameter( name="param0", param=torch.nn.Parameter(torch.randn(10, 10)) ) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 590fe22ff0564..97a57bfc24762 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -2055,8 +2055,8 @@ def test_sort_out2(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("sorted", torch.ones(4, 4)) - self.register_buffer("indices", torch.ones(4, 4, dtype=torch.long)) + self.sorted = torch.nn.Buffer(torch.ones(4, 4)) + self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long)) def forward(self, x): torch.sort(x, out=(self.sorted, self.indices)) @@ -2087,7 +2087,7 @@ def test_sigmoid_out2(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("base", torch.ones(4, 4)) + self.base = torch.nn.Buffer(torch.ones(4, 4)) def forward(self, x): torch.sigmoid(x, out=self.base) @@ -2395,8 +2395,8 @@ def test_named_buffers(self): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("x", torch.ones(3)) - self.register_buffer("y", torch.ones(3)) + self.x = torch.nn.Buffer(torch.ones(3)) + self.y = torch.nn.Buffer(torch.ones(3)) def forward(self, inp): res = 0 diff --git a/test/export/test_verifier.py b/test/export/test_verifier.py index c85e90f1b435f..8ad013b573172 100644 --- a/test/export/test_verifier.py +++ b/test/export/test_verifier.py @@ -166,8 +166,8 @@ def __init__(self): self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) - self.register_buffer("my_buffer1", torch.tensor(3.0)) - self.register_buffer("my_buffer2", torch.tensor(4.0)) + self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0)) + self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method @@ -189,8 +189,8 @@ def __init__(self): self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) - self.register_buffer("my_buffer1", torch.tensor(3.0)) - self.register_buffer("my_buffer2", torch.tensor(4.0)) + self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0)) + self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 9d515606621c0..3b949374c2afd 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -3225,7 +3225,7 @@ def test_real_weights_in_symbolic_mode_with_inplace_ops(self): class M(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("buffer", torch.ones(4, 5)) + self.buffer = torch.nn.Buffer(torch.ones(4, 5)) def forward(self, x): y = self.buffer.add_(3) @@ -4121,7 +4121,7 @@ def test_aot_export_forward_mutation_no_buffer_mut(self): class M(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("buffer1", torch.ones(6, 4)) + self.buffer1 = torch.nn.Buffer(torch.ones(6, 4)) def forward(self, x): x.add_(4) @@ -4147,7 +4147,7 @@ def test_aot_export_forward_mutation_multiple_mut(self): class M(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("buffer1", torch.ones(6, 4)) + self.buffer1 = torch.nn.Buffer(torch.ones(6, 4)) def forward(self, x, y): y.add_(4) diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index b54b9762fbc32..8a2c0c320ad78 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -3688,8 +3688,8 @@ def __init__(self): super().__init__() self.bias = nn.Parameter(torch.randn(3)) self.linear = nn.Linear(3, 3) - self.register_buffer("buffer", torch.randn(3)) - self.register_buffer("buffer_tied", self.buffer) + self.buffer = nn.Buffer(torch.randn(3)) + self.buffer_tied = self.buffer def forward(self, x): x = self.linear(x) @@ -3719,7 +3719,7 @@ class Foo(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(3, 3) - self.register_buffer("buffer", torch.randn(3)) + self.buffer = nn.Buffer(torch.randn(3)) def forward(self, x): x = self.linear(x) @@ -3741,7 +3741,7 @@ class Foo(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(3, 3) - self.register_buffer("buffer", torch.randn(3)) + self.buffer = nn.Buffer(torch.randn(3)) def forward(self, x): x = self.linear(x) @@ -3801,8 +3801,8 @@ def __init__(self): self.linear = nn.Linear(3, 3) self.weight = self.linear.weight self.bias = self.linear.bias - self.register_buffer("buffer", torch.randn(3)) - self.register_buffer("buffer_tied", self.buffer) + self.buffer = nn.Buffer(torch.randn(3)) + self.buffer_tied = self.buffer def forward(self, x): x = self.linear(x) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index db02d19310097..ac7018d668b8d 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -581,8 +581,7 @@ def __init__(self): start = math.log2(0.5) end = math.log2(1 / (2**8)) - self.register_buffer( - "scales", + self.scales = nn.Buffer( 2 ** torch.arange( start, diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 4fd58bae854dd..6134ae24c7068 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -8403,9 +8403,7 @@ def fn(x, p1, p0): class Repro(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer( - "_tensor_constant0", torch.randn([], dtype=torch.float32) - ) + self._tensor_constant0 = nn.Buffer(torch.randn([], dtype=torch.float32)) def forward(self, arg0_1, arg1_1): convert_element_type = torch.ops.prims.convert_element_type.default( diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index d16f039798895..cbaabb771dffa 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -487,6 +487,7 @@ def __init__(self): self.parameter_b = torch.nn.Parameter(torch.randn(4)) self.submodule_b = Submodule() + self.buffer_b = torch.nn.Buffer(torch.randn(4)) m = TestModule() m_loaded = self.getExportImportCopy(torch.jit.script(m)) @@ -526,7 +527,7 @@ def __init__(self): super().__init__() self.foo = torch.nn.Linear(2, 3, device="meta") self.bar = torch.nn.Linear(3, 4) - self.register_buffer("buffer", torch.randn(4, device="meta")) + self.buffer = torch.nn.Buffer(torch.randn(4, device="meta")) def forward(self, x): x = self.foo(x) @@ -1145,6 +1146,7 @@ def __init__(self): self.parameter_b = torch.nn.Parameter(torch.randn(4)) self.submodule_b = Submodule() + self.buffer_b = torch.nn.Buffer(torch.randn(4)) m = TestModule() m_loaded = self.getExportImportCopy(torch.jit.script(m)) diff --git a/test/nn/test_lazy_modules.py b/test/nn/test_lazy_modules.py index 2de0dc656bfce..a7f9e1026ccc5 100644 --- a/test/nn/test_lazy_modules.py +++ b/test/nn/test_lazy_modules.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from torch.nn import Parameter +from torch.nn import Buffer, Parameter from torch.nn.parameter import UninitializedBuffer, UninitializedParameter from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_utils import ( @@ -52,29 +52,29 @@ def test_lazy_module_parameter(self): @suppress_warnings def test_lazy_module_buffer(self): module = LazyModule() - module.register_buffer("test_buffer", UninitializedBuffer()) + module.test_buffer = UninitializedBuffer() self.assertTrue(module.has_uninitialized_params()) state_dict = module.state_dict() self.assertIsInstance(state_dict["test_buffer"], UninitializedBuffer) new_module = LazyModule() # An error is raised when there is an attempt to replace an existing parameter # with an uninitialized one - new_module.register_buffer("test_buffer", torch.ones(5, 5)) - with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"): + new_module.test_buffer = Buffer(torch.ones(5, 5)) + with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'): new_module.load_state_dict(state_dict) # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one new_module = LazyModule() - new_module.register_buffer("test_buffer", torch.ones(5, 5)) + new_module.test_buffer = Buffer(torch.ones(5, 5)) module.load_state_dict(new_module.state_dict()) self.assertEqual(module.test_buffer, torch.ones((5, 5))) # Uninitialized parameters are left unchanged module = LazyModule() - module.register_buffer("test_buffer", UninitializedBuffer()) + module.test_buffer = UninitializedBuffer() self.assertTrue(module.has_uninitialized_params()) new_module = LazyModule() - new_module.register_buffer("test_buffer", UninitializedBuffer()) + new_module.test_buffer = UninitializedBuffer() module.load_state_dict(new_module.state_dict()) module.load_state_dict(new_module.state_dict()) self.assertTrue(module.has_uninitialized_params()) @@ -90,7 +90,7 @@ def test_lazy_module_jit_param(self): @suppress_warnings def test_lazy_module_jit_buffer(self): module = LazyModule() - module.register_buffer("test_buffer", UninitializedBuffer()) + module.test_buffer = UninitializedBuffer() self.assertTrue(module.has_uninitialized_params()) with self.assertRaisesRegex(RuntimeError, "run a forward pass"): torch.jit.script(module) @@ -106,7 +106,7 @@ def test_lazy_share_memory_param(self): @suppress_warnings def test_lazy_share_memory_buffer(self): module = LazyModule() - module.register_buffer("test_buffer", UninitializedBuffer()) + module.test_buffer = UninitializedBuffer() self.assertTrue(module.has_uninitialized_params()) with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"): module.share_memory() diff --git a/test/nn/test_parametrization.py b/test/nn/test_parametrization.py index d547d8abb0db0..442cd960a41ee 100644 --- a/test/nn/test_parametrization.py +++ b/test/nn/test_parametrization.py @@ -11,7 +11,7 @@ import torch.nn.utils.parametrize as parametrize from torch import Tensor from torch.__future__ import get_swap_module_params_on_conversion -from torch.nn import Parameter +from torch.nn import Buffer, Parameter from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase @@ -361,7 +361,7 @@ def forward(self, x): # Instantiate parametrizations on buffers. It should work as expected delattr(model, "bias") - model.register_buffer("bias", torch.ones(8)) + model.bias = Buffer(torch.ones(8)) parametrize.register_parametrization(model, "bias", FirstZero()) parametrize.register_parametrization(model, "bias", LastZero()) self.assertTrue(parametrize.is_parametrized(model)) @@ -391,8 +391,8 @@ def test_serialization_parametrization(self): class Orthogonal(nn.Module): def __init__(self, n): super().__init__() - self.register_buffer("id", torch.eye(n)) - self.register_buffer("B", torch.empty(n, n)) + self.id = Buffer(torch.eye(n)) + self.B = Buffer(torch.empty(n, n)) init.orthogonal_(self.B) def forward(self, X): @@ -456,7 +456,7 @@ def right_inverse(self, X): class Orthogonal(nn.Module): def __init__(self, n): super().__init__() - self.register_buffer("B", torch.eye(n)) + self.B = Buffer(torch.eye(n)) def forward(self, X): Id = torch.eye(X.size(0)) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index c6f329942e295..b44d43d5c4246 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -292,7 +292,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): scale_1 = self.weight.reshape(1, -1, 1, 1) @@ -4337,7 +4337,7 @@ def test_gather_constant_fold(self): class GatherModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) # torch.nn.Embedding is converted to ONNX::Gather. # Constant folding will be triggerred for constant inputs. # This pattern is common for constant mask inputs in transformer models. @@ -4356,7 +4356,7 @@ def forward(self, x): class GatherModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("weight", torch.ones(2)) + self.weight = torch.nn.Buffer(torch.ones(2)) def forward(self, x): # shape is of rank 0 @@ -4371,7 +4371,7 @@ def forward(self, x): class GatherModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1)) + self.rb = torch.nn.Buffer(torch.randn(1, 1, 3, 1, 1)) def forward(self, x): x += self.rb[0] @@ -9645,7 +9645,7 @@ def test_shape_constant_fold(self): class ShapeModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): shape = self.weight.shape[0] @@ -11146,7 +11146,7 @@ class InnerModule2(torch.nn.Module): def __init__(self, embedding_dim): super().__init__() self.weights = InnerModule2.get_embedding(embedding_dim) - self.register_buffer("_float_tensor", torch.FloatTensor(1)) + self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1)) self.const = 2 @staticmethod @@ -11208,7 +11208,7 @@ def __init__(self, embedding_dim): self.embedding_dim = embedding_dim self.const = 2.5 self.weights = InnerModule.get_embedding(self.embedding_dim) - self.register_buffer("_float_tensor", torch.FloatTensor(1)) + self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1)) @staticmethod def get_embedding(embedding_dim: int): diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index f133b5cf149e2..6ef404f213cae 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -545,7 +545,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): b = self.weight.reshape(1, -1, 1, 1) @@ -568,7 +568,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): div = self.weight.div(torch.tensor([1, 2, 3, 4, 5])) @@ -591,7 +591,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5])) @@ -614,7 +614,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): add = self.weight + torch.tensor([1, 2, 3, 4, 5]) @@ -645,7 +645,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): sub = self.weight - torch.tensor([1, 2, 3, 4, 5]) @@ -676,7 +676,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): sqrt = torch.sqrt(self.weight) @@ -696,7 +696,7 @@ def test_constant_fold_shape(self): class ShapeModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): shape = self.weight.shape[0] diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py index 00462023b6b5b..30574eb42e766 100644 --- a/test/quantization/core/test_quantized_tensor.py +++ b/test/quantization/core/test_quantized_tensor.py @@ -1452,7 +1452,7 @@ def __init__(self, per_channel): s = torch.rand(5, dtype=torch.float64) + 0.1 zp = torch.randint(5, 15, (5,)) x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8) - self.register_buffer('x', x_q) + self.x = torch.nn.Buffer(x_q) @torch.jit.script_method def forward(self): diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py index 52f169b1d5b62..78c8802e9cc7e 100644 --- a/test/quantization/eager/test_quantize_eager_qat.py +++ b/test/quantization/eager/test_quantize_eager_qat.py @@ -94,9 +94,9 @@ def __init__(self, self.beta = nn.Parameter(torch.empty(out_channels)) self.affine = True self.track_running_stats = True - self.register_buffer('running_mean', torch.zeros(out_channels)) - self.register_buffer('running_var', torch.ones(out_channels)) - self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) + self.running_mean = nn.Buffer(torch.zeros(out_channels)) + self.running_var = nn.Buffer(torch.ones(out_channels)) + self.num_batches_tracked = nn.Buffer(torch.tensor(0, dtype=torch.long)) self.activation_post_process = self.qconfig.activation() self.weight_fake_quant = self.qconfig.weight() if bias: diff --git a/test/test_fx.py b/test/test_fx.py index eadcd750aeded..04980217267bd 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -854,7 +854,7 @@ def __init__(self): self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer('buffer', torch.randn(bs + 100, d_hid)) + self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid)) def forward(self, x): x = torch.mm(x, self.mm_param) @@ -2703,7 +2703,7 @@ def getitem_inner(self): class GetItemBase(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer('pe', torch.randn(8, 8)) + self.pe = torch.nn.Buffer(torch.randn(8, 8)) class GetItem1(GetItemBase): def forward(self, x): @@ -3068,7 +3068,7 @@ class B(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(100, 200) - self.register_buffer("buf", torch.randn(2, 3)) + self.buf = torch.nn.Buffer(torch.randn(2, 3)) self.net_c = C() def forward(self, x): @@ -3238,7 +3238,7 @@ class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) - self.register_buffer('buffer', torch.ones(1)) + self.buffer = torch.nn.Buffer(torch.ones(1)) def forward(self, x): return self.l1(x) + self.buffer diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index d3231df353134..28935287148ac 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1248,8 +1248,8 @@ def __init__(self): self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2)) self.linear = torch.nn.Linear(2, 2) self.attr = torch.randn(2) - self.register_buffer("attr2", torch.randn(2)) - self.register_buffer("attr3", torch.ones(2, dtype=torch.int32)) + self.attr2 = torch.nn.Buffer(torch.randn(2)) + self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32)) def forward(self, x): return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x)) diff --git a/test/test_jit.py b/test/test_jit.py index bb6f4e2558885..c72b2bcd3805a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -481,7 +481,7 @@ def test_restore_device_cuda(self): class MyModule(torch.jit.ScriptModule): def __init__(self): super().__init__() - self.register_buffer('b0', torch.randn(1, 3)) + self.b0 = nn.Buffer(torch.randn(1, 3)) self.p0 = nn.Parameter(torch.randn(2, 3)) @torch.jit.script_method @@ -537,7 +537,7 @@ def __init__(self): super().__init__() whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu') self.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1)) - self.register_buffer('b0', whole_tensor.narrow(0, 3, 1)) + self.b0 = nn.Buffer(whole_tensor.narrow(0, 3, 1)) m = Foo() m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0')) @@ -3926,7 +3926,7 @@ def test_cpp_module_iterator(self): a.p = nn.Parameter(torch.rand(3, 4)) a.foo = nn.Module() a.foo.name = 'foo' - a.foo.register_buffer('b', torch.rand(1, 1)) + a.foo.b = nn.Buffer(torch.rand(1, 1)) a.foo.bar = nn.Module() a.foo.bar.name = 'bar' a.foo.bar.an_int = 4 @@ -8891,7 +8891,7 @@ def test_script_module_param_buffer_mutation(self): class ModuleBufferMutate(torch.jit.ScriptModule): def __init__(self): super().__init__() - self.register_buffer('running_var', torch.tensor(0, dtype=torch.long)) + self.running_var = nn.Buffer(torch.tensor(0, dtype=torch.long)) @torch.jit.script_method def forward(self): @@ -9018,12 +9018,12 @@ class DerivedStateModule(torch.jit.ScriptModule): def __init__(self): super(TestScript.DerivedStateModule, self).__init__() self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float)) - self.register_buffer('derived', torch.neg(self.param).detach().clone()) + self.derived = nn.Buffer(torch.neg(self.param).detach().clone()) # This is a flag so we can test that the pack method was called - self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long)) + self.pack_called = nn.Buffer(torch.zeros(1, dtype=torch.long)) # This is a flag so we can test that the unpack method was called - self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long)) + self.unpack_called = nn.Buffer(torch.zeros(1, dtype=torch.long)) @torch.jit.script_method def _pack(self): @@ -9203,7 +9203,7 @@ def test_pack_unpack_nested(self): class SubSubMod(torch.jit.ScriptModule): def __init__(self): super().__init__() - self.register_buffer('buf', torch.ones(3, 4) * 3) + self.buf = nn.Buffer(torch.ones(3, 4) * 3) @torch.jit.script_method def _pack(self): @@ -9220,7 +9220,7 @@ def forward(self, x): class SubMod(torch.jit.ScriptModule): def __init__(self): super().__init__() - self.register_buffer('buf', torch.ones(3, 4) * 2) + self.buf = nn.Buffer(torch.ones(3, 4) * 2) self.ssm = SubSubMod() @torch.jit.script_method @@ -9239,7 +9239,7 @@ class Mod(torch.jit.ScriptModule): def __init__(self): super().__init__() self.submod = SubMod() - self.register_buffer('buf', torch.ones(3, 4) * 1) + self.buf = nn.Buffer(torch.ones(3, 4) * 1) @torch.jit.script_method def _pack(self): @@ -13072,7 +13072,7 @@ def __init__(self, in_features, out_features): self.out_features = out_features self.weight = torch.nn.Parameter(torch.empty(out_features, in_features)) self.bias = torch.nn.Parameter(torch.empty(out_features)) - self.register_buffer('counter', torch.ones(out_features)) + self.counter = nn.Buffer(torch.ones(out_features)) self.reset_parameters() def reset_parameters(self): @@ -13125,7 +13125,7 @@ def __init__(self, in_features, out_features): super().__init__() self.weight = torch.nn.Parameter(torch.ones(out_features, in_features)) self.bias = torch.nn.Parameter(torch.ones(out_features)) - self.register_buffer("buffer", torch.ones(out_features)) + self.buffer = nn.Buffer(torch.ones(out_features)) self.submodule = Submodule() def forward(self, x): @@ -13580,8 +13580,8 @@ class Root(torch.jit.ScriptModule): def __init__(self, number): super().__init__() - self.register_buffer('buffer1', torch.ones(2, 2)) - self.register_buffer('buffer2', torch.ones(2, 2)) + self.buffer1 = nn.Buffer(torch.ones(2, 2)) + self.buffer2 = nn.Buffer(torch.ones(2, 2)) self.number = number @torch.jit.script_method @@ -13599,8 +13599,8 @@ class M(torch.jit.ScriptModule): def __init__(self, number, submodule): super().__init__() - self.register_buffer('buffer1', torch.ones(2, 2)) - self.register_buffer('buffer2', torch.ones(2, 2)) + self.buffer1 = nn.Buffer(torch.ones(2, 2)) + self.buffer2 = nn.Buffer(torch.ones(2, 2)) self.number = number self.submodule = submodule @@ -13636,8 +13636,8 @@ def __setstate__(self, state): class NoArgState(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer('buffer1', torch.ones(2, 2)) - self.register_buffer('buffer2', torch.ones(2, 2)) + self.buffer1 = nn.Buffer(torch.ones(2, 2)) + self.buffer2 = nn.Buffer(torch.ones(2, 2)) def forward(self): pass @@ -15088,7 +15088,7 @@ class M(torch.jit.ScriptModule): def __init__(self): super().__init__() tensor = torch.zeros(1, requires_grad=False) - self.register_buffer('some_state', torch.nn.Parameter(tensor)) + self.some_state = nn.Buffer(torch.nn.Parameter(tensor)) @torch.jit.script_method def forward(self, x): @@ -15481,8 +15481,8 @@ def __init__(self): self.mod = (torch.nn.ReLU()) self.mod2 = (torch.nn.ReLU()) self.mod3 = torch.nn.Sequential(torch.nn.Sequential(torch.nn.ReLU())) - self.register_buffer('x', torch.zeros(3)) - self.register_buffer('y', torch.zeros(3)) + self.x = nn.Buffer(torch.zeros(3)) + self.y = nn.Buffer(torch.zeros(3)) self.z = torch.zeros(3) def bleh(self): diff --git a/test/test_mps.py b/test/test_mps.py index 24c4e2d45e48e..04c0f173719c1 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -19,7 +19,7 @@ import itertools from collections import defaultdict from torch import inf -from torch.nn import Parameter +from torch.nn import Buffer, Parameter from torch.testing._internal import opinfo from torch.testing._internal.common_utils import \ (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI, @@ -8500,14 +8500,14 @@ class Layer(nn.Module): def __init__(self): super().__init__() self.layer_dummy_param = Parameter(torch.empty(3, 5)) - self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7)) + self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7)) class Net(nn.Module): def __init__(self): super().__init__() self.l1 = Layer() self.dummy_param = Parameter(torch.empty(3, 5)) - self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1)) + self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1)) l = Layer() n = Net() diff --git a/test/test_nn.py b/test/test_nn.py index 008354ad721eb..6976b601afe6d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -26,7 +26,7 @@ from torch.nn.utils import parameters_to_vector, vector_to_parameters from torch.nn.utils.fusion import fuse_conv_bn_weights from torch.nn.utils.fusion import fuse_linear_bn_weights -from torch.nn import Parameter +from torch.nn import Buffer, Parameter from torch.nn.parallel._functions import Broadcast from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ @@ -360,8 +360,8 @@ def names(named_buffers): class M(nn.Module): def __init__(self): super().__init__() - self.register_buffer("buffer1", torch.empty(3, 5)) - self.register_buffer("buffer2", self.buffer1) + self.buffer1 = Buffer(torch.empty(3, 5)) + self.buffer2 = self.buffer1 m = M() self.assertEqual(names(m.named_buffers()), @@ -420,7 +420,7 @@ def test_dir(self): linear = nn.Linear(2, 2) linear._test_submodule = nn.Linear(2, 2) linear._test_parameter = Parameter(torch.empty(2, 2)) - linear.register_buffer('_test_buffer', torch.empty(2, 2)) + linear._test_buffer = Buffer(torch.empty(2, 2)) keys = dir(linear) self.assertIn('_test_submodule', keys) self.assertIn('_test_parameter', keys) @@ -525,6 +525,9 @@ def test_register_buffer_raises_error_if_attr_exists(self): with self.assertRaises(KeyError): m.register_buffer('attribute_name', torch.rand(5)) + with self.assertRaises(KeyError): + m.attribute_name = Buffer(torch.rand(5)) + del m.attribute_name m.register_parameter('attribute_name', nn.Parameter()) with self.assertRaises(KeyError): @@ -551,12 +554,18 @@ def test_register_buffer_allows_overwriting_with_same_name(self): self.assertEqual(m.buffer_name, buffer2) m.register_buffer('buffer_name', buffer3) self.assertEqual(m.buffer_name, buffer3) + m.buffer_name = Buffer(buffer1) + self.assertEqual(m.buffer_name, Buffer(buffer1)) + m.buffer_name = Buffer(buffer2) + self.assertEqual(m.buffer_name, Buffer(buffer2)) + m.buffer_name = Buffer(buffer3) + self.assertEqual(m.buffer_name, Buffer(buffer3)) def test_get_buffer(self): m = nn.Module() buffer1 = torch.randn(2, 3) buffer2 = torch.randn(4, 5) - m.register_buffer('foo', buffer1) + m.foo = Buffer(buffer1) m.register_buffer('bar', buffer2) self.assertEqual(buffer1, m.get_buffer('foo')) self.assertEqual(buffer2, m.get_buffer('bar')) @@ -570,13 +579,13 @@ def __init__(self, foo, bar): class Sub(nn.Module): def __init__(self, foo, bar): super().__init__() - self.register_buffer('foo', foo) + self.foo = Buffer(foo) self.subsub = SubSub(bar) class SubSub(nn.Module): def __init__(self, bar): super().__init__() - self.register_buffer('bar', bar) + self.bar = Buffer(bar) foo = torch.randn(2, 3) bar = torch.randn(4, 5) @@ -586,33 +595,35 @@ def __init__(self, bar): def test_buffer_not_persistent(self): m = nn.Module() - m.register_buffer('buf', torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5), persistent=False) self.assertTrue(len(list(m.buffers())) == 1) self.assertTrue(len(m.state_dict()) == 0) def test_buffer_not_persistent_del(self): m = nn.Module() - m.register_buffer('buf', torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5), persistent=False) del m.buf self.assertTrue(len(list(m.buffers())) == 0) def test_buffer_not_persistent_overwrite(self): m = nn.Module() - m.register_buffer('buf', torch.rand(5), persistent=False) - m.register_buffer('buf', torch.rand(5)) + m.buf = nn.Buffer(torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5)) # can we overwrite a non-persistent buffer with a persistent one? self.assertTrue(len(list(m.buffers())) == 1) self.assertTrue(len(m.state_dict()) == 1) # can we overwrite a persistent buffer with a non-persistent one? - m.register_buffer('buf', torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5), persistent=False) self.assertTrue(len(list(m.buffers())) == 1) self.assertTrue(len(m.state_dict()) == 0) def test_buffer_not_persistent_assign(self): m = nn.Module() - m.register_buffer('buf', torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5), persistent=False) + self.assertTrue(len(list(m.buffers())) == 1) + self.assertTrue(len(m.state_dict()) == 0) # Assigning None removes the buffer but if we then assign a new Tensor # to the same property, it should still be marked as a buffer. @@ -630,7 +641,7 @@ def test_buffer_not_persistent_assign(self): def test_buffer_not_persistent_load(self): m = nn.Module() - m.register_buffer('buf', torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5), persistent=False) m.load_state_dict({}) def test_register_parameter_raises_error_if_name_is_not_string(self): @@ -652,6 +663,11 @@ def test_register_parameter_raises_error_if_attr_exists(self): with self.assertRaises(KeyError): m.register_parameter('attribute_name', nn.Parameter()) + del m.attribute_name + m.attribute_name = Buffer(torch.rand(5)) + with self.assertRaises(KeyError): + m.register_parameter('attribute_name', nn.Parameter()) + del m.attribute_name m.add_module('attribute_name', nn.Module()) with self.assertRaises(KeyError): @@ -1616,7 +1632,7 @@ def test_type(self): net.l = l net.l2 = l net.add_module('empty', None) - net.register_buffer('indices', torch.LongTensor(1)) + net.indices = Buffer(torch.LongTensor(1)) net.float() self.assertIsInstance(l.weight.data, torch.FloatTensor) self.assertIsInstance(l.bias.data, torch.FloatTensor) @@ -2493,8 +2509,8 @@ def test_assignments(get_list, a, b, c): del l.a, l.b self.assertEqual(list(l.children()), []) - buf = torch.randn(10) - l.register_buffer('buf', buf) + buf = Buffer(torch.randn(10)) + l.buf = buf self.assertIs(l.buf, buf) l.buf = None self.assertIs(l.buf, None) diff --git a/test/test_stateless.py b/test/test_stateless.py index 32ec45937059f..472fb56dff236 100644 --- a/test/test_stateless.py +++ b/test/test_stateless.py @@ -18,7 +18,7 @@ class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) - self.register_buffer('buffer', torch.ones(1)) + self.buffer = torch.nn.Buffer(torch.ones(1)) self.foo = 0.0 def forward(self, x): @@ -30,8 +30,8 @@ def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) self.tied_bias = self.l1.bias - self.register_buffer('buffer', torch.ones(1)) - self.register_buffer('tied_buffer', self.buffer) + self.buffer = torch.nn.Buffer(torch.ones(1)) + self.tied_buffer = self.buffer def forward(self, x): return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer @@ -427,7 +427,7 @@ def __repr__(self): def test_tied_weights_warns(self, functional_call): module = MockModule() module.tied_bias = module.l1.bias - module.register_buffer("tied_buffer", module.buffer) + module.tied_buffer = torch.nn.Buffer(module.buffer) @parametrize("functional_call", [ subtest(torch.func.functional_call, "torch_func"), @@ -632,7 +632,7 @@ def test_setattr(self, functional_call): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer('foo', torch.tensor([0.0])) + self.foo = torch.nn.Buffer(torch.tensor([0.0])) def forward(self, x): self.foo = self.foo + 1 @@ -656,7 +656,7 @@ def test_in_place_operator(self, functional_call): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer('foo', torch.tensor([0.0])) + self.foo = torch.nn.Buffer(torch.tensor([0.0])) def forward(self, x): self.foo.add_(1) @@ -778,7 +778,7 @@ class Module(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) - self.register_buffer('buffer', torch.ones(1)) + self.buffer = torch.nn.Buffer(torch.ones(1)) def forward(self, x): parameters = tuple(self.parameters()) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ff9438085c529..d842db9ffbcd5 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -609,6 +609,7 @@ def istensor(obj): """Check of obj is a tensor""" tensor_list = ( torch.Tensor, + torch.nn.Buffer, torch.nn.Parameter, *config.traceable_tensor_subclasses, ) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 575ccfa53f8d2..4f387314dd31d 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -357,6 +357,7 @@ def _type_dispatch(cls): ( ( torch.Tensor, + torch.nn.Buffer, torch.nn.Parameter, torch._subclasses.FakeTensor, torch._subclasses.functional_tensor.FunctionalTensor, @@ -1228,6 +1229,7 @@ def wrap_tensor(self, value: torch.Tensor): else: assert type(value) in ( torch.Tensor, + torch.nn.Buffer, torch.nn.Parameter, torch._subclasses.fake_tensor.FakeTensor, torch._subclasses.functional_tensor.FunctionalTensor, @@ -2115,7 +2117,7 @@ def wrap_to_fake_tensor_and_record( e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None ): if ( - type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) + type(e) in (torch.Tensor, torch.nn.Buffer, torch.nn.Parameter, FakeTensor) or isinstance(e, torch.Tensor) or is_traceable_wrapper_subclass(e) ): diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index efeeb446fae1f..c30a924d75a02 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1957,6 +1957,7 @@ def _check_for_subclass_arg(x): and isinstance(x, torch.Tensor) and type(x) is not torch.Tensor and type(x) is not torch.nn.Parameter + and type(x) is not torch.nn.Buffer ) diff --git a/torch/_utils.py b/torch/_utils.py index 2e48fe9a1a9de..080f1379ff26b 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -417,6 +417,17 @@ def _rebuild_qtensor( return tensor +def _rebuild_buffer(data, requires_grad, persistent): + buffer = torch.nn.Buffer(data, requires_grad, persistent) + return buffer + + +def _rebuild_buffer_with_state(data, requires_grad, persistent, state): + buffer = torch.nn.Buffer(data, requires_grad, persistent) + buffer = _set_obj_state(buffer, state) + return buffer + + def _rebuild_parameter(data, requires_grad, backward_hooks): param = torch.nn.Parameter(data, requires_grad) # NB: This line exists only for backwards compatibility; the diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 07905b0348473..65aaa398de4ca 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -297,7 +297,7 @@ def inner(e): def fetch_object_proxy(tracer): return lambda t: get_proxy_slot(t, tracer, t) -HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter, FakeTensor) +HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter, torch.nn.Buffer, FakeTensor) def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs): unrecognized_types = [] diff --git a/torch/nn/__init__.py b/torch/nn/__init__.py index 3d317b7c09f20..b1b898b172fec 100644 --- a/torch/nn/__init__.py +++ b/torch/nn/__init__.py @@ -1,5 +1,6 @@ from .modules import * # noqa: F403 from .parameter import ( + Buffer as Buffer, Parameter as Parameter, UninitializedParameter as UninitializedParameter, UninitializedBuffer as UninitializedBuffer, diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 73420c0f32e7b..fae62f001e41d 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -6,7 +6,7 @@ import torch from torch._prims_common import DeviceLikeType -from ..parameter import Parameter +from ..parameter import Parameter, Buffer import torch.utils.hooks as hooks from torch import Tensor, device, dtype @@ -1753,16 +1753,16 @@ def remove_from(*dicts_or_sets): modules[name] = value else: buffers = self.__dict__.get('_buffers') - if buffers is not None and name in buffers: + if isinstance(value, Buffer) or buffers is not None and name in buffers: if value is not None and not isinstance(value, torch.Tensor): - raise TypeError(f"cannot assign '{torch.typename(value)}' as buffer '{name}' " - "(torch.Tensor or None expected)" - ) - for hook in _global_buffer_registration_hooks.values(): - output = hook(self, name, value) - if output is not None: - value = output - buffers[name] = value + raise TypeError("cannot assign '{}' as buffer '{}' " + "(torch.nn.Buffer, torch.Tensor or None expected)" + .format(torch.typename(value), name)) + if isinstance(value, Buffer): + persistent = value.persistent + else: + persistent = name not in self._non_persistent_buffers_set + self.register_buffer(name, value, persistent) else: super().__setattr__(name, value) diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index 43d4f1cf40008..cf2dae283ccab 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -199,6 +199,74 @@ def __deepcopy__(self, memo): memo[id(self)] = result return result +# Metaclass to combine _TensorMeta and the instance check override for Buffer. +class _BufferMeta(torch._C._TensorMeta): + # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag. + def __instancecheck__(self, instance): + return isinstance(instance, torch.Tensor) and getattr(instance, '_is_buffer', False) + + +class Buffer(torch.Tensor, metaclass=_BufferMeta): + r"""A kind of Tensor that should not be considered a model + parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. + + Buffers are :class:`~torch.Tensor` subclasses, that have a + very special property when used with :class:`Module` s - when they're + assigned as Module attributes they are automatically added to the list of + its buffers, and will appear e.g. in :meth:`~Module.buffers` iterator. + Assigning a Tensor doesn't have such effect. One can still assign a Tensor as explicitly by using + a the modules `~register_buffer` function. + + Args: + data (Tensor): buffer tensor. + requires_grad (bool, optional): if the buffer requires gradient. + Default: `False` + persistent (bool, optional): whether the buffer is part of the module's + :attr:`state_dict`. Default: `True` + """ + def __new__(cls, data=None, requires_grad=False, persistent=True): + if data is None: + data = torch.empty(0) + + # Path for custom tensors: set a flag on the instance to indicate buffer-ness. + t = data.detach().requires_grad_(requires_grad) + if type(t) is not type(data) and not isinstance(data, Parameter): + raise RuntimeError(f"Creating a Buffer from an instance of type {type(data).__name__} " + "requires that detach() returns an instance of the same type, but return " + f"type {type(t).__name__} was found instead. To use the type as a " + "Buffer, please correct the detach() semantics defined by " + "its __torch_dispatch__() implementation.") + t.persistent = persistent + t._is_buffer = True + return t + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad, self.persistent) + memo[id(self)] = result + return result + + def __repr__(self): + return 'Buffer containing:\n' + super().__repr__() + + def __reduce_ex__(self, proto): + state = torch._utils._get_obj_state(self) + + if not state: + return ( + torch._utils._rebuild_buffer, + (self.data, self.requires_grad, self.persistent) + ) + + return ( + torch._utils._rebuild_buffer_with_state, + (self.data, self.requires_grad, self.persistent, state) + ) + + __torch_function__ = _disabled_torch_function_impl + class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): r"""A buffer that is not initialized. @@ -217,7 +285,10 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): cls_to_become = torch.Tensor - def __new__(cls, requires_grad=False, device=None, dtype=None) -> None: + def __new__(cls, requires_grad=False, device=None, dtype=None, persistent=True) -> None: factory_kwargs = {'device': device, 'dtype': dtype} data = torch.empty(0, **factory_kwargs) - return torch.Tensor._make_subclass(cls, data, requires_grad) + ret = torch.Tensor._make_subclass(cls, data, requires_grad) + ret.persistent = persistent + ret._is_buffer = True + return ret diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index 219bb6d4efa2a..9ef33149fadb3 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -26,11 +26,22 @@ class UninitializedParameter(Tensor): dtype: Optional[torch.dtype] = None, ): ... +class Buffer(Tensor): + persistent: builtins.bool + def __init__( + self, + data: Tensor = ..., + requires_grad: builtins.bool = ..., + persistent: builtins.bool = ..., + ): ... + class UninitializedBuffer(Tensor): + persistent: builtins.bool def __init__( self, data: Tensor = ..., requires_grad: builtins.bool = ..., + persistent: builtins.bool = ..., ): ... def materialize( self, diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index c11314721f27c..177bd7a4745ac 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -3970,14 +3970,14 @@ class Layer(nn.Module): def __init__(self): super().__init__() self.layer_dummy_param = nn.Parameter(torch.empty(3, 5)) - self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7)) + self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7)) class Net(nn.Module): def __init__(self): super().__init__() self.l1 = Layer() self.dummy_param = nn.Parameter(torch.empty(3, 5)) - self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1)) + self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1)) l = Layer() n = Net()