In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.loss import (
    chamfer_distance,    
)
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from src.callback.log_render_sample import LogRenderSample
from src.utilities.util import (
    grid_to_list,
    make_faces,
    list_to_grid,
)

In [2]:
from src.models.style_generator import StyleGenerator
from src.config import get_parser

config = get_parser().parse_args(args=[])  

config.fast_image_size = 128
config.grid_full_size = 32
config.grid_slice_size = 16
config.batch_size = 32
config.stylist_channels = [1, 64, 64, 64]
config.synthesis_channels = [128, 128, 128]
config.style_dim = 8
config.log_mesh_interval = 1000
config.initial_input_fixed = True
G = StyleGenerator(config)    

G#, decoder

StyleGenerator(
  (stylist): Stylist(
    (conv0): ConvPoolBlock(
      (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lrelu): LeakyReLU(negative_slope=0.2)
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (conv1): ConvPoolBlock(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lrelu): LeakyReLU(negative_slope=0.2)
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (conv2): ConvPoolBlock(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (lrelu): LeakyReLU(negative_slope=0.2)
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (linear): Linear(in_features=64, out_features=8, bias=True)
  )
  (synthesis):

In [3]:
from src.data.sliced_render import SlicedRenderDataModule

dm = SlicedRenderDataModule(config)
batch = next(iter(dm.train_dataloader()))
for k,v in batch.items():
    print(k, v.shape)

image torch.Size([32, 1, 128, 128])
slice_data torch.Size([32, 3, 16, 16])
slice_idx torch.Size([32, 2])


In [4]:
class RenderToSurface(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()        
        self.save_hyperparameters(hparams)
        
        self.G = StyleGenerator(hparams)
        self.faces = None                     

    def training_step(self, batch, batch_idx):        
        size = batch['slice_data'].shape[-1]
        v, rcn = self.G(batch['image'], batch['slice_idx'], size)
#         if self.faces is None:
#             _, _, w, h = v.shape
#             faces = make_faces(w, h)[None]
#             self.faces = torch.tensor(faces, device=v.device)
        #print(batch['image'].shape, v.shape, batch['grid'].shape)        
        loss_grid = F.l1_loss(v, batch['slice_data'].reshape_as(v))
        loss_rcn =  F.l1_loss(rcn, batch['image'])
        #loss_grid = F.mse_loss(v, batch['grid'].reshape_as(v))
        #loss_grid = F.l1_loss(v, list_to_grid(batch['grid']))
        #v = grid_to_list(v)        
        #meshes = Meshes(verts=v, faces=self.faces.expand(v.size(0), -1, -1))
        #samples = sample_points_from_meshes(meshes)
                
        
        #loss_samples, _ = chamfer_distance(samples, batch['samples'])        
        loss = loss_grid + loss_rcn#+ loss_samples
        
        #self.log(f"train/loss_samples", loss_samples.item())
        self.log(f"train/loss_grid", loss_grid.item())
        self.log(f"train/loss_rcn", loss_rcn.item())
        self.log(f"train/loss", loss.item())
        return loss
        
        
    def configure_optimizers(self):
        lr, betas = 0.0003, (0.5, 0.999)        
        opt_g = torch.optim.Adam(self.G.parameters(),  lr=lr)#, betas=betas)
        return [opt_g], []
                 
model = RenderToSurface(config)
model

RenderToSurface(
  (G): StyleGenerator(
    (stylist): Stylist(
      (conv0): ConvPoolBlock(
        (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (conv1): ConvPoolBlock(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (conv2): ConvPoolBlock(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
      (flatten): Flatten(start_dim=1, end_dim=-1)
      (linear): Linear

In [5]:
trainer = pl.Trainer(gpus=1, max_epochs=3000, #progress_bar_refresh_rate=5,
                     #terminate_on_nan=True, 
                     #profiler="pytorch",
                     #log_every_n_steps=2, 
                     callbacks=[LogRenderSample(config)],
                     precision=16,
                     #resume_from_checkpoint="./lightning_logs/version_9/checkpoints/epoch=1999-step=23999.ckpt"
                    )
trainer.fit(model, dm)

Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type           | Params
----------------------------------------
0 | G    | StyleGenerator | 453 K 
----------------------------------------
453 K     Trainable params
0         Non-trainable params
453 K     Total params
1.815     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…


Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448278899/work/c10/core/TensorImpl.h:1156.)


Detected KeyboardInterrupt, attempting graceful shutdown...



In [8]:
self = model
size = batch['slice_data'].shape[-1]
v, rcn = self.G(batch['image'], batch['slice_idx'], size)

In [9]:
v.shape, rcn.shape

(torch.Size([32, 3, 16, 16]), torch.Size([32, 1, 128, 128]))

In [10]:
self.G.stylist(batch['image'])

tensor([[-1.0226,  0.8807, -0.3657, -0.6904,  1.0401,  0.8720,  1.1700, -0.5188],
        [-0.8217,  0.7772,  0.0413, -0.8641,  1.0386,  0.5999,  1.1482, -0.3408],
        [-1.0579,  0.9156, -0.4911, -0.6044,  1.0112,  0.9509,  1.1178, -0.5456],
        [-0.6509,  0.6294,  0.4034, -1.0065,  1.0236,  0.3012,  1.0982, -0.2688],
        [-0.5713,  0.4965,  0.6352, -1.1897,  1.0227,  0.1270,  1.2363, -0.2466],
        [-0.7948,  0.8238,  0.0101, -0.8180,  1.0484,  0.6456,  1.1308, -0.2480],
        [-0.7577,  0.7039,  0.1555, -0.9168,  1.0161,  0.4917,  1.1491, -0.3155],
        [-1.1769,  1.0255, -0.7767, -0.4522,  1.0176,  1.1695,  1.0952, -0.5868],
        [-0.9000,  0.6225, -0.1199, -0.8851,  1.0262,  0.6168,  1.3350, -0.5491],
        [-0.7998,  0.4916,  0.1014, -0.9935,  0.9912,  0.4122,  1.3441, -0.5564],
        [-1.1851,  1.0156, -0.7473, -0.4690,  1.0285,  1.1374,  1.0693, -0.6395],
        [-1.0189,  0.8386, -0.3561, -0.7060,  1.0420,  0.8357,  1.1737, -0.5898],
        [-0.7803