Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Ops of StyleGAN3 #2290

Merged
merged 53 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
2444457
add bias_act
plyfager Sep 27, 2022
4519e33
support bias_act
plyfager Oct 24, 2022
b62654f
support filtered_lrelu
plyfager Oct 26, 2022
a1c5879
resolve conflict
plyfager Oct 26, 2022
4307a51
support filtered_lrelu and upfirdn2d
plyfager Oct 27, 2022
6445745
support conv2d_gradfix and fix filtered_lrelu
plyfager Oct 31, 2022
ad1eea0
resolve conflict
plyfager Oct 31, 2022
596cb5c
fix lint
plyfager Oct 31, 2022
44f1aee
fix lint
plyfager Nov 2, 2022
d0706a1
fix c++ lint
plyfager Nov 2, 2022
c90de37
fix part comments
plyfager Nov 7, 2022
fb38992
fix lint
plyfager Nov 7, 2022
99dbb73
rm redundant header
plyfager Nov 7, 2022
f87b882
fix upgrade pip
plyfager Nov 7, 2022
6ef5039
fix as comment
plyfager Nov 9, 2022
2d354c8
fix c++ lint
plyfager Nov 9, 2022
5dcb280
fix ci
plyfager Nov 9, 2022
eff99f8
fix-ut
plyfager Nov 9, 2022
82ca81a
fix as comments
plyfager Nov 11, 2022
6810e81
add grad check
plyfager Nov 11, 2022
39cedd5
remove redundant template
plyfager Nov 11, 2022
56302d0
Update mmcv/ops/bias_act.py
plyfager Nov 15, 2022
ed90c57
add typehint
plyfager Nov 15, 2022
ea06b45
Merge branch 'plyfager/s3-ops' of github.com:plyfager/mmcv into plyfa…
plyfager Nov 15, 2022
cb8356f
fix as comment:
plyfager Nov 15, 2022
37b8044
complete type hints
plyfager Nov 15, 2022
2035b50
fix lint
plyfager Nov 15, 2022
309d4de
add test for conv_gradfix
plyfager Nov 16, 2022
2af887b
add test for conv_gradfix
plyfager Nov 16, 2022
2249194
fix lint
plyfager Nov 16, 2022
69ef946
modify licenses and ops.md
plyfager Nov 16, 2022
8f2e75b
add zh op md
plyfager Nov 16, 2022
3cdd3e9
add torch version policy for conv2d_gradfix
plyfager Nov 16, 2022
58c4560
fix lint
plyfager Nov 16, 2022
06ee7a0
fix as comments
plyfager Nov 17, 2022
81f927b
rename impl
plyfager Nov 18, 2022
bbc5276
resolve conflict
plyfager Dec 5, 2022
1256efc
rm redudant function and add ut
plyfager Dec 9, 2022
b000977
Merge branch '2.x' of github.com:open-mmlab/mmcv into plyfager/s3-ops
plyfager Dec 9, 2022
e39a276
Merge branch '2.x' of github.com:open-mmlab/mmcv into plyfager/s3-ops
plyfager Jan 3, 2023
9c1a2eb
fix as comment
plyfager Jan 3, 2023
b4d74a0
fix lint
plyfager Jan 4, 2023
00ed8f1
fix lint
plyfager Jan 4, 2023
afab2e3
Merge branch '2.x' of github.com:open-mmlab/mmcv into plyfager/s3-ops
plyfager Jan 9, 2023
c1b1d0d
fix as comments
plyfager Feb 13, 2023
81ee582
resolve conflict
plyfager Feb 13, 2023
c8bc875
fix lint
plyfager Feb 13, 2023
c5ace42
Merge branch '2.x' of github.com:open-mmlab/mmcv into plyfager/s3-ops
plyfager Feb 28, 2023
ec3fa61
fix ut
plyfager Feb 28, 2023
b9e1564
fix as comment
plyfager Mar 3, 2023
e1ef9aa
Merge branch '2.x' of github.com:open-mmlab/mmcv into plyfager/s3-ops
plyfager Mar 13, 2023
1b60d8f
fix as comment
plyfager Mar 13, 2023
fb9d709
fix as comment
plyfager Mar 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from .assign_score_withk import assign_score_withk
from .ball_query import ball_query
from .bbox import bbox_overlaps
from .bias_act import bias_act
from .border_align import BorderAlign, border_align
from .box_iou_quadri import box_iou_quadri
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .chamfer_distance import chamfer_distance
from .contour_expand import contour_expand
from .conv2d_gradfix import conv2d, conv_transpose2d
from .convex_iou import convex_giou, convex_iou
from .corner_pool import CornerPool
from .correlation import Correlation
Expand All @@ -21,6 +23,7 @@
from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d
from .filtered_lrelu import filtered_lrelu
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .furthest_point_sample import (furthest_point_sample,
Expand Down Expand Up @@ -67,7 +70,7 @@
from .three_interpolate import three_interpolate
from .three_nn import three_nn
from .tin_shift import TINShift, tin_shift
from .upfirdn2d import upfirdn2d
from .upfirdn2d import filter2d, upfirdn2d, upsample2d
from .voxelize import Voxelization, voxelization

__all__ = [
Expand Down Expand Up @@ -102,5 +105,6 @@
'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
'PrRoIPool', 'prroi_pool'
'PrRoIPool', 'prroi_pool', 'bias_act', 'filtered_lrelu', 'conv2d',
'conv_transpose2d', 'filter2d', 'upsample2d'
]
303 changes: 303 additions & 0 deletions mmcv/ops/bias_act.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
# Modified from
# https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/bias_act.py

# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

# source: https://github.com/open-mmlab/mmediting/blob/dev-1.x/mmedit/models/editors/stylegan3/stylegan3_ops/ops/bias_act.py # noqa
"""Custom PyTorch ops for efficient bias and activation."""

from typing import Any, Dict

import numpy as np
import torch

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['bias_act'])


