In [1]:
import torch
import torch.nn.functional as F

from torch import nn, Tensor
from typing import Any, Callable, Optional

In [2]:
w1 = nn.Parameter(torch.randn(4, 2, 10, 10))

In [3]:
w2 = nn.Parameter(w1.clone().detach())

In [4]:
def _modulate(weight: Tensor, style: Tensor, transposed=False) -> Tensor:
    w = weight[None, :]  # batch dim
    # out channels dim: 1 - normal, 2 - transposed
    C_out = 1 + int(transposed)
    s = style.unsqueeze(C_out)
    s = s[:, :, :, None, None]
    return w * s


def _demodulate(w: Tensor, eps=1e-8, inplace=False, transposed=False) -> Tensor:
    # in channels dim: 2 - normal, 1 - transposed
    C_in = 2 - int(transposed)
    if inplace:
        d = torch.rsqrt_(w.pow(2).sum(dim=(C_in, 3, 4), keepdim=True).add_(eps))
    else:
        d = torch.rsqrt(w.pow(2).sum(dim=(C_in, 3, 4), keepdim=True) + eps)
    return w * d

In [5]:
s = torch.randn(3,2)

s1 = nn.Parameter(s.clone().detach())
s2 = nn.Parameter(s.clone().detach())

In [6]:
out1 = _demodulate(_modulate(w1, s1), inplace=False)

In [7]:
out2 = _demodulate(_modulate(w2, s2), inplace=True)

In [8]:
out1.mean().backward()
# w1.grad

In [9]:
out2.mean().backward()
# w2.grad

In [10]:
s1.grad, s2.grad

(tensor([[-4.4983e-04,  7.0795e-04],
         [-3.7993e-05, -3.1789e-04],
         [-7.0059e-04, -1.7370e-03]]),
 tensor([[-4.4983e-04,  7.0795e-04],
         [-3.7993e-05, -3.1789e-04],
         [-7.0059e-04, -1.7370e-03]]))

In [11]:
(s1.grad == s2.grad).all().item()

True

In [12]:
(w1.grad == w2.grad).all().item()

True