In [1]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
import pytorch_lightning as pl

from src.render.mesh_renderer import MeshPointsRenderer
from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.augment.diffaug import DiffAugment
from src.augment.geoaug import GeoAugment
from src.lpips import PerceptualLoss
from src.callback.image_mesh import ImageMesh

def train_d(labels, decodings=None, reals=None):
    """Train function of discriminator"""    
    if reals is not None:
        d_adv_r = F.relu(torch.rand_like(labels) * 0.2 + 0.8 -  labels).mean()        
        d_prcp = F.l1_loss(decodings, reals)
        log = {'d_adv_r': d_adv_r.item(), 'd_prcp': d_prcp.item()}
        loss = d_adv_r + d_prcp
    else:        
        d_advr_f = F.relu(torch.rand_like(labels) * 0.2 + 0.8 + labels).mean()
        log = {'d_advr_f': d_advr_f.item()}
        loss = d_advr_f
    return loss, log
    
    
class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        #self.automatic_optimization = False
        self.mean = hparams.fast_image_mean
        self.std = hparams.fast_image_std        
        self.diffaug = hparams.diffaug_policy
        self.G_noise_amp = hparams.G_noise_amp
        
        self.G = Generator(hparams)
        self.D = Discriminator(hparams)                
        self.R = MeshPointsRenderer(hparams)     
        
        #self.perceptual = PerceptualLoss(model='net-lin', net='squeeze')
     
    def forward(self, outline):
        return self.G(outline)
    
    def reconstruction(self, points, outline):        
        pts = F.avg_pool2d(points, points.size(-1) // outline.size(-1))
        return F.mse_loss(pts, outline) / pts.size(0)
    
    def train_generator(self, points, outline, renders):
        labels, _ = self.D(renders, False)
        g_adv = -labels.mean()
        g_rcn = self.reconstruction(points, outline)
        self.log(f"loss/g_adv", g_adv.item())
        self.log(f"loss/g_rcn", g_rcn.item())
        return g_adv + g_rcn
    
    def train_discriminator(self, reals, renders):
        # Train with reals
        labels, decodings = self.D(reals, True)
        loss_r, log = train_d(labels, decodings, reals)            
        for key in log.keys(): self.log(f"loss/{key}", log[key])
        # Train with renders
        labels, _ = self.D(renders.detach(), False)
        loss_f, log = train_d(labels)
        for key in log.keys(): self.log(f"loss/{key}", log[key])            
        return loss_r + loss_f
        
    def training_step(self, batch, batch_idx, optimizer_idx):        
        reals, outline = batch['image'], batch['outline']
        #with torch.autograd.detect_anomaly():
        noise =  torch.randn_like(outline) * self.G_noise_amp
        points, colors = self.G(outline + noise)        
        renders = self.R(points, colors, self.mean, self.std)
#         print(points.shape, batch['image'].shape, 
#               batch['outline'].shape, renders.shape)
        #reals = DiffAugment(reals, policy=self.diffaug)
        #renders = DiffAugment(renders, policy=self.diffaug)         
        if optimizer_idx == 0:            
            return self.train_discriminator(reals, renders)        
        if optimizer_idx == 1:            
            return self.train_generator(points, outline, renders)
    
    def configure_optimizers(self):
        lr, betas = 0.0001, (0.5, 0.999)
        opt_d = Adam(self.D.parameters(), lr=lr, betas=betas)
        opt_g = Adam(self.G.parameters(),  lr=lr, betas=betas)
        return [opt_d, opt_g], []
    
from src.config import get_parser

config = get_parser().parse_args(args=[])   
config.log_batch_interval = 100
gan = GAN(config)
gan

GAN(
  (G): Generator(
    (points): SinGenerator(
      (head): ConvBlock(
        (conv): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (main): Sequential(
        (b0): ConvBlock(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
        )
        (b1): ConvBlock(
          (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
        )
        (b2): ConvBlock(
          (conv): Conv2

In [2]:
from src.data.fast_datamodule import FastDataModule

dm = FastDataModule(config)    
dm

<src.data.fast_datamodule.FastDataModule at 0x7f5e51620d30>

In [3]:
trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20,
                     terminate_on_nan=True, callbacks=[ImageMesh(config)]
                    )
trainer.fit(gan, dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name | Type               | Params
--------------------------------------------
0 | G    | Generator          | 1.0 M 
1 | D    | Discriminator      | 1.5 M 
2 | R    | MeshPointsRenderer | 0     
--------------------------------------------
2.5 M     Trainable params
0         Non-trainable params
2.5 M     Total params
10.096    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

--- Logging error ---
Traceback (most recent call last):
  File "/home/bobi/miniconda3/envs/pytorch3d/lib/python3.8/logging/__init__.py", line 1081, in emit
    msg = self.format(record)
  File "/home/bobi/miniconda3/envs/pytorch3d/lib/python3.8/logging/__init__.py", line 925, in format
    return fmt.format(record)
  File "/home/bobi/miniconda3/envs/pytorch3d/lib/python3.8/logging/__init__.py", line 664, in format
    record.message = record.getMessage()
  File "/home/bobi/miniconda3/envs/pytorch3d/lib/python3.8/logging/__init__.py", line 369, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "/home/bobi/miniconda3/envs/pytorch3d/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/bobi/miniconda3/envs/pytorch3d/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/bobi/.local/lib/python3.8/site-packages/ipykernel

Message: Parameter containing:
tensor([[[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         ...,

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]]],


        [[[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         ...,

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan],
          [nan, nan, nan],
          [nan, nan, nan]],

         [[nan, nan, nan




ValueError: Detected nan and/or inf values in `G.points.head.conv.weight_orig`. Check your forward pass for numerically unstable operations.