In [1]:
import torch

import pytorch_lightning as pl

from src.config import get_parser
from src.models.gan import GAN
from src.callback import LogExportCallback
from src.data.masked_datamodule import MaskedDataModule

In [2]:
from collections import OrderedDict

import torch
import torch.nn.functional as F
import pytorch_lightning as pl

from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.renderer import Renderer

class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.automatic_optimization = False
        self.mean = sum(hparams.image_mean) / len(hparams.image_mean)
        self.std = sum(hparams.image_std) / len(hparams.image_std)
        
        self.G = Generator(hparams)
        self.D = Discriminator(hparams)        
        # Renderer requires device, created in .to() step
        self.R = Renderer(hparams)
        
     
    def forward(self, shape, style):
        return self.G(shape, style)
    
    def adversarial_loss(self, y_hat, y):
        return F.mse_loss(y_hat, y)
    
    def training_step(self, batch, batch_idx, optimizer_idx):        
        img_patch = batch['img_patch']
        points =  batch['points']
        normals = batch['normals']            
        bs = img_patch.size(0)
        
        self.R.setup(points.device)
        
        # train generator
        if optimizer_idx == 0:            
            vertices = self.G(points)            
            renders =  self.R(vertices).permute(0, 3, 1, 2)             
            renders = (renders - self.mean) / self.std

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(bs, 1).type_as(points)            
           
            g_loss = self.adversarial_loss(self.D(renders), valid)
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,                
            })
            self.log("loss/g_loss", g_loss, on_epoch=True, prog_bar=True)
            return output

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(bs, 1).type_as(points)         

            real_loss = self.adversarial_loss(self.D(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1).type_as(points)            
                        
            vertices = self.G(points)            
            renders =  self.R(vertices).permute(0, 3, 1, 2)           
            renders = (renders - self.mean) / self.std

            fake_loss = self.adversarial_loss(
                self.D(renders.detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,                
            })
            self.log("loss/d_loss", d_loss, on_step=True, on_epoch=True, prog_bar=True)            
            return output

    
    def configure_optimizers(self):
        lr_g = self.hparams.lr_g
        lr_d = self.hparams.lr_d
        b1 = self.hparams.beta1
        b2 = self.hparams.beta2      
        opt_g = torch.optim.Adam(self.G.parameters(), 
                                 lr=lr_g, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.D.parameters(), 
                                 lr=lr_d, betas=(b1, b2))
        return [opt_g, opt_d], []
    
config = get_parser().parse_args(args=[])
#config.batch_size = 8
model = GAN(config)
model

GAN(
  (G): Generator(
    (head): ConvBlock(
      (conv): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (body): Sequential(
      (b1): ConvBlock(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (b2): ConvBlock(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (b3): ConvBlock(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm

In [3]:
dm = MaskedDataModule(config)
dm.setup()
dm

<src.data.masked_datamodule.MaskedDataModule at 0x7fab591b9dc0>

In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20, 
                     callbacks=[LogExportCallback(config)])
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name | Type          | Params
---------------------------------------
0 | G    | Generator     | 1.8 M 
1 | D    | Discriminator | 1.8 M 
2 | R    | Renderer      | 0     
---------------------------------------
3.6 M     Trainable params
0         Non-trainable params
3.6 M     Total params
14.264    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…