In [1]:
import torch

import pytorch_lightning as pl

from src.config import get_parser
#from src.models.gan import GAN
from src.callback.points_image import PointsImage
from src.data.masked_datamodule import MaskedDataModule
from src.callback.export_mesh import ExportMesh

In [9]:
import time
from collections import OrderedDict

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.structures import Meshes
from pytorch3d.loss import ( 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)

from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.render.mesh_renderer import MeshPointsRenderer
from src.loss.edge_loss import EdgeLoss
from src.utilities.util import grid_to_list

class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams        
        
        self.G = Generator(hparams)                     
        self.edge_loss = EdgeLoss(hparams)
     
    def forward(self, points):
        return self.G(points)        
    
    def training_step(self, batch, batch_idx):
        points_coarse = batch['points_coarse']                
        points_fine = batch['points']                
        bs = points_fine.size(0)        
        
        # train generator        
        points_noise = points_coarse + torch.randn_like(points_coarse) * self.hparams.G_noise_amp            
        vertices = self.G(points_noise)
        
        #vrt_loss = F.mse_loss(vertices, points_noise.reshape_as(vertices))
        vrt_loss = F.l1_loss(vertices, points_fine.reshape_as(vertices))
        #vrt_loss = vrt_loss / bs
        #cos_sim = torch.cosine_similarity(normals, pt_normals, dim=-1)
        #normal_consistency_loss = -(cos_sim.sum() / cos_sim.numel() - 1.)

        #edge_loss = self.edge_loss(vertices)            
        g_loss = vrt_loss #+ edge_loss * self.hparams.mesh_edge_loss_weight
        #+ normal_consistency_loss * self.hparams.mesh_normal_consistency_weight
        #tqdm_dict = {'g_loss': g_loss}

        self.log("loss/g_loss", g_loss, prog_bar=True)
        self.log("loss/vrt_loss", vrt_loss)
        
        #self.log("loss/edge_loss", edge_loss)
        #self.log("loss/normal_consistency_loss", normal_consistency_loss, on_epoch=True)
        #return OrderedDict({ 'loss': g_loss, })
        return g_loss
    
    def configure_optimizers(self):
        lr_g = self.hparams.lr_g        
        b1 = self.hparams.beta1
        b2 = self.hparams.beta2      
        opt_g = torch.optim.Adam(self.G.parameters(), 
                                 lr=lr_g, betas=(b1, b2))
        #opt_g = torch.optim.Adam(self.G.parameters())
        return opt_g
    
config = get_parser().parse_args(args=[])
config.batch_size = 128
config.blueprint = 'blueprint_16_512.npz'
config.data_blueprint_size = 32
config.data_blueprint_coarse =  16
config.data_patch_size = 4
config.lr_g = 0.0003#0003
config.log_mesh_interval = 20

config.G_out_ch = 512 # 256 default
model = GAN(config)
model

GAN(
  (G): Generator(
    (head): ConvBlock(
      (conv): Conv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Nonlinearity): Hardswish()
    )
    (body): Sequential(
      (b1): ConvBlock(
        (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (Nonlinearity): Hardswish()
      )
      (b2): ConvBlock(
        (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (Nonlinearity): Hardswish()
      )
      (b3): ConvBlock(
        (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (Nonlin

In [11]:
#config.batch_size = 128
# config.blueprint = 'blueprint_16_512.npz'
# config.data_blueprint_size = 64
# config.data_blueprint_coarse = 
# config.data_patch_size = 32

dm = MaskedDataModule(config)
dm.setup()
dm

<src.data.masked_datamodule.MaskedDataModule at 0x7f424ca3a640>

In [12]:
trainer = pl.Trainer(gpus=1,callbacks=[ExportMesh(config)])
trainer.fit(model, dm)
#trainer.tune(model)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name      | Type      | Params
----------------------------------------
0 | G         | Generator | 7.1 M 
1 | edge_loss | EdgeLoss  | 0     
----------------------------------------
7.1 M     Trainable params
0         Non-trainable params
7.1 M     Total params
28.447    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), maxâ€¦




1

In [10]:
model.G = G

In [5]:
G = model.G
G

Generator(
  (head): ConvBlock(
    (conv): Conv2d(3, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (Nonlinearity): Hardswish()
  )
  (body): Sequential(
    (b1): ConvBlock(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Nonlinearity): Hardswish()
    )
    (b2): ConvBlock(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Nonlinearity): Hardswish()
    )
    (b3): ConvBlock(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Nonlinearity): Hardswish()
    )
  )
  (tail): Sequential(

In [None]:
points = trainer.datamodule.train_dataloader().dataset.points#.to(pl_module.device)
points.shape

In [None]:
i = 11
pts = points[i][None]  
grid_to_list(pts)[0].cpu().numpy().shape

In [None]:
batch  = next(iter(trainer.datamodule.train_dataloader()))
points = batch['points'].to(model.device)
points.shape

In [None]:
batch.keys()

In [None]:
x = self.head(points)
#x = self.body(x)        
#x = self.tail(x)
x = x + points
x = grid_to_list(x)