Skip to content

Commit

Permalink
[feat] add split_dim arg to reversible, remove retain_grad, add bench…
Browse files Browse the repository at this point in the history
…mark_reversible (#45)

* perf(revnet): remove/fuse cat-split

* style(reversible): run pre-commit, remove unused variables

* feat(tbenchmark): add basic revnet time benchmark

* style(benchmark): run black over revnet bench

* feat(reversible): make split dim a parameter

* revert(reversible): readd original code/remove split fusion

* feat(reversible): readd split dim arg

* perf(reversible): remove retain graph

It's not needed but makes the execution up to 20% slower.

* style(reversible): remove one singular newline
seriously?
  • Loading branch information
ClashLuke committed Nov 1, 2021
1 parent 1fa54fd commit 0ae331a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 7 deletions.
83 changes: 83 additions & 0 deletions xformers/benchmarks/benchmark_revnet.py
@@ -0,0 +1,83 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from typing import Any, Dict

import torch
import triton

from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components.reversible import ReversibleSequence

SHAPES = [(16384, 32), (2048, 256), (128, 4096)]

DEPTH = [4, 32, 256]


def bench_revnet(backward: bool):
device = torch.device("cuda")
bw = "+bw" if backward else ""

for dtype in [torch.float16, torch.float32]:
results: Dict[str, Any] = {}

for B, K in SHAPES:
for depth in DEPTH:
f = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
g = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
revseq = ReversibleSequence(
torch.nn.ModuleList([torch.nn.ModuleList([f, g])] * depth)
)
revseq = revseq.to(device=device, dtype=dtype)

a = torch.rand(
1, B, K, device=device, dtype=dtype, requires_grad=backward
)
b = torch.rand(
1, B, K * 2, device=device, dtype=dtype, requires_grad=backward
)

def normal_step():
y = a
for _ in range(depth):
y = y + f(y)
y = y + g(y)
if backward:
torch.norm(y).backward()
return y

def reversible_step():
y = revseq(b)
if backward:
torch.norm(y).backward()
return y

for testcase in [
TestCase(normal_step, f"residual - fw{bw}"),
TestCase(reversible_step, f"reversible - fw{bw}"),
]:
time = triton.testing.do_bench(testcase.function)[0]
key = f"Batch={B}, Features={K}, Depth={depth}"
if key not in results:
results[key] = {}

results[key][testcase.name] = f"{time:.2f}"

pretty_print(
results,
title=f"\n --- Type: {dtype} --- ",
units="runtime in ms, lower is better",
)
pretty_plot(
results,
title=f"RevNet-FW{bw}-{dtype}",
units="runtime in ms, lower is better",
dash_key="pytorch",
)


for bw in [False, True]:
bench_revnet(bw)
15 changes: 8 additions & 7 deletions xformers/components/reversible.py
Expand Up @@ -52,10 +52,11 @@ def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwar


class ReversibleBlock(nn.Module):
def __init__(self, f: nn.Module, g: nn.Module):
def __init__(self, f: nn.Module, g: nn.Module, split_dim: int = -1):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
self.split_dim = split_dim

def forward(self, x: torch.Tensor, f_args={}, g_args={}):
x1, x2 = torch.chunk(x, 2, dim=-1)
Expand All @@ -65,13 +66,13 @@ def forward(self, x: torch.Tensor, f_args={}, g_args={}):
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)

return torch.cat([y1, y2], dim=-1)
return torch.cat([y1, y2], dim=self.split_dim)

def backward_pass(self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={}):
y1, y2 = torch.chunk(y, 2, dim=-1)
y1, y2 = torch.chunk(y, 2, dim=self.split_dim)
del y

dy1, dy2 = torch.chunk(dy, 2, dim=-1)
dy1, dy2 = torch.chunk(dy, 2, dim=self.split_dim)
del dy

with torch.enable_grad():
Expand All @@ -90,7 +91,7 @@ def backward_pass(self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={})
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
torch.autograd.backward(fx2, dx1)

with torch.no_grad():
x1 = y1 - fx2
Expand All @@ -100,8 +101,8 @@ def backward_pass(self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={})
del dy2
x2.grad = None

x = torch.cat([x1, x2.detach()], dim=2)
dx = torch.cat([dx1, dx2], dim=2)
x = torch.cat([x1, x2.detach()], dim=self.split_dim)
dx = torch.cat([dx1, dx2], dim=self.split_dim)

return x, dx

Expand Down

0 comments on commit 0ae331a

Please sign in to comment.