In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.loss import (
    chamfer_distance,    
)
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from src.callback.log_render_sample import LogRenderSample
from src.utilities.util import (
    grid_to_list,
    make_faces,
    list_to_grid,
)

In [2]:
from src.models.style_generator import StyleGenerator
from src.config import get_parser

config = get_parser().parse_args(args=[])  
config.fast_image_size = 32
config.fast_batch_size = 8
config.fast_generator_channels = [256, 256, 256, 256, 256]

G = StyleGenerator(config)    
G

StyleGenerator(
  (stylist): Stylist(
    (conv0): ConvBlock(
      (conv): Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lrelu): LeakyReLU(negative_slope=0.2)
    )
    (conv1): ConvBlock(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lrelu): LeakyReLU(negative_slope=0.2)
    )
    (conv2): ConvBlock(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lrelu): LeakyReLU(negative_slope=0.2)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (linear): Linear(in_features=256, out_features=256, bias=True)
  )
  (synthesis): Synthesis(
    (input): ConstantInput()
    (head): StyledConv2d(
      (conv): ModulatedConv2d(
        (modulation): Linear(in_features=256, out_features=3, bias=True)
      )
      (noise): NoiseInjection()
      (act): LeakyReLU(negative_slope=0.2)
    )
    (trunk)

In [3]:
from src.data.sample_render import SampleRenderDataModule

dm = SampleRenderDataModule(config)
batch = next(iter(dm.train_dataloader()))
for k,v in batch.items():
    print(k, v.shape)

image torch.Size([8, 1, 32, 32])
label torch.Size([8])
samples torch.Size([8, 1024, 3])
grid torch.Size([8, 3, 32, 32])


In [4]:
class RenderToSurface(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()        
        self.save_hyperparameters(hparams)
        
        self.G = StyleGenerator(hparams)
        self.faces = None                     

    def training_step(self, batch, batch_idx):
        v = self.G(batch['image'])
#         if self.faces is None:
#             _, _, w, h = v.shape
#             faces = make_faces(w, h)[None]
#             self.faces = torch.tensor(faces, device=v.device)
        #print(batch['image'].shape, v.shape, batch['grid'].shape)        
        loss_grid = F.l1_loss(v, batch['grid'].reshape_as(v))
        #loss_grid = F.l1_loss(v, list_to_grid(batch['grid']))
        #v = grid_to_list(v)        
        #meshes = Meshes(verts=v, faces=self.faces.expand(v.size(0), -1, -1))
        #samples = sample_points_from_meshes(meshes)
                
        
        #loss_samples, _ = chamfer_distance(samples, batch['samples'])        
        loss = loss_grid #+ loss_samples
        
        #self.log(f"train/loss_samples", loss_samples.item())
        self.log(f"train/loss_grid", loss_grid.item())
        self.log(f"train/loss", loss.item())
        return loss
        
        
    def configure_optimizers(self):
        lr, betas = 0.0001, (0.5, 0.999)        
        opt_g = torch.optim.Adam(self.G.parameters(),  lr=lr, betas=betas)
        return [opt_g], []
                 
model = RenderToSurface(config)
model

RenderToSurface(
  (G): StyleGenerator(
    (stylist): Stylist(
      (conv0): ConvBlock(
        (conv): Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
      )
      (conv1): ConvBlock(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
      )
      (conv2): ConvBlock(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
      (flatten): Flatten(start_dim=1, end_dim=-1)
      (linear): Linear(in_features=256, out_features=256, bias=True)
    )
    (synthesis): Synthesis(
      (input): ConstantInput()
      (head): StyledConv2d(
        (conv): ModulatedConv2d(
          (modulation): Linear(in_features=256, out_features=3, bias=True)
        )
        (noise): Noise

In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=3000, #progress_bar_refresh_rate=5,
                     #terminate_on_nan=True, 
                     #profiler="pytorch",
                     #log_every_n_steps=2, 
                     callbacks=[LogRenderSample(config)],
                     precision=32,
                     #resume_from_checkpoint="./lightning_logs/version_9/checkpoints/epoch=1999-step=23999.ckpt"
                    )
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type           | Params
----------------------------------------
0 | G    | StyleGenerator | 3.0 M 
----------------------------------------
3.0 M     Trainable params
0         Non-trainable params
3.0 M     Total params
12.128    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Traceback (most recent call last):
  File "/home/bobi/miniconda3/envs/pytorch3d_05/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/bobi/miniconda3/envs/pytorch3d_05/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/bobi/miniconda3/envs/pytorch3d_05/lib/python3.8/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/home/bobi/miniconda3/envs/pytorch3d_05/lib/python3.8/shutil.py", line 722, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/bobi/miniconda3/envs/pytorch3d_05/lib/python3.8/shutil.py", line 720, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-2c4diprj'
Traceback (most recent call last):
  File "/home/bobi/miniconda3/envs/pytorch3d_05/lib/python3.8/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/bobi/miniconda3/envs/pytorch3d_05/l