In [1]:
import numpy as np
import math

import torch
import torch.nn.functional as F

from torch import nn, Tensor

In [2]:
def equalized_lr_init(weight: Tensor, bias: Tensor, scale_weights=True, 
                      lr_mult=1.0) -> float:
    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
    he_std = 1.0 / math.sqrt(fan_in)

    if scale_weights:
        init_std = 1.0 / lr_mult
        scale = he_std * lr_mult
    else:
        init_std = he_std / lr_mult
        scale = lr_mult

    nn.init.normal_(weight, mean=0.0, std=init_std)
    if bias is not None:
        nn.init.zeros_(bias)
    return scale

In [3]:
class EqualLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, 
                 scale_weights=True, lr_mult=1.0):
        self.scale_weights = scale_weights
        self.lr_mult = lr_mult
        super(EqualLinear, self).__init__(in_features, out_features, bias)
            
    def reset_parameters(self):
        self.scale = equalized_lr_init(self.weight, self.bias, 
                                       self.scale_weights, self.lr_mult)
        
    def forward(self, x):
        return F.linear(x, self.weight * self.scale, self.bias)

In [4]:
el = EqualLinear(4, 8); el

EqualLinear(in_features=4, out_features=8, bias=True)

In [5]:
el(torch.randn(11, 4)).shape

torch.Size([11, 8])

In [6]:
class EqualConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros',
                 scale_weights=True, lr_mult=1.0):
        self.scale_weights = scale_weights
        self.lr_mult = lr_mult
        super(EqualConv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, 
            dilation, groups, bias, padding_mode)
            
    def reset_parameters(self):
        self.scale = equalized_lr_init(self.weight, self.bias, 
                                       self.scale_weights, self.lr_mult)

    def forward(self, input):
        return self.conv2d_forward(input, self.weight * self.scale)

In [7]:
ec = EqualConv2d(4, 8, kernel_size=3, padding=1); ec

EqualConv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [8]:
ec(torch.randn(3, 4, 12, 12)).shape

torch.Size([3, 8, 12, 12])

In [9]:
# class ModConv2d(nn.Module):
#     def __init__(self, in_channels, out_channels, kernel_size, style_dim, 
#                  demodulate=True, scale_weights=True, lr_mult=1)
#         super(ModConv2d, self).__init__()
#         self.style = EqualLinear(style_dim, in_channels, bias=True, 
#                                  scale_weights=scale_weights, 
#                                  lr_mult=lr_mult)
#         self.conv = EqualConv2d(in_channels, out_channels, kernel_size, 
#                                 padding=(kernel_size // 2), bias=True, 
#                                 scale_weights=scale_weights, 
#                                 lr_mult=lr_mult)
#         self.reset_parameters()
        
#     def reset_parameters(self):
#         nn.init.ones_(self.style.bias)
        
#     def forward(self, x, y):
#         pass

In [10]:
class ModConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, bias=True, 
                 scale_weights=True, lr_mult=1.0):
        self.scale_weights = scale_weights
        self.lr_mult = lr_mult
        super(ModConv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, 
            dilation, groups=1, bias=bias, padding_mode='zeros')
            
    def reset_parameters(self):
        self.w_mult = equalized_lr_init(
            self.weight, self.bias, self.scale_weights, self.lr_mult)
        
    def conv2d_forward(self, input, style, weight, bias):
        N, C, H, W = input.shape
        w = weight[None, :] # OIkk -> NOIkk
        
        s = style[:, None, :, None, None] # NI -> NOIkk
        w = w * s
        
        d = torch.rsqrt(w.pow(2).sum(dim=(2,3,4), keepdim=True) + 1e-8)
        w = w * d
        
        N, C1, C, Hk, Wk = w.shape
        w = w.view(N*C1, C, Hk, Wk)
        
        x = input.view(1, -1, H, W)
        out = F.conv2d(x, w, None, self.stride, self.padding, 
                     self.dilation, groups=N)
        _, _, H1, W1 = out.shape
        out = out.view(N, C1, H1, W1)
        
        if bias is not None:
            out = out + bias[:, None, None]
        return out

    def forward(self, input, style):
        weight = self.weight * self.w_mult
        if self.bias is not None:
            bias = self.bias * self.lr_mult
        else:
            bias = None
        return self.conv2d_forward(input, style, weight, bias)

In [11]:
mc = ModConv2d(4, 8, 3, padding=1, bias=True); mc

ModConv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [12]:
mc(
    torch.randn(7, 4, 12, 12),
    torch.randn(7, 4)
).shape

torch.Size([7, 8, 12, 12])