In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torchvision.utils import make_grid
import pytorch_lightning as pl

from src.render.mesh_renderer import MeshPointsRenderer
from src.models.generator import Generator
from src.augment.diffaug import DiffAugment
from src.augment.geoaug import GeoAugment
from src.callback.image_mesh import ImageMesh

from torchvision import transforms
from PIL import Image

from src.data.fast_datamodule import FastDataModule
from src.data.baseline_dataset import BaselineDataset
from src.utilities.operators import make_laplacian

def get_img_t(file, config):
    img = Image.open(file)
    size = config.fast_image_size        
    mean = config.fast_image_mean
    std = config.fast_image_std
    transform = transforms.Compose([
        transforms.Resize([size, size]),            
        transforms.Grayscale(),          
        transforms.ToTensor(),
        transforms.Normalize(mean=(mean), std=(std)),
    ])
    return transform(img)[None]

file_root = '/home/bobi/Desktop/pic2mesh/data/stan_lee'
files = {
    'p00': os.path.join(file_root, 'stan_lee_p00.png'),
    'p45': os.path.join(file_root, 'stan_lee_p45.png'),
    'n45': os.path.join(file_root, 'stan_lee_n45.png'),
    'p90': os.path.join(file_root, 'stan_lee_p90.png'),
    'n90': os.path.join(file_root, 'stan_lee_n45.png'),
}



class RSP(pl.LightningModule):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.mean = hparams.fast_image_mean
        self.std = hparams.fast_image_std        
        self.diffaug = hparams.diffaug_policy
        self.G_noise_amp = hparams.G_noise_amp
        self.log_render_interval = hparams.log_render_interval
        
        self.G = Generator(hparams)
        
        self.Rp00 = MeshPointsRenderer(hparams)
        hparams.viewpoint_azimuth = +45
        self.Rp45 = MeshPointsRenderer(hparams)
        hparams.viewpoint_azimuth = -45
        self.Rn45 = MeshPointsRenderer(hparams)
        hparams.viewpoint_azimuth = +90
        self.Rp90 = MeshPointsRenderer(hparams)
        hparams.viewpoint_azimuth = -90
        self.Rn90 = MeshPointsRenderer(hparams)
        
        self.register_buffer('tp00', get_img_t(files['p00'], hparams))
        self.register_buffer('tp45', get_img_t(files['p45'], hparams))
        self.register_buffer('tn45', get_img_t(files['n45'], hparams))
        self.register_buffer('tp90', get_img_t(files['p90'], hparams))
        self.register_buffer('tn90', get_img_t(files['n90'], hparams))
        
        self.ratio =  0.925
        #self.blur = transforms.GaussianBlur(17, 3)
        self.blur = lambda x: x
        
        ds = BaselineDataset(hparams)
        self.points = nn.Parameter(ds[0][None])
        self.register_buffer('colors', torch.ones_like(self.points))
        
        self.laplacian = make_laplacian(3)
        #self.base_laplacian =  0.0062 #torch.abs(self.points - self.laplacian(self.points)).mean()
        
    def forward(self, baseline):
        return self.G(baseline)
        
    
    def laplacian_loss(self, t):
