In [1]:
import torch

import pytorch_lightning as pl

from src.config import get_parser
#from src.models.gan import GAN
from src.callback.points_image import PointsImage
from src.data.masked_datamodule import MaskedDataModule

In [2]:
import time
from collections import OrderedDict

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.structures import Meshes
from pytorch3d.loss import ( 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)

from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.render.mesh_renderer import MeshPointsRenderer
from src.loss.edge_loss import EdgeLoss
from src.utilities.util import grid_to_list

class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.automatic_optimization = False        
        
        self.G = Generator(hparams)             
        # Renderer requires device, created in .to() step        
        self.edge_loss = EdgeLoss(hparams)
        
     
    def forward(self, shape, style):
        return self.G(shape, style)
    
    def adversarial_loss(self, y_hat, y):
        return F.mse_loss(y_hat, y)
    
    def training_step(self, batch, batch_idx):                
        points = batch['points']                
        bs = points.size(0)        
        
        # train generator        
        points = points #+ torch.randn_like(points) * self.hparams.G_noise_amp            
        vertices = self.G(points)            
        
        points = points + torch.randn_like(points) * self.hparams.G_noise_amp 
        vrt_loss = self.adversarial_loss(vertices, points.reshape_as(vertices))

        #cos_sim = torch.cosine_similarity(normals, pt_normals, dim=-1)
        #normal_consistency_loss = -(cos_sim.sum() / cos_sim.numel() - 1.)

        #edge_loss = self.edge_loss(vertices)            
        g_loss = vrt_loss #+ edge_loss * self.hparams.mesh_edge_loss_weight
        #+ normal_consistency_loss * self.hparams.mesh_normal_consistency_weight
        #tqdm_dict = {'g_loss': g_loss}

        self.log("loss/g_loss", g_loss, prog_bar=True)
        self.log("loss/vrt_loss", vrt_loss)
        #self.log("loss/edge_loss", edge_loss)
        #self.log("loss/normal_consistency_loss", normal_consistency_loss, on_epoch=True)
        return OrderedDict({ 'loss': g_loss, })
    
    def configure_optimizers(self):
        lr_g = self.hparams.lr_g        
        b1 = self.hparams.beta1
        b2 = self.hparams.beta2      
        opt_g = torch.optim.Adam(self.G.parameters(), 
                                 lr=lr_g, betas=(b1, b2))        
        return opt_g
    
config = get_parser().parse_args(args=[])
config.batch_size = 32 
config.data_patch_size = 96
config.blueprint = 'blueprint_16_512.npz'
config.data_blueprint_size = 2048
config.lr_g = 0.0003

config.G_out_ch = 48 # 256 default
model = GAN(config)
model

GAN(
  (G): Generator(
    (head): ConvBlock(
      (conv): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (body): Sequential(
      (b1): ConvBlock(
        (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (b2): ConvBlock(
        (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (LeakyRelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (b3): ConvBlock(
        (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (norm): BatchNorm2d(48, eps

In [3]:
dm = MaskedDataModule(config)
dm.setup()
dm

<src.data.masked_datamodule.MaskedDataModule at 0x7fa50fd6dbb0>

In [None]:
trainer = pl.Trainer(gpus=1, progress_bar_refresh_rate=20)
trainer.fit(model, dm)
#trainer.tune(model)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name      | Type      | Params
----------------------------------------
0 | G         | Generator | 65.4 K
1 | edge_loss | EdgeLoss  | 0     
----------------------------------------
65.4 K    Trainable params
0         Non-trainable params
65.4 K    Total params
0.262     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

In [None]:
torch.set_printoptions(sci_mode=False, precision=8)

batch =  next(iter(dm.train_dataloader()))
points = batch['points']  
points_noise = points + torch.randn_like(points) * config.G_noise_amp 
points.shape

In [None]:
model.adversarial_loss(points, points_noise)