# **Deep Learning on 3D Meshes**

In [None]:
# nvidia-smi
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
from time import sleep
from pathlib import Path
from itertools import tee
from functools import lru_cache

# import trimesh
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, remove_self_loops
from torch_geometric.transforms import BaseTransform, Compose, FaceToEdge
from torch_geometric.data import Data, InMemoryDataset, extract_zip, DataLoader

In [None]:
import multiprocessing
import time
import warnings

from torch_scatter import scatter
from torch_geometric.loader import DataLoader
from utils.preprocessing_for_tri_mesh import create_data
from utils.preprocessing_for_mesh import create_mesh

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# load image file paths
svg_folder = './datasets/svg'
png_folder = './datasets/png'
imgs = []
png = []
dataset = []

for root, folders, files in os.walk(svg_folder):
    for file in files:
        if file.split('.')[1] != 'svg': continue
        if 'checkpoint' in file: continue
        
        file_path = os.path.join(svg_folder, file)
        imgs.append(file_path)
        
        file_path = os.path.join(png_folder, file.replace('svg', 'png'))
        png.append(file_path)

In [None]:
# for i, file_path in enumerate(tqdm(imgs)):
#     # try:
#     #     dataset.append(create_mesh(file_path))
#     # except:
#     #     print(file_path) 
        
#     file_path = "./datasets/svg/032-firewood.svg"
#     data = create_mesh(file_path)
#     print(data)
#     break

In [None]:
imgs = imgs[:2000]
warnings.filterwarnings("ignore")

dataset = []
for data in tqdm(multiprocessing.Pool(8).imap_unordered(create_mesh, imgs), total=len(imgs)):
    dataset.append(data)

In [8]:
# hyperparameters
torch.manual_seed(16)

batch_size = 1
num_features = 2  # (x, y)
num_output = 1  # (R, G, B) or h
num_epoch = 50

_train = int(len(dataset) * 0.8)
_val = _train + int(len(dataset) * 0.1)
_test = len(dataset) - _val

# create dataloader
train_set, val_set, test_set = dataset[:_train], dataset[_train:_val], dataset[_val:]
train_svg, val_svg, test_svg = imgs[:_train], imgs[_train:_val], imgs[_val:]
train_png, val_png, test_png = png[:_train], png[_train:_val], png[_val:]

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

print(f"Training Data: {len(train_set)}\nValidation Data: {len(val_set)}\nTesting Data: {len(test_set)}")

Training Data: 41
Validation Data: 5
Testing Data: 6


# Dataset Defintion

In [9]:
class NormalizeUnitSphere(BaseTransform):

    @staticmethod
    def _re_center(x):
        centroid = torch.mean(x, dim=0)
        return x - centroid

    @staticmethod
    def _re_scale_to_unit_length(x):
        max_dist = torch.max(torch.norm(x, dim=1))
        return x / max_dist

    def __call__(self, data: Data):
        if data.x is not None:
            data.x = self._re_scale_to_unit_length(self._re_center(data.x))

        return data

    def __repr__(self):
        return "{}()".format(self.__class__.__name__)

In [10]:
def pairwise(iterable):
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

def get_conv_layers(channels: list, conv: MessagePassing, conv_params: dict):
    conv_layers = [
        conv(in_ch, out_ch, **conv_params) for in_ch, out_ch in pairwise(channels)
    ]
    return conv_layers

In [11]:
def get_mlp_layers(channels: list, activation, output_activation=nn.Identity):
    layers = []
    *intermediate_layer_definitions, final_layer_definition = pairwise(channels)

    for in_ch, out_ch in intermediate_layer_definitions:
        intermediate_layer = nn.Linear(in_ch, out_ch)
        layers += [intermediate_layer, activation()]

    layers += [nn.Linear(*final_layer_definition), output_activation()]
    return nn.Sequential(*layers)

