In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.loss import (
    chamfer_distance,    
)
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from src.callback.log_render_sample import LogRenderSample
from src.utilities.util import (
    grid_to_list,
    make_faces,
)

In [2]:
# pyright: reportMissingImports=false

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm

from src.models.util import ConvBlock
from src.utilities.operators import mean_distance
from src.data.sample_render import SampleRenderDataModule
        
class Generator(nn.Module):
    def __init__(self, config):        
        super(Generator,self).__init__()
        channels =  config.fast_generator_channels
        self.trunk = nn.Sequential(OrderedDict([
            ('head', ConvBlock(1, channels[0])),
            ('upsample', nn.Upsample(scale_factor=2)),
            ('main', nn.Sequential(OrderedDict([
                ('b'+str(i), ConvBlock(in_ch, out_ch))
                    for i, (in_ch, out_ch) in 
                    enumerate(zip(channels, channels[1:]))])))
        ]))
        self.points = nn.Sequential(
            spectral_norm(nn.Conv2d(channels[-1], 3, 3, 1, 1, bias=False)),
            #nn.Sigmoid(),)
            nn.Tanh(),)

    def forward(self, outline):
        trunk = self.trunk(outline)
        points = self.points(trunk)        
        return points
from src.config import get_parser

config = get_parser().parse_args(args=[])  

G = Generator(config)    
G

Generator(
  (trunk): Sequential(
    (head): ConvBlock(
      (conv): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lrelu): LeakyReLU(negative_slope=0.2)
    )
    (upsample): Upsample(scale_factor=2.0, mode=nearest)
    (main): Sequential(
      (b0): ConvBlock(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
      )
      (b1): ConvBlock(
        (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
      )
      (b2): ConvBlock(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
      )
    )
  )
  (points): Sequential(
    (0): Conv2d(256, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): Tanh()
  )
)

In [3]:
dm = SampleRenderDataModule(config)
batch = next(iter(dm.train_dataloader()))
for k,v in batch.items():
    print(k, v.shape)

image torch.Size([8, 1, 128, 128])
label torch.Size([8])
samples torch.Size([8, 65536, 3])


In [4]:
class RenderToSurface(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()        
        self.save_hyperparameters(hparams)
        
        self.G = Generator(hparams)
        self.faces = None                     

    def training_step(self, batch, batch_idx):
        v = self.G(batch['image'])
        if self.faces is None:
            _, _, w, h = v.shape
            faces = make_faces(w, h)[None]
            self.faces = torch.tensor(faces, device=v.device)
        v = grid_to_list(v)        
        meshes = Meshes(verts=v, faces=self.faces.expand(v.size(0), -1, -1))
        samples = sample_points_from_meshes(meshes)
        
        loss, _ = chamfer_distance(samples, batch['samples'])        
        self.log(f"train/loss", loss.item())
        return loss
        
        
    def configure_optimizers(self):
        lr, betas = 0.0003, (0.5, 0.999)        
        opt_g = torch.optim.Adam(self.G.parameters(),  lr=lr, betas=betas)
        return [opt_g], []
                 
model = RenderToSurface(config)
model

RenderToSurface(
  (G): Generator(
    (trunk): Sequential(
      (head): ConvBlock(
        (conv): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
      )
      (upsample): Upsample(scale_factor=2.0, mode=nearest)
      (main): Sequential(
        (b0): ConvBlock(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (lrelu): LeakyReLU(negative_slope=0.2)
        )
        (b1): ConvBlock(
          (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (lrelu): LeakyReLU(negative_slope=0.2)
        )
        (b2): ConvBlock(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (lrelu): LeakyReLU(negative_slope=0.2)
        )
      )
    )
    (points): Sequential(
      (0): Conv2d(256, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      

In [5]:
trainer = pl.Trainer(gpus=1, max_epochs=3000, #progress_bar_refresh_rate=5,
                     #terminate_on_nan=True, 
                     #profiler="pytorch",
                     #log_every_n_steps=2, 
                     callbacks=[LogRenderSample(config)],
                     precision=32,
                     #resume_from_checkpoint="./lightning_logs/version_9/checkpoints/epoch=1999-step=23999.ckpt"
                    )
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type      | Params
-----------------------------------
0 | G    | Generator | 1.0 M 
-----------------------------------
1.0 M     Trainable params
0         Non-trainable params
1.0 M     Total params
4.161     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Traceback (most recent call last):

Detected KeyboardInterrupt, attempting graceful shutdown...