#         return torch.relu(torch.abs(t - self.laplacian(t)).mean() 
#                           - self.base_laplacian)
        #return torch.abs(t - self.laplacian(t)).mean()
        
        full_loss = F.l1_loss(t, self.laplacian(t.detach()))
        half = F.avg_pool2d(t, 2)
        #half_loss = F.l1_loss(half, self.laplacian(half.detach()))
        #F.avg_pool2d(t, )
        return full_loss# + half_loss
        #return torch.abs(t - self.laplacian(t)).max() 
    
    def render(self, points, colors, mean=None, std=None):        
        return {
            'p00': self.Rp00(points, colors, mean, std),
            'p45': self.Rp45(points, colors, mean, std),
            'n45': self.Rn45(points, colors, mean, std),
            'p90': self.Rp90(points, colors, mean, std),
            'n90': self.Rn90(points, colors, mean, std),
        }
    
    def loss(self, renders):
        bs = renders['p00'].size(0)
        lp00 = F.mse_loss(
            self.blur(renders['p00']), 
            self.blur(self.tp00.expand(bs, -1, -1 , -1))
        )
        lp45 = F.mse_loss(
            self.blur(renders['p45']),
            self.blur(self.tp45.expand(bs, -1, -1 , -1))
        )        
        ln45 = F.mse_loss(
            self.blur(renders['n45']), 
            self.blur(self.tn45.expand(bs, -1, -1 , -1))
        )
        lp90 = F.mse_loss(
            self.blur(renders['p90']), 
            self.blur(self.tp90.expand(bs, -1, -1 , -1))
        )
        ln90 = F.mse_loss(
            self.blur(renders['n90']), 
            self.blur(self.tn90.expand(bs, -1, -1 , -1))
        )                
        
        self.log(f"loss/lp00", lp00.item())
        self.log(f"loss/lp45", lp45.item())
        self.log(f"loss/ln45", ln45.item())
        self.log(f"loss/lp90", lp90.item())
        self.log(f"loss/ln90", ln90.item())        
        render_loss = lp00 + lp45 + ln45 + lp90 + ln90        
        self.log(f"loss/render_loss", render_loss.item())
        return render_loss
    
    def log_renders(self, points, colors, batch_idx):
        if batch_idx % self.log_render_interval == 0:
            renders = self.render(points.detach(), colors.detach())            
            stacked = torch.cat(list(renders.values()))
            grid = make_grid(tensor=stacked, nrow = len(renders))            
            self.logger.experiment.add_image('renders', grid, self.global_step)
    
    def training_step(self, batch, batch_idx):
        #baseline= batch
        points, colors = self.points, self.colors
        laplacian_loss = self.laplacian_loss(points)
        self.log(f"loss/laplacian_loss", laplacian_loss.item())
        #print(laplacian_loss.item())
        self.log_renders(points, colors, batch_idx)
        renders = self.render(points, colors, self.mean, self.std)        
        loss = self.loss(renders) + laplacian_loss
        self.log(f"loss/loss", loss.item())
        return loss 
        
    def configure_optimizers(self):
        return torch.optim.Adam([self.points], lr=0.0003)
    
    def training_epoch_end(self, training_step_outputs):
        self.ratio = min(self.ratio + 0.025, 1) 
    
from src.config import get_parser

config = get_parser().parse_args(args=[])
config.fast_baseline_size = 64
config.fast_image_size = 64
config.viewpoint_distance = 3.25
config.G_noise_amp = 0.001
rsp = RSP(config)
#rsp

In [2]:
config.fast_batch_size = 1
dm = FastDataModule(config, BaselineDataset)    
dm

<src.data.fast_datamodule.FastDataModule at 0x7feb32bd3af0>

In [3]:
# pyright: reportMissingImports=false
import os

import torch
import torch.nn.functional as F
import torchvision
import trimesh
import pytorch_lightning as pl

from src.utilities.util import (
    grid_to_list,
    make_faces,
)

class ImageMesh(pl.callbacks.Callback):
    
    def __init__(self, opt):
        super().__init__()
        self.num_samples = opt.log_grid_samples
        self.nrow = opt.log_grid_rows
        self.padding = opt.log_grid_padding                
        self.pad_value = opt.log_pad_value        
        self.log_batch_interval = opt.log_batch_interval
        self.faces = None
        
        
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # show images only every log_batch_interval batches
        if (trainer.batch_idx % self.log_batch_interval) != 0:  # type: ignore[attr-defined]
            return
        batch = next(iter(trainer.datamodule.train_dataloader()))
        
        # generate images
        with torch.no_grad():
            pl_module.eval()
            baseline= batch.to(pl_module.device)
            points = pl_module.points
            pl_module.train()

            try:            
                if self.faces is None:
                    self.faces = make_faces(points.size(-2), points.size(-1))
                vertices = grid_to_list(points)[0].cpu().numpy()
                mesh = trimesh.Trimesh(vertices=vertices, faces=self.faces)
                mesh_dir = os.path.join(trainer.log_dir, 'mesh')
                if not os.path.exists(mesh_dir):
                    os.makedirs(mesh_dir)
                file_path = os.path.join(mesh_dir, f'mesh_{trainer.current_epoch}_{trainer.global_step}.stl')
                mesh.export(file_path)                     
            except:
                print('Exception', points.shape)
                pass
        

