In [None]:
# nvidia-smi
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
import multiprocessing
import time
import warnings

from tqdm import tqdm
from torch import nn
from torch_scatter import scatter
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, GCNConv, global_mean_pool
from torch_geometric.utils import add_self_loops, degree
from utils.preprocessing_for_two_step_mesh import create_data

In [None]:
# load image file paths
svg_folder = './datasets/svg'
png_folder = './datasets/png'
imgs = []
png = []
dataset = []

for root, folders, files in os.walk(svg_folder):
    for file in files:
        if file.split('.')[1] != 'svg': continue
        if 'checkpoint' in file: continue
        
        file_path = os.path.join(svg_folder, file)
        imgs.append(file_path)
        
        file_path = os.path.join(png_folder, file.replace('svg', 'png'))
        png.append(file_path)

In [None]:
# warnings.filterwarnings("ignore")
# for i, file_path in enumerate(tqdm(imgs)):
#     # try:
#     #     dataset.append(create_data(file_path))
#     # except:
#     #     print(file_path)
#     #     raise SystemExit
        
#     file_path = "./datasets/svg/032-firewood.svg"
#     data = create_data(file_path)
#     print(data)
#     break

In [None]:
imgs = imgs[:2000]
warnings.filterwarnings("ignore")
dataset = []
for data in tqdm(multiprocessing.Pool(8).imap_unordered(create_data, imgs), total=len(imgs)):
    dataset.append(data)

In [None]:
# hyperparameters
torch.manual_seed(16)

batch_size = 1
num_epoch = 50

_train = int(len(dataset) * 0.9)
_val = _train + int(len(dataset) * 0.05)
_test = len(dataset) - _val

# create dataloader
train_set, val_set, test_set = dataset[:_train], dataset[_train:_val], dataset[_val:]
train_svg, val_svg, test_svg = imgs[:_train], imgs[_train:_val], imgs[_val:]
train_png, val_png, test_png = png[:_train], png[_train:_val], png[_val:]

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

print(f"Training Data: {len(train_set)}\nValidation Data: {len(val_set)}\nTesting Data: {len(test_set)}")

In [None]:
class GraphEncoder(nn.Module):
    def __init__(self):
        super(GraphEncoder, self).__init__()
        self.conv1 = GCNConv(2, 16, improved=True)
        self.conv2 = GCNConv(16, 64, improved=True)
        
    def forward(self, x, edge_index, cluster):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        x = scatter(x, cluster, dim=0, reduce='mean')
        return x

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 3),
            nn.ReLU(),
        )
        
    def forward(self, x):
        return self.layers(x)
    
# class SVGConv(MessagePassing):
#     def __init__(self,
#                  in_channels: int,
#                  out_channels: int,
#                  improved: bool = True,
#                  add_self_loops: bool = True,
#                  normalize: bool = True,
#                  bias: bool = True,
#                  **kwargs
#     ):
#         super(SVGConv, self).__init__(aggr="mean", **kwargs)
#         self.in_channels = in_channels
#         self.out_channels = out_channels
#         self.improved = improved
#         self.add_self_loops = add_self_loops
#         self.normalize = normalize
        
#         self.lin = nn.Linear(in_channels, out_channels, bias=False)
        
#         if bias:
#             self.bias = nn.Parameter(torch.Tensor(out_channels))
#         else:
#             self.register_parameter("bias", None)
        
#         self.reset_parameters()
        
#     def reset_parameters(self):
#         torch.nn.init.uniform_(self.lin.weight)
#         if self.bias is not None:
#             torch.nn.init.normal_(self.bias, mean=0.0, std=0.1)
        
#     def forward(self, x, edge_index, edge_attr):
#         out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
#         return out
    
#     def message(self, x_i, x_j, edge_attr):
#         if edge_attr == 2:  # overlap
            
#         elif edge_attr == 3:  # contain
#         else:  # adjacent
            
#         return super().message(x_j)
    
