In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
import torch
import torch.nn as nn
import torch.functional as tf
from torch.nn.modules.activation import ReLU


from .helpers.conv2d import Conv2d


class StyleController(nn.Module):
    """
    
    Style Controller network.
    
    """

    def __init__(self, batch_size: int, input_size = 8):
        """
        
        Style Controller Network
        :param batch_size      : number of examples in a batch
        :param input_size      : dimension of the style vectors
        
        """
        super().__init__()
        
        self.input_size = input_size
        self.batch_size = batch_size
        # Used in output channel calculations
        # Authors of the paper set it to 64 
        self.k = 64

        # inp: (in_batch, input_size)
        # out: (in_batch, 128)
        self.fc1 = nn.Linear(self.input_size, 128, bias = True)
        self.initialize_weights_with_he_biases_with_zero(self.fc1)
        
        # inp: (in_batch, 128)
        # out: (in_batch, 128)
        self.ln1 = nn.LayerNorm(128)
        
        # inp: (in_batch, 128)
        # out: (in_batch, 128)
        self.relu1 = nn.ReLU()
        
        # inp: (in_batch, 128)
        # out: (in_batch, 128)
        self.fc2 = nn.Linear(128, 128, bias = True)
        self.initialize_weights_with_he_biases_with_zero(self.fc2)
        
        # inp: (in_batch, 128)
        # out: (in_batch, 128)
        self.ln2 = nn.LayerNorm(128)
        
        # inp: (in_batch, 128)
        # out: (in_batch, 128)
        self.relu2 = nn.ReLU()
        
        # inp: (in_batch, 128)
        # out: (in_batch, 4 * k)
        self.fc3 = nn.Linear(128, 4 * self.k, bias = True)
        self.initialize_weights_with_he_biases_with_zero(self.fc3) 
        
        # inp: (in_batch, 128)
        # out: (in_batch, 4 * k)
        self.fc4 = nn.Linear(128, 4 * self.k, bias = True)
        self.initialize_weights_with_he_biases_with_zero(self.fc4)

        
    def forward(self, x) -> torch.Tensor:
        """
        
        Forward function for Style Controller.
        Returns two concatenated (batch_size, 1, 1, 4 * k) shaped tensors, gamma and beta coefficients
        
        :param x: style encodings
            :shape: (batch_size, input_size)
        :return : out
            :shape: (batch_size, 2, 1, 4 * k)
        
        """
        
        def forward(self, x):
        
        if x is None:
            x = torch.randn((self.batch_size, self.input_size))
        
        # inp: (batch_size, input_size)
        # out: (batch_size, 128)
        x = self.fc1(x)
        x = self.ln1(x)
        x = self.relu1(x)
        
        # inp: (batch_size, 128)
        # out: (batch_size, 128)
        x = self.fc2(x)
        x = self.ln2(x)
        x = self.relu2(x)
        
        # inp: (batch_size, 128)
        # out: (batch_size, 1, 1, 4 * k)
        gamma = self.fc3(x)
        gamma = torch.reshape(gamma, [-1, 1, 1, 4 * self.k])
        
        # inp: (batch_size, 128)
        # out: (batch_size, 1, 1, 4 * k)
        beta = self.fc4(x)
        beta = torch.reshape(beta, [-1, 1, 1, 4 * self.k])
        
        return torch.cat((beta, gamma), 0)
        
        
        