In [None]:
trainer = pl.Trainer(gpus=1, max_epochs=100, progress_bar_refresh_rate=20,
                     terminate_on_nan=True, callbacks=[ImageMesh(config)])
trainer.fit(rsp, dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name      | Type               | Params
-------------------------------------------------
0 | G         | Generator          | 1.0 M 
1 | Rp00      | MeshPointsRenderer | 0     
2 | Rp45      | MeshPointsRenderer | 0     
3 | Rn45      | MeshPointsRenderer | 0     
4 | Rp90      | MeshPointsRenderer | 0     
5 | Rn90      | MeshPointsRenderer | 0     
6 | laplacian | Conv2d             | 225   
-------------------------------------------------
1.1 M     Trainable params
225       Non-trainable params
1.1 M     Total params
4.226     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

In [None]:
dist = nn.Conv2d(3, 3, 3, stride=1, padding=1, bias=False)
dist

In [None]:
dist.weight.data.shape

In [None]:
def make_laplacian():
    hood = [[0.125, 0.125, 0.125],
            [0.125, 0.000, 0.125],
            [0.125, 0.125, 0.125],]

    zeros = [[0.000, 0.000, 0.000],
             [0.000, 0.000, 0.000],
             [0.000, 0.000, 0.000],]

    weights = torch.tensor([
        [hood, zeros, zeros],
        [zeros, hood, zeros],
        [zeros, zeros, hood],
    ])
    res = nn.Conv2d(3, 3, 3, stride=1, padding=1, 
        bias=False, padding_mode='replicate')
    res.requires_grad_(False)
    res.weight.data = weights
    return res

laplacian = make_laplacian()
laplacian.weight

In [None]:
self = rsp
torch.abs(self.points - self.laplacian(self.points)).mean()

In [None]:
import trimesh
from src.utilities.util import (
    grid_to_list,
    make_faces,
)
sz = 64
config.fast_baseline_size = sz
config.G_noise_amp = 0.001
faces = make_faces(sz, sz)
ds = BaselineDataset(config)
base_vert = ds[0][None]
smooth_vert = laplacian(base_vert)

mesh = trimesh.Trimesh(vertices=grid_to_list(base_vert)[0], faces=faces)
mesh.export('./base.stl')


mesh = trimesh.Trimesh(vertices=grid_to_list(smooth_vert)[0], faces=faces)
mesh.export('./smooth.stl');

In [None]:
grid_to_list(base_vert)[0].shape

In [None]:
torch.abs(base_vert -smooth_vert).max()

In [30]:
F.avg_pool2d(torch.rand(1, 3, 8, 8), 2).shape

torch.Size([1, 3, 4, 4])

In [35]:
# 1 2 3 4 5
# 2
# 3
# 4
# 5
# 16 
#  8
# 16 * x + 8 * 2 * x =  1
# 32x  = 1
# x= 0.03125
#x= 0.0625
0.03125 * 16 +  0.0625 * 8

1.0

In [41]:

print(res.weight.data.shape, weights.shape)
res.weight.data = weights

torch.Size([3, 3, 5, 5]) torch.Size([3, 3, 5, 5])


In [42]:
res(torch.rand(1, 3, 8, 8)).shape

torch.Size([1, 3, 8, 8])