In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict
import pickle

import numpy as np
import IPython

In [2]:
class MyLinear(nn.Module):
    """Linear layer with equalized learning rate and custom learning rate multiplier"""
    def __init__(self,input_size, output_size,gain=2**(0.5), use_wscale=False,lrmul=1,bias=True):
        super().__init__()
        he_std = gain * input_size**(-0.5)
        # Equalized learning rate and custom learning rate multiplier.
        if use_wscale:
            init_std = 1.0 / lrmul
            self.w_mul = he_std * lrmul
        else:
            init_std = he_std / lrmul
            self.w_mul = lrmul
        self.weight = torch.nn.Parameter(torch.randn(output_size,input_size) * init_std)
        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(output_size))
            self.b_mul = lrmul
        else:
            self.bias = None
            
    def forward(self,x):
        bias = self.bias
        if bias is not None:
            bias = bias * self.b_mul
        return F.linear(x,self.weight * self.w_mul,bias)

In [3]:
class MyConv2d(nn.Module):
    """Conv layer with equalized lr and custom lr multiplier"""
    def __init__(self,input_channels,output_channels,kernel_size,gain=2**(0.5),use_wscale=False, lrmul=1, bias=True,intermediate=None,upscale=False):
        super().__init__()
        if upscale:
            self.upscale = Upscale2d()
        else:
            self.upscale = None
        he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5)
        self.kernel_size = kernel_size
        if use_wscale:
            init_std = 1.0 / lrmul
            self.w_mul = he_std * lrmul
        else:
            init_std = he_std / lrmul
            self.w_mul = he_std * lrmul
        self.weight = torch.nn.Parameter(torch.randn(output_channels,input_channels,kernel_size,kernel_size) * init_std)
        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(output_channels))
            self.b_mul = lrmul
        else:
            self.bias = None
        self.intermediate = intermediate
        
    def forward(self,x):
        bias = self.bias
        if bias is not None:
            bias = bias * self.b_mul
            
        have_convolution = False
        if self.upscale is not None and min(x.shape[2:]):
            w = self.weight * self.w_mul
            w = w.permute(1,0,2,3)
            w = F.pad(w,(1,1,1,1))
            w = w[:,:,1:,1:] + w[:,:,:-1,1:] + w[:,:,1:,:-1] + w[:,:,:-1,:-1]
            x = F.conv_transpose2d(x,w,stride=2,padding=(w.size(-1)-1),have_convolution=True)
        elif self.upscale is not None:
            x = self.upscale(x)
            
        if not have_convolution and self.intermediate is None:
            return F.conv2d(x,self.weight * self.w_mul,bias, padding=self.kernel_size//2)
        elif not have_convolution:
            x = F.conv2d(x,self.weight * self.w_mul, None, padding=self.kernel_size//2)
            
        if self.intermediate is not None:
            x = self.intermediate(x)
        if bias is not None:
            x = x + bias.view(1,-1,1,1)
        return x
            
            

In [4]:
class NoiseLayer(nn.Module):
    """adds noise. noise is per pixel (constant over channels) with per-channel weight"""
    def __init__(self,channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(channels))
        self.noise = None
    def forward(self,x,noise=None):
        if noise is None and self.noise is None:
            noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
        elif noise is None:
            noise = self.noise
        x = x + self.weight.view(1,-1,1,1) * noise
        return x

In [5]:
class StyleMod(nn.Module):
    def __init__(self,latent_size,channels,use_wscale):
        super(StyleMod,self).__init__()
        self.lin = MyLinear(latent_size,channels*2,gain=1.0,use_wscale=use_wscale)
        
    def forward(self,x,latent):
        style = self.lin(latent) # style => [batch_size,n_channels*2]
        shape = [-1,2,x.size(1)] + (x.dim() - 2) * [1]
        style = style.view(shape) # [batch_size,2,n_channels]
        x = x * (style[:,0] + 1.) + style[:,1]
        return x

In [6]:
class PixelNormLayer(nn.Module): # forcing std for latent vector to be one
    def __init__(self,epsilon=1e-8):
        super().__init__()
        self.epsilon = epsilon
    def forward(self,x):
        return x * torch.rsqrt(torch.mean(x**2,dim=1,keepdim=True) + self.epsilon)

In [None]:
class BlurLayer(nn.Module):
    def __init__(self,kernel=[1,2,1],normalize=True,flip=False,stride=1):
        super(BlurLayer,self).__init__()
        kernel = [1,2,1]
        kernel = torch.tensor(kernel,dtype=torch.float32)
        kernel = kernel[:,None] * kernel[None,:]
        kernel = kernel[None,None]
        if normalize:
            kernel = kernel / kernel.sum()
        if flip:
            kernel = kernel[:,:,::-1,::-1]
        self.register_buffer('kernel',kernel)
        self.stride = stride
        
    def forward(self,x):
        # expand kernels channels
        kernel = self.kernel.expand(x.size(1),-1,-1,-1)
        x = F.conv2d(x,kernel,stride=self.stride,padding=int((self.kernel.size(2)-1)/2),groups=x.size(1))
        return x