In [1]:
# Train generator and discriminator 
# directly on surface
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.loss import (
    chamfer_distance,    
)

from src.callback.log_mesh import LogMesh
from src.utilities.util import grid_to_list
from src.models.discriminator import Discriminator
from src.models.surface_generator import Generator

In [2]:
class DirectGAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()        
        self.save_hyperparameters(hparams)
        
        self.G = Generator(hparams)
        self.D = Discriminator(hparams)
        
        
    def forward(self, outline):
        return self.G(outline)
    
    def log_all(self, r, p, d):
        for k, v in  d.items(): 
            self.log(f"{r}_loss_{p}/{k}", v)        
            
    def adversarial_loss(self, lbl, is_real):
        trg = torch.ones_like(lbl) if is_real else torch.zeros_like(lbl)
        return F.binary_cross_entropy_with_logits(lbl, trg)
    
    def train_generator(self, vertices, batch):
        rcn = F.l1_loss(vertices, batch['baseline'])        
        lbl, _ = self.D(vertices, False)        
        adv = self.adversarial_loss(lbl, False)        
        loss = rcn + adv
        log =  {
            'G_loss' : loss.item(),
            'G_rcn': rcn.item(),            
            'G_adv': adv.item(),            
        }        
        return loss, log
    
    def train_discriminator(self, vertices, batch):
        # Real
        lbl_r, decodings = self.D(batch['baseline'], True)        
        adv_r = self.adversarial_loss(lbl_r, True)                        
        #dcd = F.l1_loss(decodings, vertices)                
        # Fake                
        lbl_f, _ = self.D(vertices, False)                
        adv_f = self.adversarial_loss(lbl_f, False)
        
        loss = adv_r  + adv_f #+ dcd
        log = {
            'D_loss' : loss.item(),
            'D_adv_r' : adv_r.item(),            
            #'D_dcd' : dcd.item(),
            'D_adv_f' : adv_f.item(),
        }        
        return loss, log
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        vertices = self.G(batch['outline'])
        loss = 0
        if optimizer_idx == 0:            
            loss, log = self.train_discriminator(vertices, batch)        
            self.log_all('train', 'D', log)
        elif optimizer_idx == 1:            
            loss, log = self.train_generator(vertices, batch)
            self.log_all('train', 'G', log)            
        return loss
        
#     def validation_step(self, batch, batch_idx):
#         vertices = self.G(batch['outline'])
#         _, log = self.train_discriminator(vertices, batch)        
#         self.log_all('val', 'D', log)   
#         _, log = self.train_generator(vertices, batch)
#         self.log_all('val', 'G', log)     
        
    def configure_optimizers(self):
        lr, betas = 0.0003, (0.5, 0.999)
        opt_d = torch.optim.Adam(self.D.parameters(), lr=lr, betas=betas)
        opt_g = torch.optim.Adam(self.G.parameters(),  lr=lr, betas=betas)
        return [opt_d, opt_g], []
    
    

from src.config import get_parser

config = get_parser().parse_args(args=[])   

config.log_mesh_interval = 100
config.fast_outline_size =  64
#config.fast_baseline_size = 128
#config.fast_image_size = 128
config.fast_batch_size = 16
config.raster_faces_per_pixel = 4
config.G_noise_amp = 0.005 #0.002
#config.geoaug_policy = 'scaling'
#config.fast_discriminator_channels[0] = 3


model = DirectGAN(config)
model

DirectGAN(
  (G): Generator(
    (points): SurfaceGenerator(
      (trunk): Sequential(
        (head): UpConvBlock(
          (upsample): Upsample(scale_factor=2.0, mode=bilinear)
          (conv): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (lrelu): LeakyReLU(negative_slope=0.2)
        )
        (main): Sequential(
          (b0): ConvBlock(
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (lrelu): LeakyReLU(negative_slope=0.2)
          )
          (b1): ConvBlock(
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (lrelu): LeakyReLU(negative_slope=0.2)
          )
          (b2): ConvBlock(
            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (lrelu): LeakyReLU(negative_slope=0.2)
          )
        )
      )
      (points): Sequential(
        (0): Conv2d(

In [3]:
from src.data.surface_datamodule import SurfaceDataModule

dm = SurfaceDataModule(config)    
dm

<src.data.surface_datamodule.SurfaceDataModule at 0x7f51d30a49d0>

In [4]:
trainer = pl.Trainer(gpus=1, max_epochs=100, progress_bar_refresh_rate=20,
                     terminate_on_nan=True, 
                     #profiler="pytorch",
                     log_every_n_steps=2, 
                     callbacks=[LogMesh(config)],
                    )
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type          | Params
---------------------------------------
0 | G    | Generator     | 1.8 M 
1 | D    | Discriminator | 2.4 M 
---------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.599    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

RuntimeError: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 7.80 GiB total capacity; 787.59 MiB already allocated; 141.88 MiB free; 862.00 MiB reserved in total by PyTorch)