class GraphNet(nn.Module):
    def __init__(self):
        super(GraphNet, self).__init__()
        self.encode = GraphEncoder()
        self.gcn1 = GCNConv(64, 128)
        self.gcn2 = GCNConv(128, 64)
        # self.gcn1 = SVGConv(64, 128)
        # self.gcn2 = SVGConv(128, 64)
        self.fc = MLP()
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        cluster = data.cluster.type(torch.long)
        group_edge_index, group_edge_attr = data.group_edge_index, data.group_edge_attr
        
        out = self.encode(x, edge_index, cluster)
        
        if len(group_edge_attr) != 0:
            out = self.gcn1(out, group_edge_index, group_edge_attr)
            out = F.relu(out)
            out = self.gcn2(out, group_edge_index, group_edge_attr)
            out = F.relu(out)
        out = self.fc(out)
        
        return out

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.MSELoss(reduction='mean')

In [None]:
# training
train_losses = []
val_losses = []
best_loss = float('inf')
train_err = 0
val_err = 0

for epoch in range(num_epoch):  # num_epoch
    train_loss = 0
    val_loss = 0
    
    model.train()
    for i, data in enumerate(tqdm(train_loader)):
        data = data.to(device)
        if data.x.shape[0] == 0:
            train_err += 1
            continue
        
        optimizer.zero_grad()
        out = model(data)
        
        cluster = data.cluster.type(torch.long)
        rgb = scatter(data.rgb, cluster, dim=0, reduce='mean')
        loss = criterion(out, rgb)
        
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
    model.eval()
    for i, data in enumerate(tqdm(val_loader)):
        data = data.to(device)
        if data.x.shape[0] == 0:
            val_err += 1
            continue
        out = model(data)
        
        cluster = data.cluster.type(torch.long)
        rgb = scatter(data.rgb, cluster, dim=0, reduce='mean')
        loss = criterion(out, rgb)
        
        val_loss += loss.item()
    
    train_avg = train_loss / (len(train_loader)-train_err)
    val_avg = val_loss / (len(val_loader)-val_err)
    train_losses.append(loss.item())
    val_losses.append(loss.item())
    
    print(f'Epoch {epoch}\tTraining Loss: {train_avg}\tValidation Loss: {val_avg}')
    
    if val_avg < best_loss:
        print(f'Validation Loss Decreased({best_loss:.6f}--->{val_avg:.6f})\tSaving The Model')
        best_loss = val_avg
        torch.save(model.state_dict(), 'best_checkpoint.pth')

In [None]:
# plot losses
import matplotlib.pyplot as plt

_x = list(range(num_epoch))  # num_epoch
plt.plot(_x, train_losses, label='Training Loss')
plt.plot(_x, val_losses, label='Validation Loss')
 
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# testing
model.load_state_dict(torch.load('./best_checkpoint.pth'))
model.eval()

test_loss = 0
test_err = 0
for i, data in enumerate(tqdm(test_loader)):
    data = data.to(device)
    if data.x.shape[0] == 0:
        tes_err += 1
        continue
    out = model(data)
    cluster = data.cluster.type(torch.long)
    rgb = scatter(data.rgb, cluster, dim=0, reduce='mean')
    loss = criterion(out, rgb)
    test_loss += loss.item()
    
print(f'Testing Loss: {test_loss / (len(test_loader)-test_err)}')

In [None]:
# visualize testing results
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

vis_loader = DataLoader(test_set, batch_size=1, shuffle=False)
model.load_state_dict(torch.load('./best_checkpoint.pth'))
model.eval()

fig = plt.figure("plot")
ax = fig.add_subplot(1, 1, 1)
plt.axis("off")

for data, png in zip(vis_loader, test_png):
    data = data.to(device)
    out = model(data)
    
    pos = data.x.cpu().detach().numpy()
    rgb = out.cpu().detach().numpy()
    edge = np.transpose(data.edge_index.cpu().detach().numpy())
    cluster = data.cluster.cpu().detach().numpy()

    for n1, n2 in edge:
        x1, y1 = pos[n1]
        x2, y2 = pos[n2]
        l = Line2D([x1,x2], [y1,y2], alpha=0.2)
        ax.add_line(l)
    
    for (x, y), clu in zip(pos, cluster):
        plt.scatter(x, y, color=rgb[int(clu)], s=30)
        
    plt.axis("off")
    plt.show()
    plt.close()
    break