In [8]:
import torch

import pytorch_lightning as pl

from src.config import get_parser
#from src.models.gan import GAN
from src.callback.points_image import PointsImage
from src.data.masked_datamodule import MaskedDataModule
from src.callback.export_mesh import ExportMesh
from pytorch3d.loss.chamfer import chamfer_distance

In [9]:
import math
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.models.layers import ModulateConvBlock
#from src.models.blocks import ConvBlock
from src.utilities.util import grid_to_list

class ConvBlock(nn.Sequential):
    def __init__(self, in_channel, out_ch, ker_size, stride, padding):
        super(ConvBlock,self).__init__()
        self.add_module('conv',nn.Conv2d(in_channel, out_ch, 
                                         kernel_size=ker_size,
                                         stride=stride,
                                         padding=padding)),
        self.add_module('norm',nn.BatchNorm2d(out_ch)),
        #self.add_module('norm',nn.InstanceNorm2d(out_ch)),
        self.add_module('LeakyRelu',nn.LeakyReLU(0.2, inplace=True))
        #self.add_module('Nonlinearity',nn.Tanh())
        #self.add_module('GELU',nn.GELU())
        #self.add_module('Nonlinearity', nn.Hardswish(inplace=True))

#     def weights_init(m):
#         classname = m.__class__.__name__
#         if classname.find('Conv2d') != -1:
#             m.weight.data.normal_(0.0, 0.02)
#         elif classname.find('Norm') != -1:
#             m.weight.data.normal_(1.0, 0.02)
#             m.bias.data.fill_(0) 

class GenBlock(nn.Module):
    def __init__(self, latent_size, in_ch, out_ch, kernel, stride, padding, pool=None):
        super(GenBlock, self).__init__()        
        self.mod_conv =  ModulateConvBlock(latent_size, in_ch, out_ch, kernel)
        #self.mod_conv = ConvBlock(in_ch, out_ch, kernel, stride=stride, padding=padding)
        self.conv = ConvBlock(out_ch, out_ch, kernel, stride=stride, padding=padding)
        self.to_points = nn.Conv2d(out_ch, 3, kernel, stride=stride, padding=padding)
        self.pool = nn.AvgPool2d(kernel_size=pool, stride=pool) if pool else None
    
    def upscale(self, x, scale_factor):
        return F.interpolate(x, scale_factor=scale_factor, mode='bilinear', 
                             align_corners=True)# if scale_factor else x     
    
    def forward(self, x, z, prev_vrt):
        x = self.mod_conv(x, z)
        x = self.conv(x)
        vrt = self.to_points(self.pool(x)) if self.pool else self.to_points(x)                    
        scale_factor = prev_vrt.size(-1) // vrt.size(-1)        
        if scale_factor > 1:
            vrt = self.upscale(vrt, scale_factor)
        vrt = vrt + prev_vrt
        return x, vrt
        
        

    
class Generator(nn.Module):
    def __init__(self, opt):
        super(Generator, self).__init__()                      
        latent_size, in_ch, out_ch, ker_size, stride, padding= (opt.latent_size,
            opt.G_in_ch, opt.G_out_ch, opt.ker_size, opt.stride, opt.padd_size)
        self.head =  ConvBlock(in_ch, out_ch, ker_size, stride=stride, padding=padding)
        self.pools = [4, 2, None]
        
        self.b1 = GenBlock(latent_size, out_ch, out_ch, ker_size, stride=stride, 
                      padding=padding, pool=self.pools[0])
        self.b2 = GenBlock(latent_size, out_ch, out_ch, ker_size, stride=stride, 
                      padding=padding, pool=self.pools[1])
        self.b3 = GenBlock(latent_size, out_ch, out_ch, ker_size, stride=stride, 
                           padding=padding, pool=self.pools[2])
    
    def forward(self, points, latents):                        
        x = self.head(points)
        
        x, vrt = self.b1(x, latents[0], points)
        x, vrt = self.b2(x, latents[1], vrt)
        x, vrt = self.b3(x, latents[2], vrt)
        
        #vrt = grid_to_list(vrt)
        return vrt

In [15]:
import time
from collections import OrderedDict

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.structures import Meshes
from pytorch3d.loss import ( 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)

from src.models.discriminator import Discriminator
#from src.models.generator import Generator
from src.render.mesh_renderer import MeshPointsRenderer
from src.loss.edge_loss import EdgeLoss
from src.utilities.util import grid_to_list

