In [3]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.render.mesh_renderer import MeshPointsRenderer
from src.models.stylist import Stylist

class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.generator = Generator(hparams)
        self.discriminator = Discriminator(hparams)
        self.stylist = Stylist(hparams)
        self.renderer = MeshPointsRenderer(hparams)        

    def forward(self, points, style):        
        """
        Generates a vertices using the generator
        given input points & style
        """
        return self.generator(points, style)
    
    def generator_step(self, points, images):
        """
        Training step for generator
        1. Sample random noise and labels
        2. Pass noise and labels to generator to
           generate images
        3. Classify generated images using
           the discriminator
        4. Backprop loss
        """

        # Sample random noise and labels

        style = self.Stylist(images)
        vertices = self(points, style)        
        # Generate images
        generated_imgs = self.renderer(vertices)

        # Classify generated image using the discriminator

        d_output = torch.squeeze(self.discriminator(generated_imgs, y))

        # Backprop loss. We want to maximize the discriminator's
        # loss, which is equivalent to minimizing the loss with the true
        # labels flipped (i.e. y_true=1 for fake images). We do this
        # as PyTorch can only minimize a function instead of maximizing
        g_loss = nn.BCELoss()(d_output, torch.ones(x.shape[0],
                              device=device))
        self.log("loss/g_loss", g_loss)
        return g_loss
    
    def discriminator_step(self, points, images):
        """
        Training step for discriminator
        1. Get actual images and labels
        2. Predict probabilities of actual images and get BCE loss
        3. Get fake images from generator
        4. Predict probabilities of fake images and get BCE loss
        5. Combine loss from both and backprop
        """
       
        # Real images
        d_output = torch.squeeze(self.discriminator(images))
        loss_real = nn.BCELoss()(d_output, torch.ones(x.shape[0],
                                 device=images.device))

        # Fake images
        style = self.Stylist(images)
        vertices = self(points, style)        
        
        generated_imgs = self.renderer(vertices)
        d_output = torch.squeeze(self.discriminator(generated_imgs))
        loss_fake = nn.BCELoss()(d_output, torch.zeros(x.shape[0],
                                 device=images.device))
        d_loss = (real_loss + fake_loss) / 2
        self.log("loss/loss_real", loss_real)
        self.log("loss/loss_fake", loss_fake)
        self.log("loss/d_loss", d_loss)
        return d_loss
    
    def training_step(
        self,
        batch,
        batch_idx,
        optimizer_idx,
        ):
        images = batch['images']
        points = batch['points']      
        bs = img_patch.size(0)

        # train generator
        if optimizer_idx == 0:
            loss = self.generator_step(points, images)

        # train discriminator
        if optimizer_idx == 1:
            loss = self.discriminator_step(points, images)

        return loss
    
    def configure_optimizers(self):
        gsp = list(self.generator.parameters()) + list(self.stylist.parameters())
        gs_optimizer = torch.optim.Adam(gsp,
                lr=self.hparams.lr_gs)
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                lr=self.hparams.lr_d)
        return ([gs_optimizer, d_optimizer], [])
    
    
from src.config import get_parser
config = get_parser().parse_args(args=[])
#config.batch_size = 8
model = GAN(config)
model

GAN(
  (generator): Generator(
    (head): ConvBlock(
      (conv): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (b1): GenBlock(
      (mod_conv): ModulateConvBlock(
        (style): DenseBlock(
          (fc): Linear(in_features=256, out_features=256, bias=False)
          (activate): Identity()
        )
        (activate): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (conv1): ConvBlock(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (conv2): ConvBlock(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(256, eps=1e-0

In [6]:
model.stylist(torch.rand(5, 1, 64, 64)).shape

torch.Size([5, 256])