In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.loss import (
    chamfer_distance,    
)
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from src.callback.log_render_sample import LogRenderSample
from src.utilities.util import (
    grid_to_list,
    make_faces,
    list_to_grid,
)

In [2]:
from src.models.style_generator import VariationalGenerator
from src.config import get_parser

config = get_parser().parse_args(args=[])  

config.fast_image_size = 128
config.grid_full_size = 16
config.grid_slice_size = 16
config.batch_size = 32
config.stylist_channels = [3, 256, 256]
config.synthesis_channels = [256, 256, 256]
config.style_dim = 128
config.log_mesh_interval = 250
config.initial_input_fixed = True
G = VariationalGenerator(config)    

G#, decoder

VariationalGenerator(
  (encoder): Encoder(
    (trunk): Sequential(
      (conv0): ConvPoolBlock(
        (conv): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (conv1): ConvPoolBlock(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (lrelu): LeakyReLU(negative_slope=0.2)
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
      (flatten): Flatten(start_dim=1, end_dim=-1)
    )
    (mu): Linear(in_features=256, out_features=128, bias=True)
    (logVar): Linear(in_features=256, out_features=128, bias=True)
  )
  (synthesis): VariationalSynthesis(
    (input): Slices2D()
    (head): StyledConv2d(
      (conv): ModulatedConv2d(
        (modulation): Line

In [3]:
from src.data.sliced_render import SlicedRenderDataModule

dm = SlicedRenderDataModule(config)
batch = next(iter(dm.train_dataloader()))
for k,v in batch.items():
    print(k, v.shape)

slice_data torch.Size([32, 3, 16, 16])
slice_idx torch.Size([32, 2])


In [4]:
class RenderToSurface(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()        
        self.save_hyperparameters(hparams)
        
        self.G = VariationalGenerator(hparams)
        self.faces = None                     

    def kl_loss(self, mu, logVar):
        return -0.5 * torch.sum(1 + logVar - mu.pow(2) - logVar.exp())
         
        
    def training_step(self, batch, batch_idx):        
        size = batch['slice_data'].shape[-1]
        #v, rcn = self.G(batch['image'], batch['slice_idx'], size)
        v, mu, logVar = self.G(batch['slice_data'], batch['slice_idx'], size)
#         if self.faces is None:
#             _, _, w, h = v.shape
#             faces = make_faces(w, h)[None]
#             self.faces = torch.tensor(faces, device=v.device)
        #print(batch['image'].shape, v.shape, batch['grid'].shape)        
        grid_loss = F.l1_loss(v, batch['slice_data'].reshape_as(v))
        kl_loss = self.kl_loss(mu, logVar)
        #loss_rcn =  F.l1_loss(rcn, batch['slice_data'])
        #loss_rcn =  F.l1_loss(rcn, batch['image'])
        #loss_grid = F.mse_loss(v, batch['grid'].reshape_as(v))
        #loss_grid = F.l1_loss(v, list_to_grid(batch['grid']))
        #v = grid_to_list(v)        
        #meshes = Meshes(verts=v, faces=self.faces.expand(v.size(0), -1, -1))
        #samples = sample_points_from_meshes(meshes)
                
        
        #loss_samples, _ = chamfer_distance(samples, batch['samples'])        
        loss = grid_loss + kl_loss #+ loss_rcn#+ loss_samples
        
        #self.log(f"train/loss_samples", loss_samples.item())
        self.log(f"train/loss_grid", grid_loss.item())
        self.log(f"train/loss_kl", kl_loss.item())
        #self.log(f"train/loss_rcn", loss_rcn.item())
        self.log(f"train/loss", loss.item())
        return loss
        
        
    def configure_optimizers(self):
        lr, betas = 0.0003, (0.5, 0.999)        
        opt_g = torch.optim.Adam(self.G.parameters(),  lr=lr)#, betas=betas)
        return [opt_g], []
                 
model = RenderToSurface(config)
model

RenderToSurface(
  (G): VariationalGenerator(
    (encoder): Encoder(
      (trunk): Sequential(
        (conv0): ConvPoolBlock(
          (conv): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (lrelu): LeakyReLU(negative_slope=0.2)
          (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
        (conv1): ConvPoolBlock(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (lrelu): LeakyReLU(negative_slope=0.2)
          (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
        (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
        (flatten): Flatten(start_dim=1, end_dim=-1)
      )
      (mu): Linear(in_features=256, out_features=128, bias=True)
      (logVar): Linear(in_features=256, out_features=128, bias=True)
    )
    (synthesis): VariationalSynthesis(
      (input): Slices2D()
      (head): Style

In [5]:
self = model
size = batch['slice_data'].shape[-1]
v, _, _ = self.G(batch['slice_data'], batch['slice_idx'], size)


Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448278899/work/c10/core/TensorImpl.h:1156.)



In [6]:
trainer = pl.Trainer(gpus=1, max_epochs=3000, #progress_bar_refresh_rate=5,
                     #terminate_on_nan=True, 
                     #profiler="pytorch",
                     #log_every_n_steps=2, 
                     callbacks=[LogRenderSample(config)],
                     precision=32,
                     #resume_from_checkpoint="./lightning_logs/version_16/checkpoints/epoch=0-step=581.ckpt"
                    )
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type                 | Params
----------------------------------------------
0 | G    | VariationalGenerator | 1.9 M 
----------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.759     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…


Detected KeyboardInterrupt, attempting graceful shutdown...



In [7]:
self = model
size = batch['slice_data'].shape[-1]
v, rcn = self.G(batch['slice_data'], batch['slice_idx'], size)
v.shape, rcn.shape

AttributeError: 'NoneType' object has no attribute 'shape'

In [None]:
self.G.stylist(batch['slice_data']).shape #[0]

In [None]:
grids = dm.train_dataloader().dataset.grids

In [None]:
res = torch.empty(len(grids), 3, grids[0].size(-2), grids[0].size(-1))
res.shape

In [None]:
for i, t in enumerate(grids):
    print(i)
    res[i] = t

In [None]:
res.shape

In [None]:
torch.save(res, './data/scaled_32.pth')

In [None]:
import os

In [None]:
os.path.exists('./data/scaled_321.pth')