In [1]:
import torch

import pytorch_lightning as pl

from src.config import get_parser
from src.models.gan import GAN

from src.data.masked_datamodule import MaskedDataModule

In [2]:
from collections import OrderedDict

import torch
import torch.nn.functional as F
import pytorch_lightning as pl

from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.models.stylist import Stylist
from src.renderer import Renderer

class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.automatic_optimization = False
        self.mean = sum(hparams.image_mean) / len(hparams.image_mean)
        self.std = sum(hparams.image_std) / len(hparams.image_std)
        
        self.G = Generator(hparams)
        self.D = Discriminator(hparams)        
        # Renderer requires device, created in .to() step
        self.R = Renderer(hparams)
        
     
    def forward(self, shape, style):
        return self.G(shape, style)
    
    def adversarial_loss(self, y_hat, y):
        return F.mse_loss(y_hat, y)
    
    def training_step(self, batch, batch_idx, optimizer_idx):        
        img_patch = batch['img_patch']
        points =  batch['points']
        normals = batch['normals']            
        bs = img_patch.size(0)
        
        self.R.setup(points.device)
        
        # train generator
        if optimizer_idx == 0:            
            vertices = self.G(points)
            print(points.shape, vertices.shape)
            renders =  self.R(vertices).permute(0, 3, 1, 2)             
            renders = (renders - self.mean) / self.std

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(bs, 1).type_as(points)            
           
            g_loss = self.adversarial_loss(self.D(renders), valid)
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(bs, 1).type_as(points)         

            real_loss = self.adversarial_loss(self.D(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1).type_as(points)            
                        
            vertices = self.G(points)            
            renders =  self.R(vertices).permute(0, 3, 1, 2)           
            renders = (renders - self.mean) / self.std

            fake_loss = self.adversarial_loss(
                self.D(renders.detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

    
    def configure_optimizers(self):
        lr_g = self.hparams.lr_g
        lr_d = self.hparams.lr_d
        b1 = self.hparams.beta1
        b2 = self.hparams.beta2      
        opt_g = torch.optim.Adam(self.G.parameters(), 
                                 lr=lr_g, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.D.parameters(), 
                                 lr=lr_d, betas=(b1, b2))
        return [opt_g, opt_d], []
    
config = get_parser().parse_args(args=[])
config.batch_size = 4
model = GAN(config)
model

GAN(
  (G): Generator(
    (head): ConvBlock(
      (conv): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (body): Sequential(
      (b1): ConvBlock(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (b2): ConvBlock(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (b3): ConvBlock(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm

In [3]:
dm = MaskedDataModule(config)
dm.setup()
dm

<src.data.masked_datamodule.MaskedDataModule at 0x7f7c72f28a60>

In [5]:
trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name | Type          | Params
---------------------------------------
0 | G    | Generator     | 450 K 
1 | D    | Discriminator | 447 K 
2 | R    | Renderer      | 0     
---------------------------------------
898 K     Trainable params
0         Non-trainable params
898 K     Total params
3.593     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])


Please use self.log(...) inside the lightningModule instead.
# log on a step or aggregate epoch metric to the logger and/or progress bar (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
Please use self.log(...) inside the lightningModule instead.
# log on a step or aggregate epoch metric to the logger and/or progress bar (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)


torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 256, 256]) torch.Size([4, 3, 256, 256])
torch.Size([4,

Traceback (most recent call last):
  File "/home/bobi/miniconda3/envs/pytorch3d/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/bobi/miniconda3/envs/pytorch3d/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/bobi/miniconda3/envs/pytorch3d/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/home/bobi/miniconda3/envs/pytorch3d/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


1

In [4]:
batch = next(iter(dm.train_dataloader()))
for key in batch.keys():
    print(key, batch[key].shape)

style_img torch.Size([4, 1, 192, 192])
img_patch torch.Size([4, 1, 256, 256])
points torch.Size([4, 3, 512, 512])
normals torch.Size([4, 3, 512, 512])


In [6]:
from torchsummary import summary

summary(model.S, (1, 192, 192));

------------------------------------------------------------------------------------------
Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 512, 6, 6]           --
|    └─Conv2d: 2-1                       [-1, 64, 192, 192]        640
|    └─BatchNorm2d: 2-2                  [-1, 64, 192, 192]        128
|    └─ReLU: 2-3                         [-1, 64, 192, 192]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 96, 96]          --
|    └─Conv2d: 2-5                       [-1, 128, 96, 96]         73,856
|    └─BatchNorm2d: 2-6                  [-1, 128, 96, 96]         256
|    └─ReLU: 2-7                         [-1, 128, 96, 96]         --
|    └─MaxPool2d: 2-8                    [-1, 128, 48, 48]         --
|    └─Conv2d: 2-9                       [-1, 256, 48, 48]         295,168
|    └─BatchNorm2d: 2-10                 [-1, 256, 48, 48]         512
|    └─ReLU: 2-11                        [-1, 256, 

In [3]:
from torchsummary import summary

summary(model.S, (1, 192, 192));

------------------------------------------------------------------------------------------
Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 512, 12, 12]         --
|    └─Conv2d: 2-1                       [-1, 64, 192, 192]        640
|    └─BatchNorm2d: 2-2                  [-1, 64, 192, 192]        128
|    └─ReLU: 2-3                         [-1, 64, 192, 192]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 96, 96]          --
|    └─Conv2d: 2-5                       [-1, 128, 96, 96]         73,856
|    └─BatchNorm2d: 2-6                  [-1, 128, 96, 96]         256
|    └─ReLU: 2-7                         [-1, 128, 96, 96]         --
|    └─MaxPool2d: 2-8                    [-1, 128, 48, 48]         --
|    └─Conv2d: 2-9                       [-1, 256, 48, 48]         295,168
|    └─BatchNorm2d: 2-10                 [-1, 256, 48, 48]         512
|    └─ReLU: 2-11                        [-1, 256, 

In [8]:
summary(model.G, [(3, 512, 512), (3, 512, 512), (config.dlatent_size,)]);

------------------------------------------------------------------------------------------
Layer (type:depth-idx)                   Output Shape              Param #
├─ModConvLayer: 1-1                      [-1, 128, 512, 512]       --
|    └─EqualizedModConv2d: 2-1           [-1, 128, 512, 512]       --
|    |    └─EqualizedLinear: 3-1         [-1, 6]                   774
├─Sequential: 1-2                        [-1, 128, 512, 512]       --
|    └─ConvBlock: 2-2                    [-1, 128, 512, 512]       --
|    |    └─Conv2d: 3-2                  [-1, 128, 512, 512]       147,584
|    |    └─BatchNorm2d: 3-3             [-1, 128, 512, 512]       256
|    |    └─LeakyReLU: 3-4               [-1, 128, 512, 512]       --
|    └─ConvBlock: 2-3                    [-1, 128, 512, 512]       --
|    |    └─Conv2d: 3-5                  [-1, 128, 512, 512]       147,584
|    |    └─BatchNorm2d: 3-6             [-1, 128, 512, 512]       256
|    |    └─LeakyReLU: 3-7               [-1, 128, 

In [9]:
rsz = config.raster_image_size
summary(model.D, (3, rsz, rsz));

------------------------------------------------------------------------------------------
Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 128]                 --
|    └─ConvBlock: 2-1                    [-1, 128, 512, 512]       --
|    |    └─Conv2d: 3-1                  [-1, 128, 512, 512]       3,584
|    |    └─BatchNorm2d: 3-2             [-1, 128, 512, 512]       256
|    |    └─LeakyReLU: 3-3               [-1, 128, 512, 512]       --
|    └─ConvBlock: 2-2                    [-1, 128, 512, 512]       --
|    |    └─Conv2d: 3-4                  [-1, 128, 512, 512]       147,584
|    |    └─BatchNorm2d: 3-5             [-1, 128, 512, 512]       256
|    |    └─LeakyReLU: 3-6               [-1, 128, 512, 512]       --
|    └─ConvBlock: 2-3                    [-1, 128, 512, 512]       --
|    |    └─Conv2d: 3-7                  [-1, 128, 512, 512]       147,584
|    |    └─BatchNorm2d: 3-8             [-1, 128