In [1]:
import sys

import numpy as np
import math

import torch
import torch.nn.functional as F

from torch import nn, Tensor

In [2]:
sys.path.insert(0, '../src')

In [3]:
from models.stylegan2.layers import EqualConv2d, EqualLinear, EqualLeakyReLU, \
    ModulatedConv2d, RandomGaussianNoise

from models.stylegan2.net import Layer, ToRGB, Input, SynthesisNet, MappingNet

In [4]:
import random

In [5]:
class Generator(nn.Module):
    def __init__(self, mapping: MappingNet, synthesis: SynthesisNet, p_style_mix=0.9):
        super(Generator, self).__init__()
        self.mapping = mapping
        self.synthesis = synthesis
        self.p_style_mix = p_style_mix
        
    @property
    def num_layers(self): return self.synthesis.num_layers 
        
    def mix_styles(self, z1: Tensor, y: Tensor, w1: Tensor):
        num_layers = self.num_layers
        
        if random.uniform(0, 1) < self.p_style_mix:
            mix_cutoff = int(random.uniform(1, num_layers))
        else:
            mix_cutoff = num_layers
        
        z2 = torch.randn_like(z1)
        w2 = self.mapping(z2, y)
        mask = torch.arange(num_layers)[:, None, None] < mix_cutoff
        return torch.where(mask, w1, w2)
        
    def forward(self, z, y=None):
        w = self.mapping(z, y)
        if w.ndim == 2:
            w = w.expand(self.num_layers, -1, -1)
        if self.p_style_mix is not None:
            w = self.mix_styles(z, y, w)
        out = self.synthesis(w)
        return out

In [6]:
latent_dim = 16
style_dim = 16

synthesis = SynthesisNet(
    img_res=32,
    fmap_base=2 << 5,
    style_dim=style_dim
)

mapping = MappingNet(
    latent_dim=latent_dim,
    label_dim=1,
    style_dim=style_dim,
    num_layers=3,
    hidden_dim=16
)

g = Generator(mapping, synthesis, p_style_mix=0.9)

In [7]:
N = 3
z = torch.rand(N, latent_dim)
y = torch.rand(N, 1)
g(z, y).shape

torch.Size([3, 3, 32, 32])