class EasyDict(dict):
"""Convenience class that behaves like a dict but allows access with the
attribute syntax."""

def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)

def __setattr__(self, name: str, value: Any) -> None:
self[name] = value

def __delattr__(self, name: str) -> None:
del self[name]


activation_funcs = {
'linear':
EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
cuda_idx=1,
ref='',
has_2nd_grad=False),
'relu':
EasyDict(
func=lambda x, **_: torch.nn.functional.relu(x),
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=2,
ref='y',
has_2nd_grad=False),
'lrelu':
EasyDict(
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
def_alpha=0.2,
def_gain=np.sqrt(2),
cuda_idx=3,
ref='y',
has_2nd_grad=False),
'tanh':
EasyDict(
func=lambda x, **_: torch.tanh(x),
def_alpha=0,
def_gain=1,
cuda_idx=4,
ref='y',
has_2nd_grad=True),
'sigmoid':
EasyDict(
func=lambda x, **_: torch.sigmoid(x),
def_alpha=0,
def_gain=1,
cuda_idx=5,
ref='y',
has_2nd_grad=True),
'elu':
EasyDict(
func=lambda x, **_: torch.nn.functional.elu(x),
def_alpha=0,
def_gain=1,
cuda_idx=6,
ref='y',
has_2nd_grad=True),
'selu':
EasyDict(
func=lambda x, **_: torch.nn.functional.selu(x),
def_alpha=0,
def_gain=1,
cuda_idx=7,
ref='y',
has_2nd_grad=True),
'softplus':
EasyDict(
func=lambda x, **_: torch.nn.functional.softplus(x),
def_alpha=0,
def_gain=1,
cuda_idx=8,
ref='y',
has_2nd_grad=True),
'swish':
EasyDict(
func=lambda x, **_: torch.sigmoid(x) * x,
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=9,
ref='x',
has_2nd_grad=True),
}

_plugin = None
plyfager marked this conversation as resolved.
Show resolved Hide resolved
_null_tensor = torch.empty([0])


def bias_act(x,
b=None,
dim=1,
act='linear',
alpha=None,
gain=None,
clamp=None,
impl='cuda'):
plyfager marked this conversation as resolved.
Show resolved Hide resolved
r"""Fused bias and activation function.
Adds bias `b` to activation tensor `x`, evaluates activation function
plyfager marked this conversation as resolved.
Show resolved Hide resolved
`act`, and scales the result by `gain`. Each of the steps is optional.
In most cases, the fused op is considerably more efficient than performing
the same calculation using standard PyTorch ops. It supports first and
second order gradients, but not third order gradients.
plyfager marked this conversation as resolved.
Show resolved Hide resolved

Args:
x: Input activation tensor. Can be of any shape.
b: Bias vector, or `None` to disable. Must be a 1D tensor of the
same type as `x`. The shape must be known, and it must match
the dimension of `x` corresponding to `dim`.
dim: The dimension in `x` corresponding to the elements of `b`.
The value of `dim` is ignored if `b` is not specified.
act: Name of the activation function to evaluate, or `"linear"` to
disable. Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`,
`"sigmoid"`, `"swish"`, etc. See `activation_funcs` for a full
list. `None` is not allowed.
alpha: Shape parameter for the activation function, or `None` to use
the default.
gain: Scaling factor for the output tensor, or `None` to use default.
See `activation_funcs` for the default scaling of each
activation function. If unsure, consider specifying 1.
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to
disable the clamping (default).
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"`
(default).
plyfager marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Tensor of the same shape and datatype as `x`.
"""
assert isinstance(x, torch.Tensor)
assert impl in ['ref', 'cuda']
if impl == 'cuda' and x.device.type == 'cuda':
plyfager marked this conversation as resolved.
Show resolved Hide resolved
return _bias_act_cuda(
dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
return _bias_act_ref(
x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)


def _bias_act_ref(x,
b=None,
dim=1,
act='linear',
alpha=None,
gain=None,
clamp=None):
plyfager marked this conversation as resolved.
Show resolved Hide resolved
"""Slow reference implementation of `bias_act()` using standard TensorFlow
plyfager marked this conversation as resolved.
Show resolved Hide resolved
ops."""
assert isinstance(x, torch.Tensor)
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)

