Skip to content

Commit

Permalink
Make adding Buffers more like adding Parameters
Browse files Browse the repository at this point in the history
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new Buffer class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the register_buffer method has not been changed. The persistent parameter in the Buffer type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new Buffer type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the Buffer type can be used as a drop in replacement for register_buffer as it just leads to register_buffer being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.

Fixes pytorch#35735
  • Loading branch information
ekamiti committed May 10, 2024
1 parent e43d656 commit 533c3c8
Show file tree
Hide file tree
Showing 39 changed files with 263 additions and 151 deletions.
1 change: 1 addition & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ These are the basic building blocks for graphs:
:nosignatures:
:template: classtemplate.rst

~parameter.Buffer
~parameter.Parameter
~parameter.UninitializedParameter
~parameter.UninitializedBuffer
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/fsdp/test_fsdp_flatten_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_transformer(self, seed=0):
dim_feedforward=128,
dropout=0.1,
)
module.register_buffer("dummy_buffer", torch.tensor(1.0))
module.dummy_buffer = nn.Buffer(torch.tensor(1.0))

def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
Expand Down
10 changes: 5 additions & 5 deletions test/distributed/fsdp/test_fsdp_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,11 +594,11 @@ def init_nested_wrapped_module():

# Check that `device_id` with `sync_module_states=True` works
nested_wrapped_module = init_nested_wrapped_module()
nested_wrapped_module.register_buffer(
"buf", torch.ones((2, 2), device="cpu") * self.rank
nested_wrapped_module.buf = nn.Buffer(
torch.ones((2, 2), device="cpu") * self.rank
)
nested_wrapped_module.module[0].register_buffer(
"buf", torch.ones((3, 2), device="cpu") * self.rank
nested_wrapped_module.module[0].buf = nn.Buffer(
torch.ones((3, 2), device="cpu") * self.rank
)
nested_wrapped_module = FSDP(
nested_wrapped_module,
Expand Down Expand Up @@ -892,7 +892,7 @@ def __init__(self, rank):
torch.manual_seed(rank)
torch.cuda.manual_seed(rank)
self.lin = nn.Linear(10, 10, bias=False)
self.register_buffer("buffer", torch.ones(1) * rank)
self.buffer = nn.Buffer(torch.ones(1) * rank)

m = MyModel(self.rank).cuda()
_assert_module_states(
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/fsdp/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
super().__init__()
self.inner = Linear(*INNER_SHAPE)
if register_buffers:
self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
self.inner.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE))
self.inner.register_buffer(
"non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
)
Expand All @@ -121,7 +121,7 @@ def __init__(
)
self.outer = Linear(*OUTER_SHAPE)
if register_buffers:
self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
self.outer.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE))
self.outer.register_buffer(
"non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
)
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/fsdp/test_fsdp_unshard_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def _test_named_parameters_and_buffers(self, prefix: str, recurse: bool):
CUDAInitMode.CUDA_BEFORE,
deterministic=True,
)
model.register_buffer("buffer", torch.ones(1))
model.buffer = nn.Buffer(torch.ones(1))
# Wrap the top-level with FSDP since `named_parameters()` and
# `named_buffers` will contain FSDP prefixes if called on a non-FSDP
# root module
Expand All @@ -449,7 +449,7 @@ def _test_named_parameters_and_buffers(self, prefix: str, recurse: bool):
),
self.process_group,
)
fsdp_model.register_buffer("buffer", torch.ones(1))
fsdp_model.buffer = nn.Buffer(torch.ones(1))
with FSDP.summon_full_params(fsdp_model):
for call in ["named_parameters", "named_buffers"]:
for (n1, p1), (n2, p2) in itertools.zip_longest(
Expand Down
3 changes: 1 addition & 2 deletions test/distributed/optim/test_zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,7 @@ def test_local_optimizer_parity(
torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
).to(self.device)
model.register_buffer(
"test_buffer",
model.test_buffer = torch.nn.Buffer(
torch.ones((1), device=self.device) * self.rank,
)
# Define models/optimizers for DDP with ZeRO and DDP with local
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def test_data_parallel_buffers_requiring_grad(self):
class TestModule(nn.Module):
def __init__(self, t):
super().__init__()
self.register_buffer("t_rg", t)
self.register_buffer("t_not_rg", t.clone().detach())
self.t_rg = nn.Buffer(t, t.requires_grad)
self.t_not_rg = nn.Buffer(t.clone().detach())

def forward(self, x):
return x * self.t_rg + self.t_not_rg
Expand Down
11 changes: 5 additions & 6 deletions test/distributed/test_dynamo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,9 +1159,8 @@ class DuplicateModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self._param = torch.randn((3,), device="cuda")
self.register_buffer(
"_buf", torch.randn((3,), requires_grad=False, device="cuda")
)
self._buf = torch.nn.Buffer(
torch.randn((3,), requires_grad=False, device="cuda"))

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Use `_param` and `_buf` each twice in this compiled forward
Expand Down Expand Up @@ -1190,8 +1189,8 @@ def test_fsdp_dup_tensors_diff_source(self):
class BufModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"_buf", torch.randn((3,), requires_grad=False, device="cuda")
self._buf = nn.Buffer(
torch.randn((3,), requires_grad=False, device="cuda")
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -1203,7 +1202,7 @@ def __init__(self) -> None:
self._param = nn.Parameter(torch.randn((1,), device="cuda"))
self._buf_module = BufModule()
# Share the buffer, meaning same tensor but different source
self.register_buffer("_buf", self._buf_module._buf)
self._buf = self._buf_module._buf

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Use the same buffer tensor twice in the compiled forward,
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ class MyBlock(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(1, 1))
self.register_buffer("buffer", torch.ones(1, 1))
self.buffer = torch.nn.Buffer(torch.ones(1, 1))

def forward(self, x):
x = torch.nn.functional.linear(x, torch.randn(4, 4))
Expand Down Expand Up @@ -3013,7 +3013,7 @@ def test_not_functionalize(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 2))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 2))

