In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl

from src.utilities.operators import mean_2d_3ch
from src.data.image_patch import ImagePatchDataModule
from src.loss.layer_loss import LayerLoss
from src.render.mesh_points_renderer import MeshPointsRenderer
from src.utilities.util import (
    make_faces, 
    list_to_grid, 
    grid_to_list,
)
from src.config import get_parser

from torchvision.transforms import ToPILImage

In [6]:
def create_model(no_ch):
    return nn.Sequential(
        nn.Conv2d(3, no_ch, 3, stride=1, padding=1,
            padding_mode='replicate'),
        nn.ReLU(),
#         nn.Conv2d(no_ch, no_ch, 3, stride=1, padding=1,
#             padding_mode='replicate'),
#         nn.ReLU(),
        nn.Conv2d(no_ch, 3, 3, stride=1, padding=1,
            padding_mode='replicate'),
        nn.Tanh(),
    )

In [7]:
class RenderToSurface(pl.LightningModule):
    
    def __init__(self, hparams, patch_file, patch_size, no_channels, mean_size):
        super().__init__()        
        self.save_hyperparameters(hparams)
        
        self.noise_lvl = 0.0003
        self.renderer = MeshPointsRenderer(config)
        
        self.patch_file = patch_file        
        vertices = F.interpolate(torch.load(patch_file)['vertices'],  
            size=patch_size, mode='bilinear', align_corners=True)
        self.register_buffer('vertices', vertices)
        
        self.no_channels = no_channels
        self.v_model = create_model(no_channels)
        self.c_model = create_model(no_channels)
        _, _, w, h = vertices.shape
        self.register_buffer('faces',
            torch.tensor(make_faces(w, h)[None]))
        layers=['relu1', 'relu2']#, 'relu3']#, 'relu4', 'relu5']
        self.loss_fn = LayerLoss('alex', layers=layers)
        self.mean2d3ch =  mean_2d_3ch(mean_size)
        
        
    def training_step(self, batch, batch_idx):        
        bs = batch.size(0)
        faces = self.faces.expand(bs, -1, -1)
        vert = self.vertices.expand(bs, -1, -1, -1)
        vert = vert + torch.randn_like(vert) * self.noise_lvl        
        mean_orig = self.mean2d3ch(vert)
        
        res = self.v_model(vert)
        res = res + mean_orig - self.mean2d3ch(res)
        colors = self.c_model(vert)
        
        imgs = self.renderer(res, faces, colors=colors,  mean=0.5, std=0.5, grayscale=False)
        imgs = torch.clamp(imgs, -1, 1)
                             
        loss = self.loss_fn(imgs, batch)
        #print(loss)
        self.log(f"train/loss", loss.item())
        return loss
        
        
    def configure_optimizers(self):
        lr, betas = 0.0001, (0.5, 0.999)        
        opt_g = torch.optim.Adam(list(self.v_model.parameters()) +
             list(self.c_model.parameters()),  lr=lr)#, betas=betas)
        return [opt_g], []

config = get_parser().parse_args(args=[])  
config.viewpoint_distance = 1.05
config.fast_image_size = 32
config.batch_size = 4

patch_file = './data/patch_64_64.pth'
patch_size = 256
no_channels = 128
mean_size = 33
img_file = './data/skin0.png'
                             
                             
model = RenderToSurface(config, patch_file, patch_size, no_channels, mean_size)
model                        

Setting up [baseline] perceptual loss: trunk [alex], v[0.1], spatial [off]


RenderToSurface(
  (renderer): MeshPointsRenderer()
  (v_model): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
    (1): ReLU()
    (2): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
    (3): Tanh()
  )
  (c_model): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
    (1): ReLU()
    (2): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
    (3): Tanh()
  )
  (loss_fn): LayerLoss(
    (net): alexnet(
      (slice1): Sequential(
        (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
        (1): ReLU(inplace=True)
      )
      (slice2): Sequential(
        (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (4): ReLU(inplace=True)
      )
  

In [8]:
dm = ImagePatchDataModule(config, img_file)
dm

<src.data.image_patch.ImagePatchDataModule at 0x7f74f1c12a60>

In [9]:
trainer = pl.Trainer(gpus=1, max_epochs=3000, #progress_bar_refresh_rate=5,
                     #terminate_on_nan=True, 
                     #profiler="pytorch",
                     #log_every_n_steps=2, 
                     #callbacks=[LogRenderSample(config)],
                     precision=32,
                     #resume_from_checkpoint="./lightning_logs/version_16/checkpoints/epoch=0-step=581.ckpt"
                    )
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | renderer  | MeshPointsRenderer | 0     
1 | v_model   | Sequential         | 7.0 K 
2 | c_model   | Sequential         | 7.0 K 
3 | loss_fn   | LayerLoss          | 2.5 M 
4 | mean2d3ch | Conv2d             | 9.8 K 
-------------------------------------------------
14.1 K    Trainable params
2.5 M     Non-trainable params
2.5 M     Total params
9.974     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…