In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.stylegan2.op import (
    fused_leaky_relu,
    upfirdn2d,
)

from src.stylegan2.CustomLayers import (
    Blur,
    EqualizedLinear, 
    EqualizedModConv2d,         
)

from src.stylegan2.Blocks import  (
    ModConvLayer,    
)

In [2]:
print(upfirdn2d(torch.randn(1,3,10,10).cuda(),torch.randn(3,3).cuda()).shape)
print(upfirdn2d(torch.randn(1,3,10,10), torch.randn(3,3)).shape)

torch.Size([1, 3, 8, 8])
torch.Size([1, 3, 8, 8])


In [3]:
el = EqualizedLinear(64, 32, activation='lrelu')
el(torch.rand(3, 64)).shape

torch.Size([3, 32])

In [4]:
blur = Blur((3, 3), (0, 0))
blur(torch.rand((2, 4, 3, 3))).shape

torch.Size([2, 4, 2, 2])

In [5]:
dlatent_size = 256
in_channel = 64
out_channel = 128
kernel = 3
modconv2d = EqualizedModConv2d(dlatent_size, in_channel, out_channel, kernel,
                     up=False, down=False, demodulate=True, resample_kernel=None,
                     gain=1., use_wscale=True, lrmul=1.)   

bs = 7
modconv2d(torch.rand((bs, in_channel, 8, 8)),
          torch.rand((bs, dlatent_size))).shape

torch.Size([7, 128, 8, 8])

In [6]:
modconv2d

EqualizedModConv2d(64, 128, 3, upsample=False, downsample=False)

In [13]:
modconvlayer = ModConvLayer(
    dlatent_size, 
    in_channel,
    out_channel, kernel, 
)
modconvlayer(torch.rand((bs, in_channel, 8, 8)), 
             torch.rand((bs, dlatent_size))).shape

torch.Size([7, 128, 8, 8])

In [None]:
class ModConvLayer(nn.Module):
    def __init__(self, dlatent_size, in_channel, out_channel, 
                 kernel, up=False, down=False, use_noise=True):
        super(ModConvLayer, self).__init__()

        self.conv = EqualizedModConv2d(dlatent_size=dlatent_size,
                                       in_channel=in_channel, out_channel=out_channel,
                                       kernel=kernel, up=up, down=down)
        self.bias = nn.Parameter(torch.zeros(out_channel), requires_grad=True)

        self.use_noise = use_noise
        if self.use_noise:
            self.noise_strength = nn.Parameter(
                torch.zeros(1), requires_grad=True)

    def forward(self, x, dlatents_in_range, noise_input=None):
        x = self.conv(x, dlatents_in_range)

        if self.use_noise:
            if noise_input is None:
                batch, _, height, width = x.shape
                noise_input = x.new_empty(batch, 1, height, width).normal_()

            x += self.noise_strength * noise_input

        out = fused_leaky_relu(x, self.bias)  # act='lrelu'

        return out