In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [13]:
def fused_leaky_relu(x, bias, negative_slope=0.2, scale=2 ** 0.5):
    return scale * F.leaky_relu(x + bias.view((1, -1)+(1,)*(len(x.shape)-2)), 
                                negative_slope=negative_slope)

def upfirdn2d_native(
    x, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
    x = x.permute(0, 2, 3, 1)
    _, in_h, in_w, minor = x.shape
    kernel_h, kernel_w = kernel.shape
    out = x.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(
        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
    )
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape(
        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
    )
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    # out = out.permute(0, 2, 3, 1)
    return out[:, :, ::down_y, ::down_x]


def upfirdn2d(x, kernel, up=1, down=1, pad=(0, 0)):    
    return upfirdn2d_native(x, kernel, up, up, down, down, 
                           pad[0], pad[1], pad[0], pad[1])    

class EqualizedLinear(nn.Module):
    """Linear layer with equalized learning rate and custom learning rate multiplier."""

    def __init__(self, in_dim, out_dim, bias=True, bias_init=0., activation=None,
                 gain=1., use_wscale=True, lrmul=1.):
        super(EqualizedLinear, self).__init__()

        # Equalized learning rate and custom learning rate multiplier.
        he_std = gain * in_dim ** (-0.5)  # He init
        if use_wscale:
            init_std = 1.0 / lrmul
            self.w_mul = he_std * lrmul
        else:
            init_std = he_std / lrmul
            self.w_mul = lrmul

        self.weight = torch.nn.Parameter(torch.randn(
            out_dim, in_dim) * init_std, requires_grad=True)

        if bias:
            self.bias = nn.Parameter(torch.zeros(
                out_dim).fill_(bias_init), requires_grad=True)
            self.b_mul = lrmul
        else:
            self.bias = None

        self.activation = activation

    def forward(self, x):
        if self.activation == 'lrelu':  # act='lrelu'
            out = F.linear(x, self.weight * self.w_mul)
            out = fused_leaky_relu(out, self.bias * self.b_mul)
        else:
            out = F.linear(x, self.weight * self.w_mul,
                           bias=self.bias * self.b_mul)

        return out
        

print(upfirdn2d(torch.randn(1,3,10,10).cuda(),torch.randn(3,3).cuda()).shape)
print(upfirdn2d(torch.randn(1,3,10,10), torch.randn(3,3)).shape)

el = EqualizedLinear(64, 32, activation='lrelu')
el(torch.rand(3, 64)).shape

torch.Size([1, 3, 8, 8])
torch.Size([1, 3, 8, 8])


torch.Size([3, 32])

In [11]:
def make_kernel(k):
    k = torch.tensor(k, dtype=torch.float32)

    if k.ndim == 1:
        k = k[None, :] * k[:, None]

    k /= k.sum()

    return k

class Blur(nn.Module):
    def __init__(self, kernel, pad, upsample_factor=1):
        super().__init__()

        kernel = make_kernel(kernel)

        if upsample_factor > 1:
            kernel = kernel * (upsample_factor ** 2)

        self.register_buffer('kernel', kernel)

        self.pad = pad

    def forward(self, x):
        out = upfirdn2d(x, self.kernel, pad=self.pad)

        return out
    
blur = Blur((3, 3), (0, 0))
blur(torch.rand((2, 4, 3, 3))).shape

torch.Size([2, 4, 2, 2])

In [19]:
class EqualizedModConv2d(nn.Module):
    def __init__(self, dlatent_size, in_channel, out_channel, kernel,
                 up=False, down=False, demodulate=True, resample_kernel=None,
                 gain=1., use_wscale=True, lrmul=1.):
        """
        """
        super(EqualizedModConv2d, self).__init__()

        assert not (up and down)
        assert kernel >= 1 and kernel % 2 == 1

        if resample_kernel is None:
            resample_kernel = [1, 3, 3, 1]

        self.in_channel = in_channel
        self.out_channel = out_channel
        self.up = up
        self.down = down
        self.demodulate = demodulate
        self.kernel = kernel

        if up:
            factor = 2
            p = (len(resample_kernel) - factor) - (kernel - 1)
            self.blur = Blur(resample_kernel, pad=(
                (p + 1) // 2 + factor - 1, p // 2 + 1), upsample_factor=factor)

        if down:
            factor = 2
            p = (len(resample_kernel) - factor) + (kernel - 1)
            self.blur = Blur(resample_kernel, pad=((p + 1) // 2, p // 2))

        self.mod = EqualizedLinear(
            in_dim=dlatent_size, out_dim=in_channel, bias_init=1.)

        he_std = gain * (in_channel * kernel ** 2) ** (-0.5)  # He init
        if use_wscale:
            init_std = 1.0 / lrmul
            self.w_mul = he_std * lrmul
        else:
            init_std = he_std / lrmul
            self.w_mul = lrmul

        self.weight = torch.nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel, kernel) * init_std, requires_grad=True)

    def forward(self, x, y):
        batch, in_channel, height, width = x.shape

        # Modulate
        s = self.mod(y).view(batch, 1, in_channel, 1, 1)
        ww = self.w_mul * self.weight * s

        # Demodulate
        if self.demodulate:
            # [BO] Scaling factor.
            d = torch.rsqrt(ww.pow(2).sum([2, 3, 4]) + 1e-8)
            # [BOIkk] Scale output feature maps.
            ww *= d.view(batch, self.out_channel, 1, 1, 1)

        weight = ww.view(batch * self.out_channel,
                         in_channel, self.kernel, self.kernel)

        if self.up:
            x = x.view(1, batch * in_channel, height, width)
            weight = weight.view(batch, self.out_channel,
                                 in_channel, self.kernel, self.kernel)
            weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel,
                                                    self.kernel, self.kernel)
            out = F.conv_transpose2d(
                x, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)
            out = self.blur(out)
        elif self.down:
            x = self.blur(x)
            _, _, height, width = x.shape
            x = x.view(1, batch * in_channel, height, width)
            out = F.conv2d(x, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)
        else:
            x = x.view(1, batch * in_channel, height, width)
            out = F.conv2d(x, weight, padding=self.kernel // 2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        return out

    def __repr__(self):
        return (
            f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
            f'upsample={self.up}, downsample={self.down})'
        )

dlatent_size = 256
in_channel = 64
out_channel = 128
kernel = 3
modconv2d = EqualizedModConv2d(dlatent_size, in_channel, out_channel, kernel,
                     up=False, down=False, demodulate=True, resample_kernel=None,
                     gain=1., use_wscale=True, lrmul=1.)    

bs = 7
modconv2d(torch.rand((bs, in_channel, 8, 8)),
          torch.rand((bs, dlatent_size))).shape

torch.Size([7, 128, 8, 8])