In [12]:
class FeatureSteeredConvolution(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_heads: int,
        ensure_trans_invar: bool = True,
        bias: bool = True,
        with_self_loops: bool = True,
    ):
        super().__init__(aggr="mean")

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        self.with_self_loops = with_self_loops

        self.linear = torch.nn.Linear(
            in_features=in_channels,
            out_features=out_channels * num_heads,
            bias=False,
        )
        self.u = torch.nn.Linear(
            in_features=in_channels,
            out_features=num_heads,
            bias=False,
        )
        self.c = torch.nn.Parameter(torch.Tensor(num_heads))

        if not ensure_trans_invar:
            self.v = torch.nn.Linear(
                in_features=in_channels,
                out_features=num_heads,
                bias=False,
            )
        else:
            self.register_parameter("v", None)

        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.uniform_(self.linear.weight)
        torch.nn.init.uniform_(self.u.weight)
        torch.nn.init.normal_(self.c, mean=0.0, std=0.1)
        if self.bias is not None:
            torch.nn.init.normal_(self.bias, mean=0.0, std=0.1)
        if self.v is not None:
            torch.nn.init.uniform_(self.v.weight)

    def forward(self, x, edge_index, edge_attr):
        if self.with_self_loops:
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index=edge_index, num_nodes=x.shape[0])

        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        return out if self.bias is None else out + self.bias

    def _compute_attention_weights(self, x_i, x_j):
        if x_j.shape[-1] != self.in_channels:
            raise ValueError(
                f"Expected input features with {self.in_channels} channels."
                f" Instead received features with {x_j.shape[-1]} channels."
            )
        if self.v is None:
            attention_logits = self.u(x_i - x_j) + self.c
        else:
            attention_logits = self.u(x_i) + self.b(x_j) + self.c
        # return F.relu(attention_logits)
        return F.softmax(attention_logits, dim=1)

    def message(self, x_i, x_j):
        attention_weights = self._compute_attention_weights(x_i, x_j)
        x_j = self.linear(x_j).view(-1, self.num_heads, self.out_channels)
        return (attention_weights.view(-1, self.num_heads, 1) * x_j).sum(dim=1)

In [13]:
class GraphFeatureEncoder(torch.nn.Module):
    def __init__(
        self,
        in_features,
        conv_channels,
        num_heads,
        apply_batch_norm: int = True,
        ensure_trans_invar: bool = True,
        bias: bool = True,
        with_self_loops: bool = True,
    ):
        super().__init__()

        conv_params = dict(
            num_heads=num_heads,
            ensure_trans_invar=ensure_trans_invar,
            bias=bias,
            with_self_loops=with_self_loops,
        )
        self.apply_batch_norm = apply_batch_norm

        *first_conv_channels, final_conv_channel = conv_channels
        conv_layers = get_conv_layers(
            channels=[in_features] + conv_channels,
            conv=FeatureSteeredConvolution,
            conv_params=conv_params,
        )
        self.conv_layers = nn.ModuleList(conv_layers)

        self.batch_layers = [None for _ in first_conv_channels]
        if apply_batch_norm:
            self.batch_layers = nn.ModuleList(
                [nn.BatchNorm1d(channel) for channel in first_conv_channels]
            )

    def forward(self, x, edge_index, edge_attr):
        *first_conv_layers, final_conv_layer = self.conv_layers
        for conv_layer, batch_layer in zip(first_conv_layers, self.batch_layers):
            x = conv_layer(x, edge_index, edge_attr)
            x = F.relu(x)
            if batch_layer is not None:
                x = batch_layer(x)
        return final_conv_layer(x, edge_index, edge_attr)

In [14]:
class MeshSeg(torch.nn.Module):
    def __init__(
        self,
        in_features,
        encoder_features,
        conv_channels,
        encoder_channels,
        decoder_channels,
        num_classes,
        num_heads,
        apply_batch_norm=True,
    ):
        super().__init__()
        self.input_encoder = get_mlp_layers(
            channels=[in_features] + encoder_channels,
            activation=nn.ReLU,
        )
        self.gnn = GraphFeatureEncoder(
            in_features=encoder_features,
            conv_channels=conv_channels,
            num_heads=num_heads,
            apply_batch_norm=apply_batch_norm,
        )
        *_, final_conv_channel = conv_channels

        self.final_projection = get_mlp_layers(
            [final_conv_channel] + decoder_channels + [num_classes],
            activation=nn.ReLU,
        )

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = self.input_encoder(x)
        x = self.gnn(x, edge_index, edge_attr)
        x = scatter(x, data.cluster, dim=0, reduce='mean')
        return self.final_projection(x)

In [15]:
def train(net, train_data, optimizer, loss_fn, device):
    net.train()
    cumulative_loss = 0.0
    err = 0
    for i, data in enumerate(train_data):
        data = data.to(device)
        
        if data.x.shape[0] == 0:
            err += 1
            continue
        
        optimizer.zero_grad()
        out = net(data)
        
        # # for hsv
        # h = data.h.type(torch.LongTensor).to(device)
        # loss = loss_fn(out, h)
        
        # for rgb
        loss = loss_fn(out, data.y)
        
        # # for rgb w/ scatter
        # rgb = scatter(data.y, data.cluster, dim=0, reduce='mean')
        # loss = loss_fn(out, rgb)
        
        loss.backward()
        cumulative_loss += loss.item()
        optimizer.step()
        
    return cumulative_loss / (len(train_data)-err)

