In [None]:
import torch

import pytorch_lightning as pl

from src.config import get_parser
#from src.models.gan import GAN
from src.callback.points_image import PointsImage
from src.data.masked_datamodule import MaskedDataModule

In [None]:
import time
from collections import OrderedDict

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch3d.structures import Meshes
from pytorch3d.loss import ( 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)

from src.models.discriminator import Discriminator
from src.models.generator import Generator
from src.render.points_renderer import PointsRenderer
from src.render.pulsar_renderer import PulsarRenderer
from src.loss.edge_loss import EdgeLoss




class GAN(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.automatic_optimization = False
        self.mean = sum(hparams.image_mean) / len(hparams.image_mean)
        self.std = sum(hparams.image_std) / len(hparams.image_std)
        
        self.G = Generator(hparams)
        self.D = Discriminator(hparams)        
        # Renderer requires device, created in .to() step
        self.R = PointsRenderer(hparams)        
        self.edge_loss = EdgeLoss(hparams)
        
     
    def forward(self, shape, style):
        return self.G(shape, style)
    
    def adversarial_loss(self, y_hat, y):
        return F.mse_loss(y_hat, y)
    
    def training_step(self, batch, batch_idx, optimizer_idx):        
        img_patch = batch['img_patch']
        points = batch['points']
        pt_normals = batch['normals']
        faces = batch['faces']
        bs = img_patch.size(0)
        pt_normals = pt_normals.reshape(bs, 3, -1).permute(0, 2, 1)        
        self.R.setup(points.device)        
        
        # train generator
        if optimizer_idx == 0:            
            vertices = self.G(points)
            calc_normals = self.R.vrt_nrm.vertex_normals_fast(vertices.detach())
            renders =  self.R(vertices, normals=calc_normals).permute(0, 3, 1, 2)
            renders = (renders - self.mean) / self.std                        
            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(bs, 1).type_as(points)            
            render_loss = self.adversarial_loss(self.D(renders), valid)
            
            #cos_sim = torch.cosine_similarity(normals, pt_normals, dim=-1)
            #normal_consistency_loss = -(cos_sim.sum() / cos_sim.numel() - 1.)
            
            edge_loss = self.edge_loss(vertices)            
            g_loss = render_loss + edge_loss * self.hparams.mesh_edge_loss_weight \
                + edge_loss * self.hparams.mesh_edge_loss_weight
            #+ normal_consistency_loss * self.hparams.mesh_normal_consistency_weight
            #tqdm_dict = {'g_loss': g_loss}
            
            self.log("loss/g_loss", g_loss, on_epoch=True, prog_bar=True)
            self.log("loss/render_loss", render_loss, on_epoch=True)
            self.log("loss/edge_loss", edge_loss, on_epoch=True)
            #self.log("loss/normal_consistency_loss", normal_consistency_loss, on_epoch=True)
            return OrderedDict({ 'loss': g_loss, })

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(bs, 1).type_as(points)         

            real_loss = self.adversarial_loss(self.D(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1).type_as(points)            
                        
            vertices = self.G(points)
            calc_normals = self.R.vrt_nrm.vertex_normals_fast(vertices.detach())
            renders =  self.R(vertices, normals=calc_normals).permute(0, 3, 1, 2)
            renders = (renders - self.mean) / self.std

            fake_loss = self.adversarial_loss(
                self.D(renders.detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,                
            })
            self.log("loss/d_loss", d_loss, on_step=True, on_epoch=True, prog_bar=True)            
            return output

    
    def configure_optimizers(self):
        lr_g = self.hparams.lr_g
        lr_d = self.hparams.lr_d
        b1 = self.hparams.beta1
        b2 = self.hparams.beta2      
        opt_g = torch.optim.Adam(self.G.parameters(), 
                                 lr=lr_g, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.D.parameters(), 
                                 lr=lr_d, betas=(b1, b2))
        return [opt_g, opt_d], []
    
config = get_parser().parse_args(args=[])
#config.batch_size = 8
model = GAN(config)
model

In [None]:
dm = MaskedDataModule(config)
dm.setup()
dm

In [None]:
trainer = pl.Trainer(gpus=1, 
                     max_epochs=5, 
                     progress_bar_refresh_rate=20,
                     callbacks=[PointsImage(config)])
trainer.fit(model, dm)