def forward(self, x):
x.add_(2)
Expand Down
6 changes: 3 additions & 3 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
self.buf0 = torch.nn.Buffer(torch.randn(10, 10))

def forward(self, x):
return self.relu(self.linear(x) + self.buf0)
Expand Down Expand Up @@ -1802,7 +1802,7 @@ class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
self.buf0 = torch.nn.Buffer(torch.randn(10, 10))

def forward(self, x):
return self.r(torch.sin(x)) + self.buf0
Expand All @@ -1829,7 +1829,7 @@ class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
self.register_buffer("buf0", torch.nn.Buffer(torch.randn(10, 10)))
self.register_parameter(
name="param0", param=torch.nn.Parameter(torch.randn(10, 10))
)
Expand Down
10 changes: 5 additions & 5 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,8 +2055,8 @@ def test_sort_out2(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("sorted", torch.ones(4, 4))
self.register_buffer("indices", torch.ones(4, 4, dtype=torch.long))
self.sorted = torch.nn.Buffer(torch.ones(4, 4))
self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long))

def forward(self, x):
torch.sort(x, out=(self.sorted, self.indices))
Expand Down Expand Up @@ -2087,7 +2087,7 @@ def test_sigmoid_out2(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("base", torch.ones(4, 4))
self.base = torch.nn.Buffer(torch.ones(4, 4))

def forward(self, x):
torch.sigmoid(x, out=self.base)
Expand Down Expand Up @@ -2395,8 +2395,8 @@ def test_named_buffers(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("x", torch.ones(3))
self.register_buffer("y", torch.ones(3))
self.x = torch.nn.Buffer(torch.ones(3))
self.y = torch.nn.Buffer(torch.ones(3))

def forward(self, inp):
res = 0
Expand Down
8 changes: 4 additions & 4 deletions test/export/test_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def __init__(self):

self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))

self.register_buffer("my_buffer1", torch.tensor(3.0))
self.register_buffer("my_buffer2", torch.tensor(4.0))
self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))

