From 0ae331affe5665ec197ec8fcb8476d41e084239b Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Mon, 1 Nov 2021 10:35:39 -0700 Subject: [PATCH] [feat] add split_dim arg to reversible, remove retain_grad, add benchmark_reversible (#45) * perf(revnet): remove/fuse cat-split * style(reversible): run pre-commit, remove unused variables * feat(tbenchmark): add basic revnet time benchmark * style(benchmark): run black over revnet bench * feat(reversible): make split dim a parameter * revert(reversible): readd original code/remove split fusion * feat(reversible): readd split dim arg * perf(reversible): remove retain graph It's not needed but makes the execution up to 20% slower. * style(reversible): remove one singular newline seriously? --- xformers/benchmarks/benchmark_revnet.py | 83 +++++++++++++++++++++++++ xformers/components/reversible.py | 15 ++--- 2 files changed, 91 insertions(+), 7 deletions(-) create mode 100644 xformers/benchmarks/benchmark_revnet.py diff --git a/xformers/benchmarks/benchmark_revnet.py b/xformers/benchmarks/benchmark_revnet.py new file mode 100644 index 000000000..512a861b2 --- /dev/null +++ b/xformers/benchmarks/benchmark_revnet.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components.reversible import ReversibleSequence + +SHAPES = [(16384, 32), (2048, 256), (128, 4096)] + +DEPTH = [4, 32, 256] + + +def bench_revnet(backward: bool): + device = torch.device("cuda") + bw = "+bw" if backward else "" + + for dtype in [torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for B, K in SHAPES: + for depth in DEPTH: + f = torch.nn.Linear(K, K).to(device=device, dtype=dtype) + g = torch.nn.Linear(K, K).to(device=device, dtype=dtype) + revseq = ReversibleSequence( + torch.nn.ModuleList([torch.nn.ModuleList([f, g])] * depth) + ) + revseq = revseq.to(device=device, dtype=dtype) + + a = torch.rand( + 1, B, K, device=device, dtype=dtype, requires_grad=backward + ) + b = torch.rand( + 1, B, K * 2, device=device, dtype=dtype, requires_grad=backward + ) + + def normal_step(): + y = a + for _ in range(depth): + y = y + f(y) + y = y + g(y) + if backward: + torch.norm(y).backward() + return y + + def reversible_step(): + y = revseq(b) + if backward: + torch.norm(y).backward() + return y + + for testcase in [ + TestCase(normal_step, f"residual - fw{bw}"), + TestCase(reversible_step, f"reversible - fw{bw}"), + ]: + time = triton.testing.do_bench(testcase.function)[0] + key = f"Batch={B}, Features={K}, Depth={depth}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{time:.2f}" + + pretty_print( + results, + title=f"\n --- Type: {dtype} --- ", + units="runtime in ms, lower is better", + ) + pretty_plot( + results, + title=f"RevNet-FW{bw}-{dtype}", + units="runtime in ms, lower is better", + dash_key="pytorch", + ) + + +for bw in [False, True]: + bench_revnet(bw) diff --git a/xformers/components/reversible.py b/xformers/components/reversible.py index 59fcf1ed7..0e05a0f15 100644 --- a/xformers/components/reversible.py +++ b/xformers/components/reversible.py @@ -52,10 +52,11 @@ def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwar class ReversibleBlock(nn.Module): - def __init__(self, f: nn.Module, g: nn.Module): + def __init__(self, f: nn.Module, g: nn.Module, split_dim: int = -1): super().__init__() self.f = Deterministic(f) self.g = Deterministic(g) + self.split_dim = split_dim def forward(self, x: torch.Tensor, f_args={}, g_args={}): x1, x2 = torch.chunk(x, 2, dim=-1) @@ -65,13 +66,13 @@ def forward(self, x: torch.Tensor, f_args={}, g_args={}): y1 = x1 + self.f(x2, record_rng=self.training, **f_args) y2 = x2 + self.g(y1, record_rng=self.training, **g_args) - return torch.cat([y1, y2], dim=-1) + return torch.cat([y1, y2], dim=self.split_dim) def backward_pass(self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={}): - y1, y2 = torch.chunk(y, 2, dim=-1) + y1, y2 = torch.chunk(y, 2, dim=self.split_dim) del y - dy1, dy2 = torch.chunk(dy, 2, dim=-1) + dy1, dy2 = torch.chunk(dy, 2, dim=self.split_dim) del dy with torch.enable_grad(): @@ -90,7 +91,7 @@ def backward_pass(self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={}) with torch.enable_grad(): x2.requires_grad = True fx2 = self.f(x2, set_rng=True, **f_args) - torch.autograd.backward(fx2, dx1, retain_graph=True) + torch.autograd.backward(fx2, dx1) with torch.no_grad(): x1 = y1 - fx2 @@ -100,8 +101,8 @@ def backward_pass(self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={}) del dy2 x2.grad = None - x = torch.cat([x1, x2.detach()], dim=2) - dx = torch.cat([dx1, dx2], dim=2) + x = torch.cat([x1, x2.detach()], dim=self.split_dim) + dx = torch.cat([dx1, dx2], dim=self.split_dim) return x, dx