In [16]:
@torch.no_grad()
def test(net, train_data, loss_fn, device):
    net.eval()
    cumulative_loss = 0.0
    err = 0
    for data in train_data:
        data = data.to(device)
        if data.x.shape[0] == 0:
            err += 1
            continue
        out = net(data)
        
        # h = data.h.type(torch.LongTensor).to(device)
        # loss = loss_fn(out, h)
        
        loss = loss_fn(out, data.y)
        
        # rgb = scatter(data.y, data.cluster, dim=0, reduce='mean')
        # loss = loss_fn(out, rgb)
        
        cumulative_loss += loss.item()
    return cumulative_loss / (len(train_data)-err)

In [17]:
model_params = dict(
    in_features=2,
    encoder_features=16,
    conv_channels=[32, 64, 128, 64],
    encoder_channels=[16],
    decoder_channels=[32],
    num_classes=3,
    num_heads=12,
    apply_batch_norm=True,
)

net = MeshSeg(**model_params).to(device)

In [18]:
print(net)

MeshSeg(
  (input_encoder): Sequential(
    (0): Linear(in_features=2, out_features=16, bias=True)
    (1): Identity()
  )
  (gnn): GraphFeatureEncoder(
    (conv_layers): ModuleList(
      (0): FeatureSteeredConvolution(16, 32)
      (1): FeatureSteeredConvolution(32, 64)
      (2): FeatureSteeredConvolution(64, 128)
      (3): FeatureSteeredConvolution(128, 64)
    )
    (batch_layers): ModuleList(
      (0): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (final_projection): Sequential(
    (0): Linear(in_features=64, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=3, bias=True)
    (3): Identity()
  )
)


In [None]:
lr = 0.001
num_epochs = 50
best_loss = float('inf')
train_losses = []
val_losses = []

optimizer = torch.optim.Adam(net.parameters(), lr=lr)
# loss_fn = torch.nn.CrossEntropyLoss()
loss_fn = torch.nn.MSELoss(reduction='mean')

In [None]:
with tqdm(range(num_epochs), unit="Epoch") as tepochs:
    for epoch in tepochs:
        train_loss = train(net, train_loader, optimizer, loss_fn, device)
        val_loss = test(net, val_loader, loss_fn, device)
        # train_acc, test_acc = test(net, train_loader, val_loader, device)
        
        tepochs.set_postfix(
            train_loss=train_loss,
            val_loss=val_loss,
        )
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        sleep(0.1)
        
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(net.state_dict(), "seg_ckpt.pth")

In [None]:
import matplotlib.pyplot as plt

_x = list(range(len(train_losses)))
plt.plot(_x, train_losses, label='Training Loss')
plt.plot(_x, val_losses, label='Validation Loss')
 
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# visualize testing results
import colorsys

vis_loader = DataLoader(test_set, batch_size=1, shuffle=False)

net.load_state_dict(torch.load('./seg_ckpt.pth'))
net.eval()

for data in vis_loader:
    data = data.to(device)
    out = net(data)
    pos = data.x.cpu().detach().numpy()[:,:2]
    H = out.cpu().detach().numpy()
    
    # # for rgb w/ scatter
    # cluster = data.cluster.cpu().detach().numpy()
    # for (x, y), clu in zip(pos, cluster):
    #     plt.scatter(x, y, color=H[clu])
    
    # for rgb w/o scatter
    for (x, y), (r, g, b) in zip(pos, H):
        plt.scatter(x, y, color=[r, g, b])
    
    # # get pos and hex color
    # cc = []
    # for h in H:
    #     if h[0] == 0: hh = 30 / 360
    #     elif h[0] == 1: hh = 60 / 360
    #     elif h[0] == 2: hh = 90 / 360
    #     elif h[0] == 3: hh = 120 / 360
    #     elif h[0] == 4: hh = 150 / 360
    #     elif h[0] == 5: hh = 180 / 360
    #     elif h[0] == 6: hh = 210 / 360
    #     elif h[0] == 7: hh = 240 / 360
    #     elif h[0] == 8: hh = 270 / 360
    #     elif h[0] == 9: hh = 300 / 360
    #     elif h[0] == 10: hh = 330 / 360
    #     else: hh = 360 / 360
    #     r, g, b = np.array(colorsys.hsv_to_rgb(hh, 1, 1))
    #     cc.append([r, g, b])
    
    # # plt.axis([0, 25, 33, -8])
    # for (x, y), c in zip(pos, cc):
    #     plt.scatter(x, y, color=c)
    
    plt.axis("off")
    plt.show()
    plt.close()
    break