class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams        
        
        self.G = Generator(hparams)                     
        #self.edge_loss = EdgeLoss(hparams)
     
    def forward(self, points):
        return self.G(points)        
    
    def training_step(self, batch, batch_idx):
        points_coarse = batch['points_coarse']#batch['points']             
        points_fine = batch['points']                
        bs = points_fine.size(0)        
        
        # train generator
        with torch.autograd.set_detect_anomaly(True):            
            points_noise = points_coarse + torch.randn_like(points_coarse) * self.hparams.G_noise_amp
            z_size = self.hparams.latent_size
            zero_styles = [torch.zeros(bs, z_size, device=points_coarse.device) for _ in range(3)]
            vertices = self.G(points_noise, zero_styles)
            #print(vertices.shape)
            #vrt_loss = F.mse_loss(vertices, points_noise.reshape_as(vertices))
            points_fine = points_fine.reshape_as(vertices)
            vrt_loss = F.l1_loss(vertices, points_fine)
            #chm_loss, _ = chamfer_distance(vertices, points_fine)

            #vrt_loss = vrt_loss / bs
            #cos_sim = torch.cosine_similarity(normals, pt_normals, dim=-1)
            #normal_consistency_loss = -(cos_sim.sum() / cos_sim.numel() - 1.)

            #edge_loss = self.edge_loss(vertices)            
            g_loss = torch.sqrt(vrt_loss) #+ chm_loss#+ edge_loss * self.hparams.mesh_edge_loss_weight
            #+ normal_consistency_loss * self.hparams.mesh_normal_consistency_weight
            #tqdm_dict = {'g_loss': g_loss}

            self.log("loss/g_loss", g_loss)
            self.log("loss/vrt_loss", vrt_loss)
            #self.log("loss/chm_loss", chm_loss)
            #self.log("loss/edge_loss", edge_loss)
            #self.log("loss/normal_consistency_loss", normal_consistency_loss, on_epoch=True)
            #return OrderedDict({ 'loss': g_loss, })
            return g_loss
    
    def configure_optimizers(self):
        lr_g = self.hparams.lr_g        
        b1 = self.hparams.beta1
        b2 = self.hparams.beta2      
#       opt_g = torch.optim.RMSprop(self.G.parameters())
        opt_g = torch.optim.AdamW(self.G.parameters(), 
                                  lr=lr_g#, betas=(b1, b2)
                                 )
#         opt_g = torch.optim.Adam(self.G.parameters())
        return opt_g
    
config = get_parser().parse_args(args=[])
config.batch_size = 4# + 128
config.blueprint = 'blueprint_16_512.npz'
#config.blueprint = 'vezuvio255.npz'

#config.blueprint = 'blueprint_noise_1_512.npz'
config.data_blueprint_size = 256
config.data_blueprint_coarse =  128 + 64
config.data_patch_size = 64
config.lr_g = 0.01
config.log_mesh_interval = 1

# config.ker_size = 5
# config.stride = 1
# config.padd_size = 2
config.latent_size = 32

config.G_out_ch = 128 # 256 default
model = GAN(config)
model

#torch.optim.RMSprop()

GAN(
  (G): Generator(
    (head): ConvBlock(
      (conv): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (b1): GenBlock(
      (mod_conv): ModulateConvBlock(
        (style): DenseBlock(
          (fc): Linear(in_features=32, out_features=128, bias=False)
          (activate): Identity()
        )
        (activate): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (conv): ConvBlock(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (to_points): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): AvgPool2d(kernel_size=4, stride=4, padding=0)
    )
    (b2): 

In [5]:
model.hparams.latent_size

256

In [16]:
#config.batch_size = 128
# config.blueprint = 'blueprint_16_512.npz'
# config.data_blueprint_size = 64
# config.data_blueprint_coarse = 
config.data_patch_size = 16
config.batch_size = 128
config.num_workers = 6
dm = MaskedDataModule(config)
dm.setup()
dm

torch.Size([16, 3, 512, 512])


<src.data.masked_datamodule.MaskedDataModule at 0x7f8aa4196be0>

In [17]:
trainer = pl.Trainer(gpus=1, callbacks=[ExportMesh(config)])
trainer.fit(model, dm)
#trainer.tune(model)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name | Type      | Params
-----------------------------------
0 | G    | Generator | 913 K 
-----------------------------------
913 K     Trainable params
0         Non-trainable params
913 K     Total params
3.653     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

points.size(-2), points.size(-1) 256 256



1

In [None]:
len(torch.rand(1, 3, 64, 64).shape)

In [None]:
grid_to_list(torch.rand(1, 3, 64, 64)).shape

In [None]:
torch.nn.functional.leaky_relu(torch.rand(1))