In [1]:
import numpy as np
import math

import torch
import torch.nn.functional as F

from torch import nn, Tensor

In [2]:
c = nn.Conv2d(2, 8, kernel_size=(3,4), padding=1)

In [3]:
# O, I, H, W
c.weight.shape

torch.Size([8, 2, 3, 4])

In [4]:
# def dense_layer(x, fmaps, gain=1, use_wscale=True, lrmul=1, weight_var='weight'):
#     if len(x.shape) > 2:
#         x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])])
#     w = get_weight([x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale, 
#                    lrmul=lrmul, weight_var=weight_var)
#     w = tf.cast(w, x.dtype)
#     return tf.matmul(x, w)

In [5]:
class EqualLinear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias=True, 
                 scale_weights=True, lr_mult=1.):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))  
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.lr_mult = lr_mult
        self.reset_parameters(scale_weights)
            
    def reset_parameters(self, scale_weights: bool):
        he_std = 1.0 / math.sqrt(self.in_features)
        
        if scale_weights:
            init_std = 1.0 / self.lr_mult
            self.scale = he_std * self.lr_mult
        else:
            init_std = he_std / self.lr_mult
            self.scale = self.lr_mult
            
        nn.init.normal_(self.weight, mean=0.0, std=init_std)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
        
    def forward(self, x):
        return F.linear(x, self.weight * self.scale, self.bias)
    
    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)

In [6]:
el = EqualLinear(4, 8); el

EqualLinear(in_features=4, out_features=8, bias=True)

In [7]:
el(torch.randn(11, 4)).shape

torch.Size([11, 8])