In [1]:
import torch

import pytorch_lightning as pl

from src.config import get_parser
#from src.models.gan import GAN
from src.callback.points_image import PointsImage
from src.data.masked_datamodule import MaskedDataModule
from src.callback.export_mesh import ExportMesh
from pytorch3d.loss.chamfer import chamfer_distance

In [2]:
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.stylegan2.op import fused_leaky_relu
from src.stylegan2.Blocks import ModConvLayer
#from src.models.blocks import ConvBlock
from src.utilities.util import grid_to_list

class ConvBlock(nn.Sequential):
    def __init__(self, in_channel, out_ch, ker_size, padd, stride):
        super(ConvBlock,self).__init__()
        self.add_module('conv',nn.Conv2d(in_channel, out_ch, 
                                         kernel_size=ker_size,
                                         stride=stride,
                                         padding=padd)),
        self.add_module('norm',nn.BatchNorm2d(out_ch)),
        self.add_module('LeakyRelu',nn.LeakyReLU(0.2, inplace=True))
        #self.add_module('GELU',nn.GELU())
        #self.add_module('Nonlinearity', nn.Hardswish(inplace=True))

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('Norm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0) 


    
class Generator(nn.Module):
    def __init__(self, opt):
        super(Generator, self).__init__()                      
        in_ch, out_ch, ker_size, stride, padd_size = (opt.G_in_ch,
            opt.G_out_ch, opt.ker_size, opt.stride, opt.padd_size)
        self.head =  ConvBlock(in_ch, out_ch, ker_size,stride, padd_size)
        self.b1 = ConvBlock(out_ch, out_ch, ker_size,stride, padd_size)
        self.b2 = ConvBlock(out_ch, out_ch, ker_size,stride, padd_size)
        self.b3 = ConvBlock(out_ch, out_ch, ker_size,stride, padd_size)
        
        self.s1 = nn.Conv2d(out_ch, 3, ker_size, stride, padd_size)
        self.s2 = nn.Conv2d(out_ch, 3, ker_size, stride, padd_size)
        self.s3 = nn.Conv2d(out_ch, 3, ker_size, stride, padd_size)        
    
    def forward(self, points):                
        x = self.head(points)
        
        x = self.b1(x)
        vrt = self.s1(x)  + points
        
        x = self.b2(x)
        vrt = self.s2(x) + vrt
        
        x = self.b3(x)
        vrt = self.s3(x) + vrt
                
        #x = x + points
        vrt = grid_to_list(vrt)
        return vrt    

In [3]:
import time
from collections import OrderedDict

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.structures import Meshes
from pytorch3d.loss import ( 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)

from src.models.discriminator import Discriminator
#from src.models.generator import Generator
from src.render.mesh_renderer import MeshPointsRenderer
from src.loss.edge_loss import EdgeLoss
from src.utilities.util import grid_to_list

class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams        
        
        self.G = Generator(hparams)                     
        self.edge_loss = EdgeLoss(hparams)
     
    def forward(self, points):
        return self.G(points)        
    
    def training_step(self, batch, batch_idx):
        points_coarse = batch['points_coarse']                
        points_fine = batch['points']                
        bs = points_fine.size(0)        
        
        # train generator        
        points_noise = points_coarse + torch.randn_like(points_coarse) * self.hparams.G_noise_amp            
        vertices = self.G(points_noise)
        #print(vertices.shape)
        #vrt_loss = F.mse_loss(vertices, points_noise.reshape_as(vertices))
        points_fine = points_fine.reshape_as(vertices)
        vrt_loss = F.l1_loss(vertices, points_fine)
        #chm_loss, _ = chamfer_distance(vertices, points_fine)
        
        #vrt_loss = vrt_loss / bs
        #cos_sim = torch.cosine_similarity(normals, pt_normals, dim=-1)
        #normal_consistency_loss = -(cos_sim.sum() / cos_sim.numel() - 1.)

        #edge_loss = self.edge_loss(vertices)            
        g_loss = vrt_loss #+ chm_loss#+ edge_loss * self.hparams.mesh_edge_loss_weight
        #+ normal_consistency_loss * self.hparams.mesh_normal_consistency_weight
        #tqdm_dict = {'g_loss': g_loss}

        self.log("loss/g_loss", g_loss, prog_bar=True)
        self.log("loss/vrt_loss", vrt_loss)
        #self.log("loss/chm_loss", chm_loss)
        #self.log("loss/edge_loss", edge_loss)
        #self.log("loss/normal_consistency_loss", normal_consistency_loss, on_epoch=True)
        #return OrderedDict({ 'loss': g_loss, })
        return g_loss
    
    def configure_optimizers(self):
        lr_g = self.hparams.lr_g        
        b1 = self.hparams.beta1
        b2 = self.hparams.beta2      
        opt_g = torch.optim.AdamW(self.G.parameters(), 
                                 lr=lr_g, betas=(b1, b2))
        #opt_g = torch.optim.Adam(self.G.parameters())
        return opt_g
    
config = get_parser().parse_args(args=[])
config.batch_size = 128# + 128
config.blueprint = 'blueprint_16_512.npz'
config.data_blueprint_size = 256
config.data_blueprint_coarse =  128 + 64
config.data_patch_size = 32
config.lr_g = 0.001
config.log_mesh_interval = 1

config.G_out_ch = 256 # 256 default
model = GAN(config)
model

GAN(
  (G): Generator(
    (head): ConvBlock(
      (conv): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (b1): ConvBlock(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (b2): ConvBlock(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (b3): ConvBlock(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, trac

In [4]:
#config.batch_size = 128
# config.blueprint = 'blueprint_16_512.npz'
# config.data_blueprint_size = 64
# config.data_blueprint_coarse = 
# config.data_patch_size = 32
#config.batch_size = 64
config.num_workers = 6
dm = MaskedDataModule(config)
dm.setup()
dm

torch.Size([16, 3, 512, 512])


<src.data.masked_datamodule.MaskedDataModule at 0x7f6b47e2cd90>

In [None]:
trainer = pl.Trainer(precision=16, gpus=1, callbacks=[ExportMesh(config)])
trainer.fit(model, dm)
#trainer.tune(model)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.

  | Name      | Type      | Params
----------------------------------------
0 | G         | Generator | 1.8 M 
1 | edge_loss | EdgeLoss  | 0     
----------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.201     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

points.size(-2), points.size(-1) 256 256


In [None]:
# 3.33it/s
dm.train_dataloader().dataset.entries.shape1

In [None]:
G = model.G
G

In [None]:
points = trainer.datamodule.train_dataloader().dataset.points#.to(pl_module.device)
points.shape

In [None]:
i = 11
pts = points[i][None]  
grid_to_list(pts)[0].cpu().numpy().shape

In [None]:
batch  = next(iter(trainer.datamodule.train_dataloader()))
points = batch['points'].to(model.device)
points.shape

In [None]:
batch.keys()

In [None]:
x = self.head(points)
#x = self.body(x)        
#x = self.tail(x)
x = x + points
x = grid_to_list(x)