IndentationError: expected an indented block (Temp/ipykernel_13820/2469161808.py, line 85)

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels, style_controller):
         
        super().__init__()
        
        self.style_controller = style_controller 
        self.k = self.style_controller.k
        
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = 4 * self.k, kernel_size=3)
        self.initialize_weights_and_biases(self.conv1)
        
        self.instance_norm_layer1 = nn.InstanceNorm2d(4 * self.k)
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(in_channels = 4 * self.k, out_channels = 4 * self.k, kernel_size = 3, bias = False)
        nn.init.kaiming_normal_(self.conv2.weight)
        
        self.instance_norm_layer2 = nn.InstanceNorm2d(4 * self.k)
        
        self.conv3 = nn.Conv2d(in_channels = 4 * self.k, out_channels = 4 * self.k, kernel_size = 3)
        self.initialize_weights_and_biases(self.conv3)
        
        self.instance_norm_layer3 = nn.InstanceNorm2d(4 * self.k)
        self.relu2 = nn.ReLU()
        
        self.conv4 = nn.Conv2d(in_channels = 4 * self.k, out_channels = 4 * self.k, kernel_size = 3, bias = False)
        nn.init.kaiming_normal_(self.conv4.weight)
        
        self.instance_norm_layer4 = nn.InstanceNorm2d(4 * self.k)
        
        self.conv5 = nn.Conv2d(in_channels = 4 * self.k, out_channels = 4 * self.k, kernel_size = 3)
        self.initialize_weights_and_biases(self.conv5)
        
        self.instance_norm_layer5 = nn.InstanceNorm2d(4 * self.k)
        self.relu3 = nn.ReLU()
        
        self.conv6 = nn.Conv2d(in_channels = 4 * self.k, out_channels = 4 * self.k, kernel_size = 3, bias = False)
        nn.init.kaiming_normal_(self.conv6.weight)
        
        self.instance_norm_layer6 = nn.InstanceNorm2d(4 * self.k)
        
        self.conv7 = nn.Conv2d(in_channels = 4 * self.k, out_channels = 4 * self.k, kernel_size = 5)
        self.initialize_weights_and_biases(self.conv7)
        
        self.instance_norm_layer7 = nn.InstanceNorm2d(2 * self.k)
        self.relu4 = nn.ReLU()
        
        self.conv8 = nn.Conv2d(2 * self.k, self.k, 5)
        self.initialize_weights_and_biases(self.conv8)
        
        self.instance_norm_layer8 = nn.InstanceNorm2d(self.k)
        self.relu5 = nn.ReLU()
        
        self.conv9 = nn.Conv2d(self.k, 3, 7)
        self.initialize_weights_and_biases(self.conv9, True)
        
        self.tanh = nn.Tanh()
        
    
    def initialize_weights_and_biases(self, layer : nn.Module, bothWeightAndBias = False):
        if not bothWeightAndBias:
            nn.init.kaiming_normal_(layer.weight)
            layer.bias.data.fill_(0.0)
        else:
            layer.weight.data.fill_(0.0)
            layer.bias.data.fill_(0.0)
            
    def forward(self, styles, x):
        style_controller_res = self.style_controller.forward(styles)
        gamma = style_controller_res[0:self.style_controller.batch_size, :, :, :]
        beta = style_controller_res[self.style_controller.batch_size :, :, :, :]
        
        print("X shape after padding: ", padding2(x, 1, pad_mode="zero").shape)
        x_ = self.conv1(padding2(x, 1, pad_mode="zero"))
        print("X shape after conv: ", x_.shape)
        x_ = self.instance_norm_layer1(x_)
        print("X shape after instance norm: ", x_.shape)
        x_ = gamma * x_ + beta
        print("2: ", x_.shape)
        x_ = self.relu1(x_)
        
        print(self.instance_norm_layer2(self.conv2(padding2(x_, 1, pad_mode="zero"))).shape)
        print(x.shape)
        x += self.instance_norm_layer2(self.conv2(padding2(x_, 1, pad_mode="zero")))
        
        print('module res{} shape:'.format(1), [dim for dim in x.shape])
        
        x = padding2(x, 1, pad_mode="zero")
        x_ = self.conv3(x)
        x_ = self.instance_norm_layer3(x_)
        x_ = gamma * x_ + beta
        x_ = self.relu2(x_)
        
        x_ = padding2(x_, 1, pad_mode="zero")
        x += self.instance_norm_layer4(self.conv4(x_))
        
        print('module res{} shape:'.format(2), [dim for dim in x.shape])
        
        x = padding2(x, 1, pad_mode="zero")
        x_ = self.conv5(x)
        x_ = self.instance_norm_layer5(x_)
        print("Gmma shape: ", gamma.shape)
        x_ = gamma * x_ + beta
        x_ = self.relu3(x_)
        
        x_ = padding2(x_, 1, pad_mode="zero")
        x += self.instance_norm_layer6(self.conv6(x_))
        
        print('module res{} shape:'.format(3), [dim for dim in x.shape])
        
        x = upscale2d(x, 2, pad_mode="zero")
        x = padding2(x, 2)
        x = self.conv7(x)
        x = self.instance_norm_layer7(x)
        x = self.relu4(x)
        
        print('module deconv1 shape:', [dim for dim in x.shape])

        x = upscale2d(x, 2)
        x = padding2(x, 2, pad_mode="zero")
        x = self.conv8(x)
        x = self.instance_norm_layer8(x)
        x = self.relu5(x)
        
        x = padding2(x, 3, pad_mode="zero")
        x = self.conv9(x)
        return self.tanh(x)
        
        

In [None]:
controller = StyleController(10,8)
styles = torch.tensor(np.ones((10, 8), dtype=float)).float()
decoder = Decoder(256, controller)

In [None]:
decoder.forward(styles, torch.tensor(np.ones((10, 256, 256, 256))).float())