In [1]:
import torch
import torch.nn.functional as F

from torch import nn, Tensor
from typing import Any, Callable, Optional

In [2]:
def _modulate(weight: Tensor, style: Tensor, transposed=False) -> Tensor:
    w = weight[None, :]  # batch dim
    # out channels dim: 1 - normal, 2 - transposed
    C_out = 1 + int(transposed)
    s = style.unsqueeze(C_out)
    s = s[:, :, :, None, None]
    return w * s


def _demodulate(w: Tensor, eps=1e-8, inplace=False, transposed=False) -> Tensor:
    # in channels dim: 2 - normal, 1 - transposed
    C_in = 2 - int(transposed)
    if inplace:
        d = torch.rsqrt_(w.pow(2).sum(dim=(C_in, 3, 4), keepdim=True).add_(eps))
    else:
        d = torch.rsqrt(w.pow(2).sum(dim=(C_in, 3, 4), keepdim=True) + eps)
    return w * d

In [3]:
w1 = nn.Parameter(torch.randn(4, 2, 10, 10))

In [4]:
w2 = nn.Parameter(w1.clone().detach())

In [5]:
s = torch.randn(3,2)

s1 = nn.Parameter(s.clone().detach())
s2 = nn.Parameter(s.clone().detach())

In [6]:
out1 = _demodulate(_modulate(w1, s1), inplace=False)

In [7]:
out2 = _demodulate(_modulate(w2, s2), inplace=True)

In [8]:
out1.mean().backward()
# w1.grad

In [9]:
out2.mean().backward()
# w2.grad

In [10]:
s1.grad, s2.grad

(tensor([[-7.4626e-04,  5.4799e-04],
         [ 4.5596e-05, -4.7382e-06],
         [ 4.0299e-04,  4.2420e-04]]),
 tensor([[-7.4626e-04,  5.4799e-04],
         [ 4.5596e-05, -4.7382e-06],
         [ 4.0299e-04,  4.2420e-04]]))

In [11]:
(s1.grad == s2.grad).all().item()

True

In [12]:
(w1.grad == w2.grad).all().item()

True

In [13]:
import torch.autograd.profiler as P

In [14]:
def benchmark(n_iter, inplace, n_warm=100, N=32, C0=16, C1=32, device=None):
    def _generate_w_s():
        w = torch.randn(C1, C0, 3, 3, requires_grad=True, device=device)
        s = torch.randn(N, C0, requires_grad=True, device=device)
        return w, s
    
    for _ in range(n_warm):
        w, s = _generate_w_s()
        out = _demodulate(_modulate(w, s), inplace=inplace).mean()
        out.backward()
    
    with P.profile(use_cuda=True) as prof:
        for _ in range(n_iter):
            w, s = _generate_w_s()
            
            with P.record_function("forward"):
                out = _demodulate(_modulate(w, s), inplace=inplace).mean()
                
            with P.record_function("backward"):
                out.backward()
                
    print(prof.key_averages().table(sort_by="cuda_time"))

In [15]:
device = torch.device('cuda:1')

In [16]:
benchmark(10_000, False, device=device)

------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                  Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CUDA total %     CUDA total       CUDA time avg    Number of Calls  
------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
backward                              47.39%           8.466s           47.71%           8.523s           852.301us        34.82%           8.523s           852.312us        10000            
forward                               5.91%            1.055s           17.15%           3.065s           306.483us        12.52%           3.064s           306.376us        10000            
PowBackward0                          0

In [17]:
benchmark(10_000, True, device=device)

------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                  Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CUDA total %     CUDA total       CUDA time avg    Number of Calls  
------------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
backward                              47.45%           8.114s           47.70%           8.157s           815.730us        34.71%           8.171s           817.141us        10000            
forward                               6.17%            1.055s           16.60%           2.839s           283.886us        12.02%           2.830s           282.955us        10000            
PowBackward0                          0