def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
Expand All @@ -189,8 +189,8 @@ def __init__(self):

self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))

self.register_buffer("my_buffer1", torch.tensor(3.0))
self.register_buffer("my_buffer2", torch.tensor(4.0))
self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))

def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
Expand Down
6 changes: 3 additions & 3 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3225,7 +3225,7 @@ def test_real_weights_in_symbolic_mode_with_inplace_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.ones(4, 5))
self.buffer = torch.nn.Buffer(torch.ones(4, 5))

def forward(self, x):
y = self.buffer.add_(3)
Expand Down Expand Up @@ -4121,7 +4121,7 @@ def test_aot_export_forward_mutation_no_buffer_mut(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 4))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))

def forward(self, x):
x.add_(4)
Expand All @@ -4147,7 +4147,7 @@ def test_aot_export_forward_mutation_multiple_mut(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 4))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))

def forward(self, x, y):
y.add_(4)
Expand Down
12 changes: 6 additions & 6 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3688,8 +3688,8 @@ def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.randn(3))
self.linear = nn.Linear(3, 3)
self.register_buffer("buffer", torch.randn(3))
self.register_buffer("buffer_tied", self.buffer)
self.buffer = nn.Buffer(torch.randn(3))
self.buffer_tied = self.buffer

def forward(self, x):
x = self.linear(x)
Expand Down Expand Up @@ -3719,7 +3719,7 @@ class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.register_buffer("buffer", torch.randn(3))
self.buffer = nn.Buffer(torch.randn(3))

def forward(self, x):
x = self.linear(x)
Expand All @@ -3741,7 +3741,7 @@ class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.register_buffer("buffer", torch.randn(3))
self.buffer = nn.Buffer(torch.randn(3))

def forward(self, x):
x = self.linear(x)
Expand Down Expand Up @@ -3801,8 +3801,8 @@ def __init__(self):
self.linear = nn.Linear(3, 3)
self.weight = self.linear.weight
self.bias = self.linear.bias
self.register_buffer("buffer", torch.randn(3))
self.register_buffer("buffer_tied", self.buffer)
self.buffer = nn.Buffer(torch.randn(3))
self.buffer_tied = self.buffer

def forward(self, x):
x = self.linear(x)
Expand Down
3 changes: 1 addition & 2 deletions test/inductor/test_cuda_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,7 @@ def __init__(self):
start = math.log2(0.5)
end = math.log2(1 / (2**8))

self.register_buffer(
"scales",
self.scales = nn.Buffer(
2
** torch.arange(
start,
Expand Down
4 changes: 1 addition & 3 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8403,9 +8403,7 @@ def fn(x, p1, p0):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"_tensor_constant0", torch.randn([], dtype=torch.float32)
)
self._tensor_constant0 = nn.Buffer(torch.randn([], dtype=torch.float32))

def forward(self, arg0_1, arg1_1):
convert_element_type = torch.ops.prims.convert_element_type.default(
Expand Down
4 changes: 3 additions & 1 deletion test/jit/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def __init__(self):

self.parameter_b = torch.nn.Parameter(torch.randn(4))
self.submodule_b = Submodule()
self.buffer_b = torch.nn.Buffer(torch.randn(4))

m = TestModule()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
Expand Down Expand Up @@ -526,7 +527,7 @@ def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 3, device="meta")
self.bar = torch.nn.Linear(3, 4)
self.register_buffer("buffer", torch.randn(4, device="meta"))
self.buffer = torch.nn.Buffer(torch.randn(4, device="meta"))

def forward(self, x):
x = self.foo(x)
Expand Down Expand Up @@ -1145,6 +1146,7 @@ def __init__(self):

self.parameter_b = torch.nn.Parameter(torch.randn(4))
self.submodule_b = Submodule()
self.buffer_b = torch.nn.Buffer(torch.randn(4))

m = TestModule()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
Expand Down

0 comments on commit 533c3c8

Please sign in to comment.