# Add bias.
if b is not None:
assert isinstance(b, torch.Tensor) and b.ndim == 1
assert 0 <= dim < x.ndim
assert b.shape[0] == x.shape[dim]
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])

# Evaluate activation function.
alpha = float(alpha)
x = spec.func(x, alpha=alpha)

# Scale by gain.
gain = float(gain)
if gain != 1:
x = x * gain

# Clamp.
if clamp >= 0:
# pylint: disable=invalid-unary-operand-type
x = x.clamp(-clamp, clamp)
return x


_bias_act_cuda_cache: Dict = dict()


def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
"""Fast CUDA implementation of `bias_act()` using custom ops."""
# Parse arguments.
assert clamp is None or clamp >= 0
spec = activation_funcs[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)

# Lookup from cache.
key = (dim, act, alpha, gain, clamp)
if key in _bias_act_cuda_cache:
return _bias_act_cuda_cache[key]

# Forward op.
class BiasActCuda(torch.autograd.Function):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def forward(ctx, x, b): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(
1) == 1 else torch.contiguous_format
x = x.contiguous(memory_format=ctx.memory_format)
b = b.contiguous() if b is not None else _null_tensor.to(x.device)
y = x
if act != 'linear' or gain != 1 or clamp >= 0 or (
b is not _null_tensor.to(x.device)):
y = ext_module.bias_act(x, b, _null_tensor.to(x.device),
_null_tensor.to(x.device),
_null_tensor.to(x.device), 0, dim,
spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor.to(
x.device), b if 'x' in spec.ref or spec.has_2nd_grad else
_null_tensor.to(x.device),
y if 'y' in spec.ref else _null_tensor.to(x.device))
return y

@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
plyfager marked this conversation as resolved.
Show resolved Hide resolved
dy = dy.contiguous(memory_format=ctx.memory_format)
x, b, y = ctx.saved_tensors
dx = None
db = None

if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dx = dy
if act != 'linear' or gain != 1 or clamp >= 0:
dx = BiasActCudaGrad.apply(dy, x, b, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be called? BiasActCudaGrad.forward or BiasActCudaGrad.backward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BiasActCudaGrad.forward since this is the first order grad.


if ctx.needs_input_grad[1]:
db = dx.sum([i for i in range(dx.ndim) if i != dim])

return dx, db

# Backward op.
class BiasActCudaGrad(torch.autograd.Function):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if dy.ndim > 2 and (
dy.stride(1) == 1) else torch.contiguous_format
dx = ext_module.bias_act(dy, b, x, y, _null_tensor.to(x.device), 1,
dim, spec.cuda_idx, alpha, gain, clamp)
ctx.save_for_backward(
dy if spec.has_2nd_grad else _null_tensor.to(x.device), x, b,
y)
return dx

@staticmethod
def backward(ctx, d_dx): # pylint: disable=arguments-differ
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
dy, x, b, y = ctx.saved_tensors
d_dy = None
d_x = None
d_b = None
d_y = None

if ctx.needs_input_grad[0]:
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)

if spec.has_2nd_grad and (ctx.needs_input_grad[1]
or ctx.needs_input_grad[2]):
d_x = ext_module.bias_act(d_dx, b, x, y, dy, 2, dim,
spec.cuda_idx, alpha, gain, clamp)

if spec.has_2nd_grad and ctx.needs_input_grad[2]:
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])

return d_dy, d_x, d_b, d_y

# Add to cache.
_bias_act_cuda_cache[key] = BiasActCuda
return BiasActCuda
Loading