# The Mokume Dataset and Inverse Texturing of Solid Wood (NCA Training)

This notebook is a step by step guide to train an NCA model to create 3D wood textures.

## Utils and Imports

In [None]:
import os
import numpy as np
import argparse
import yaml
from tqdm import tqdm
import warnings
import shutil
import PIL
import gc
import matplotlib.pyplot as plt
from IPython.display import Image, HTML, Markdown, clear_output, display
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter
import moviepy.editor as mvp

import torch
import torch.nn.functional as F
import torchvision.models as models

# from COMMON.data_utils import get_unfolded_image

# torch.autograd.set_detect_anomaly(True)
torch.set_default_dtype(torch.float32)
torch.set_default_device('cuda:0')
prec = torch.float16

warnings.filterwarnings("ignore")

os.environ['FFMPEG_BINARY'] = 'ffmpeg'

class VideoWriter:
    def __init__(self, filename='_autoplay.mp4', fps=30.0, **kw):
        self.writer = None
        self.params = dict(filename=filename, fps=fps, **kw)

    def add(self, img):
        img = np.asarray(img)
        if self.writer is None:
            h, w = img.shape[:2]
            self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
        if img.dtype in [np.float32, np.float64]:
            img = np.uint8(img.clip(0, 1) * 255)
        if len(img.shape) == 2:
            img = np.repeat(img[..., None], 3, -1)
        self.writer.write_frame(img)

    def close(self):
        if self.writer:
            self.writer.close()

    def show(self, **kw):
        self.close()
        fn = self.params['filename']
        display(mvp.ipython_display(fn, **kw))

    def __enter__(self):
        return self

    def __exit__(self, *kw):
        self.close()

        
def get_unfolded_image(imgs, black_bg=False): #false before, does it cause any problems?

    all_imgs = []

    for i in range(6):
        if i<len(imgs):
            img = np.copy(imgs[i])
            if img.dtype!=np.uint8:
                img = (255.0*img).astype(np.uint8)
            all_imgs.append(img)
        else:
            all_imgs.append(np.ones(imgs[0].shape, dtype=np.uint8)*255)

    depth = all_imgs[1].shape[0]
    shape_base = list(all_imgs[0].shape)
    shape_base[0] = depth
    shape_f  = tuple(shape_base)
    shape_ce = list(shape_base)
    shape_ce[1] = depth
    shape_ce = tuple(shape_ce)

    if black_bg: 
        empty_img_ce = np.zeros(shape_ce, dtype=np.uint8)
        empty_img_f =  np.zeros(shape_f, dtype=np.uint8)
        
    else: #white bg         
        empty_img_ce = np.ones(shape_ce, dtype=np.uint8)*255
        empty_img_f = np.ones(shape_f, dtype=np.uint8)*255
    
    col1 = np.vstack(((empty_img_ce, all_imgs[2], empty_img_ce)))
    col2 = np.vstack(((all_imgs[1], all_imgs[0], all_imgs[3])))
    col3 = np.vstack(((empty_img_ce, all_imgs[4], empty_img_ce)))
    col4 = np.vstack(((empty_img_f, all_imgs[5], empty_img_f)))

    unfolded_img = np.hstack((col1,col2,col3,col4))

    return unfolded_img
        
def to_pil_image(x):
    """
    x: tensor with shape [b, 3, h, w] or [b, h, w]
    """
    if x.ndim == 3:
        x = x.unsqueeze(1).repeat(1, 3, 1, 1)

    x = x.permute(0, 2, 3, 1)
    x = (x * 255).clamp(0, 255).to(torch.uint8)
    x = x.cpu().numpy()
    x = np.hstack(x)
    return PIL.Image.fromarray(x)

def get_outer_faces(x):
    """
    x: tensor with shape [c, h, w, d] or [h, w, d]
    Returns the outer faces of the 3D tensor. [b, c, h, w] or [b, h, w]
    """
    return torch.stack([
        x[..., 0, :, :], x[..., -1, :, :], # C, E
        x[..., :, 0, :], x[..., :, -1, :], # D, B
        x[..., :, :, 0], x[..., :, :, -1], # F, A
    ], dim=0)

def reorder_faces(faces, order="ABCDEF"):
    """
    faces: [6, c, h, w] 
    external faces of the cube in A B C D E F order
    
    returns: external_faces [6, c, h, w]
    """
    assert order in ["ABCDEF", "ZYX"] 
    # Reorder external faces to be compatible with the 3D volumetric cube 
    if order == "ABCDEF":
        return torch.stack((
            torch.rot90(faces[5], -3, (1, 2)),  # A, X=+1
            torch.rot90(torch.flip(faces[3], (1, )), -1, (1, 2)), # B, Y=+1
            torch.flip(faces[0], (1,)), # C, Z = -1
            torch.rot90(faces[2], -3, (1, 2)),  # D, Y=-1
            torch.flip(faces[1], (1, 2,)),  # E, Z=+1
            torch.rot90(torch.flip(faces[4], (2, )), -1, (1, 2)) # F, X=-1
        ))
        
    elif order == "ZYX":
        return torch.stack((
            torch.flip(faces[2], (1,)),  # C, Z=-1
            torch.flip(faces[4], (1, 2,)),  # E, Z=+1
            torch.rot90(faces[3], 3, (1, 2)),  # D, Y=-1
            torch.flip(torch.rot90(faces[1], 1, (1, 2)), (1,)),  # B, Y=+1
            torch.flip(torch.rot90(faces[5], 1, (1, 2)), (2,)),  # F, X=-1
            torch.rot90(faces[0], 3, (1, 2)),  # A, X=+1
        ))
    else:
        return

## VGG-Based Style Loss Function

In [None]:
class RelaxedOTLoss(torch.nn.Module):
    """Loss function proposed in --> https://arxiv.org/abs/1904.12785"""
    """Code taken from --> https://arxiv.org/pdf/2404.06279"""
    def __init__(self, vgg, target_image, n_samples=1024):
        super().__init__()
        self.n_samples = n_samples
        self.vgg = vgg
        with torch.no_grad():
            self.target_features = self.get_vgg_features(target_image)

    def get_vgg_features(self, imgs):
        style_layers = [1, 6, 11, 18, 25]
        mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None]
        std = torch.tensor([0.229, 0.224, 0.225])[:, None, None]
        x = (imgs - mean) / std
        b, c, h, w = x.shape
        features = []
        for i, layer in enumerate(self.vgg[:max(style_layers) + 1]):
            x = layer(x)
            if i in style_layers:
                b, c, h, w = x.shape
                features.append(x.reshape(b, c, h * w))
        return features

    @staticmethod
    def pairwise_distances_cos(x, y):
        x_norm = torch.norm(x, dim=2, keepdim=True)  # (b, n, 1)
        y_t = y.transpose(1, 2)  # (b, c, m) (m may be different from n)
        y_norm = torch.norm(y_t, dim=1, keepdim=True)  # (b, 1, m)
        dist = 1. - torch.matmul(x, y_t) / (x_norm * y_norm + 1e-10)  # (b, n, m)
        return dist

    @staticmethod
    def style_loss(x, y):
        pairwise_distance = RelaxedOTLoss.pairwise_distances_cos(x, y)
        m1, m1_inds = pairwise_distance.min(1)
        m2, m2_inds = pairwise_distance.min(2)
        remd = torch.max(m1.mean(dim=1), m2.mean(dim=1))
        return remd

    @staticmethod
    def moment_loss(x, y):
        mu_x, mu_y = torch.mean(x, 1, keepdim=True), torch.mean(y, 1, keepdim=True)
        mu_diff = torch.abs(mu_x - mu_y).mean(dim=(1, 2))

        x_c, y_c = x - mu_x, y - mu_y
        x_cov = torch.matmul(x_c.transpose(1, 2), x_c) / (x.shape[1] - 1)
        y_cov = torch.matmul(y_c.transpose(1, 2), y_c) / (y.shape[1] - 1)

        cov_diff = torch.abs(x_cov - y_cov).mean(dim=(1, 2))
        return mu_diff + cov_diff

    def forward(self, generated_image):
        loss = 0.0
        generated_features = self.get_vgg_features(generated_image)
        # Iterate over the VGG layers
        for x, y in zip(generated_features, self.target_features):
            (b_x, c, n_x), (b_y, _, n_y) = x.shape, y.shape
            n_samples = min(n_x, n_y, self.n_samples)
            indices_x = torch.argsort(torch.rand(b_x, 1, n_x, device=x.device), dim=-1)[..., :n_samples]
            x = x.gather(-1, indices_x.expand(b_x, c, n_samples))
            indices_y = torch.argsort(torch.rand(b_y, 1, n_y, device=y.device), dim=-1)[..., :n_samples]
            y = y.gather(-1, indices_y.expand(b_y, c, n_samples))
            x, y = x.transpose(1, 2), y.transpose(1, 2)  # (b, n_samples, c)
            loss += self.style_loss(x, y) + self.moment_loss(x, y)

        return loss.mean()
    
vgg = models.vgg16(weights='IMAGENET1K_V1').features

## Volumetric NCA Model

In [None]:
def get_pos_emb_3D(D):
    xs = torch.arange(D, dtype=prec) / D
    xs = 2.0 * (xs - 0.5 + 0.5 / D)
    xs, ys, zs = xs[None, :, None, None], xs[None, None, :, None], xs[None, None, None, :]
    grid = torch.zeros((3, D, D, D), dtype=prec)
    grid[0:1], grid[1:2], grid[2:3] = xs, ys, zs
    return grid

def depthwise_conv(x, filters):
    """filters: [filter_n, h, w]"""
    b, ch, h, w, d = x.shape
    y = x.reshape(b * ch, 1, h, w, d)
    y = torch.nn.functional.pad(y, [1, 1, 1, 1, 1, 1], "replicate")
    y = torch.nn.functional.conv3d(y, filters[:, None])
    return y.reshape(b, -1, h, w, d)

class VolumeNCA(torch.nn.Module):
    def __init__(self, chn=12, fc_dim=128, noise_level=0.1, pemb=True):
        super().__init__()
        self.chn = chn
        self.register_buffer("noise_level", torch.tensor([noise_level], dtype=prec))

        input_dim = (self.chn + 2) * 5
        self.pemb = pemb
        if pemb:
            input_dim += 3
        self.w1 = torch.nn.Conv3d(input_dim, fc_dim, 1, bias=True)
        self.w2 = torch.nn.Conv3d(fc_dim, self.chn, 1)

        torch.nn.init.xavier_normal_(self.w1.weight, gain=0.2)
        torch.nn.init.zeros_(self.w2.weight)

        with torch.no_grad():
            delta_one = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], dtype=prec)
            delta_two = torch.tensor([[-2.0, 0.0, 2.0], [-4.0, 0.0, 4.0], [-2.0, 0.0, 2.0]], dtype=prec)
            sobel_z = torch.stack([delta_one, delta_two, delta_one]) / 4.0
            sobel_y = sobel_z.permute(0, 2, 1)
            sobel_x = sobel_z.permute(2, 1, 0)

            lap1 = torch.tensor([[2.0, 3.0, 2.0], [3.0, 6.0, 3.0], [2.0, 3.0, 2.0]], dtype=prec)
            lap2 = torch.tensor([[3.0, 6.0, 3.0], [6.0, -88.0, 6.0], [3.0, 6.0, 3.0]], dtype=prec)
            lap = torch.stack([lap1, lap2, lap1])
            #             lap = (lap / 26.0)
            lap = lap / 8.0

            ident = torch.zeros(3, 3, 3, dtype=prec)
            ident[1, 1, 1] = 1.0

            self.filters = torch.stack([ident, sobel_x, sobel_y, sobel_z, lap])

    def forward(self, s, cond, p=None):
        z = depthwise_conv(torch.cat([s, cond], dim=1), self.filters)  # [b, 5 * chn, h, w]
        if p is not None:
            z = torch.cat([z, p], dim=1)
        delta_s = self.w2(torch.relu(self.w1(z)))
        return s + delta_s

    def seed(self, n, h=128, w=128, d=128):
        with torch.no_grad():
            return (torch.rand(n, self.chn, h, w, d, dtype=prec) - 0.5) * self.noise_level


def to_rgb(s):
    return s[..., :3, :, :, :] + 0.5


def rgb_delta(s):
    return s[..., :3, :, :, :]

print('Number of NCA parameters:', sum(p.numel() for p in VolumeNCA().parameters()))

## Load and visualize the data

In [None]:
sample_name = "H11"
res = 64
with torch.no_grad():
    grid = get_pos_emb_3D(res)
    col = np.load(f'Samples/{sample_name}/col_cube.npz')['arr_0']
    gf = np.load(f'Samples/{sample_name}/gf_cube.npz')['arr_0']
    
    external = np.stack([PIL.Image.open(f"Samples/{sample_name}/{face}_col.png") for face in "ABCDEF"]) / 255.0 # Unmuted colors

    col = torch.tensor(col, dtype=prec).permute(3, 0, 1, 2).contiguous()
    col = F.interpolate(col.unsqueeze(0), size=(res, res, res), mode='trilinear', align_corners=False).squeeze(0) # [3, res, res, res]

    gf = torch.tensor(gf, dtype=prec)
    gf = F.interpolate(gf[None, None, :], size=(res, res, res), mode='trilinear', align_corners=False)[0, 0] # [res, res, res]

    external = torch.tensor(external, dtype=prec)
    external = F.interpolate(external.permute(0, 3, 1, 2), size=(res, res), mode='bilinear', align_corners=False)

    col_cond = col.unsqueeze(0).mean(dim=1) - 0.5  # instead of * 2 - 1 [1, res, res, res]
    gf_cond = gf.unsqueeze(0) - 0.5  # instead of * 2 - 1 # [1, res, res, res]
    cond = torch.stack([col_cond, gf_cond], dim=1) # [2, res, res, res]

    external_faces = reorder_faces(external, "ZYX")
    loss_fn = RelaxedOTLoss(vgg, external_faces)

    print("External Faces in ZYX order")
    to_pil_image(external_faces).show()
    to_pil_image(get_outer_faces(col)).show()
    to_pil_image(get_outer_faces(gf)).show()
    

    print("Faces in ABCDEF order")
    external_ABCDEF = reorder_faces(external_faces, "ABCDEF")
    PIL.Image.fromarray(get_unfolded_image(external_ABCDEF.permute(0, 2, 3, 1).cpu().numpy())).show()
    
    col_ABCDEF = reorder_faces(get_outer_faces(col), "ABCDEF")
    PIL.Image.fromarray(get_unfolded_image(col_ABCDEF.permute(0, 2, 3, 1).cpu().numpy())).show()

    gf_ABCDEF = reorder_faces(get_outer_faces(gf).unsqueeze(1).repeat(1, 3, 1, 1), "ABCDEF")
    PIL.Image.fromarray(get_unfolded_image(gf_ABCDEF.permute(0, 2, 3, 1).cpu().numpy())).show()



## Initialize and Train the NCA model

In [None]:
model = VolumeNCA(chn=12, fc_dim=128, noise_level=0.1, pemb=True)
opt = torch.optim.Adam(model.parameters(), 0.001, fused=True)
lr_sched = torch.optim.lr_scheduler.MultiStepLR(opt, [700, 1200], 0.3)
scaler = torch.cuda.amp.GradScaler()

pool_size = 32  # Number of NCA states in pool
batch_size = 1 # Batch size must be 1
step_range = (8, 16)
gradient_checkpoint = True
loss_log = []
with torch.no_grad():
    pool = model.seed(pool_size, res, res, res)
    
model = VolumeNCA(chn=12, fc_dim=128, noise_level=0.1, pemb=True)
opt = torch.optim.Adam(model.parameters(), 0.001, fused=True)
lr_sched = torch.optim.lr_scheduler.MultiStepLR(opt, [700, 1200], 0.3)
scaler = torch.cuda.amp.GradScaler()


### Training Loop

In [None]:
for epoch in tqdm(range(1500)):
    step_n = np.random.randint(step_range[0], step_range[1])  # 32..96
    batch_idx = np.random.choice(len(pool), batch_size, replace=False)
    s = pool[batch_idx]
    if epoch % 32 == 0:
        s[:1] = model.seed(1, res, res, res)

    if model.pemb:
        pemb = grid[None, ...]
    else:
        pemb = None



    with torch.autocast(device_type='cuda'):
        if not gradient_checkpoint:
            for k in range(step_n):
                s = model(s, cond, pemb)
        else:
            model_forward = lambda x: model(x, cond, pemb)
            s.requires_grad = True  # https://github.com/pytorch/pytorch/issues/42812
            s = torch.utils.checkpoint.checkpoint_sequential([model_forward]*step_n, 4, s, use_reentrant=True)
            
    pool[batch_idx] = torch.detach(s)  # update pool

    overflow_loss = (s - s.clamp(-1.0, 1.0)).abs().mean()
    texture_loss = 0.0

    # Assuming that the batch size is 1
    s_rgb = rgb_delta(s).squeeze(0) + col
    imgs = torch.stack((s_rgb[:, 0, :, :], s_rgb[:, -1, :, :],
                        s_rgb[:, :, 0, :], s_rgb[:, :, -1, :],
                        s_rgb[:, :, :, 0], s_rgb[:, :, :, -1]), dim=0)
    texture_loss = loss_fn(imgs)


    loss = texture_loss + overflow_loss
    loss_log.append(loss.item())
    scaler.scale(loss).backward()

    with torch.no_grad():
        if prec is torch.float16:
            scaler.unscale_(opt)
        for p in model.parameters():
            p.grad /= (p.grad.norm() + 1e-8)  # normalize gradients
        if prec is torch.float16:
            scaler.step(opt)
            scaler.update()
        else:
            opt.step()
        lr_sched.step()
        opt.zero_grad()
        
    if epoch % 50 == 0:
        clear_output()
        plt.plot(loss_log, ".", alpha=0.1)
        plt.ylim(top=loss_log[0])
        plt.ylabel("Loss")
        plt.xlabel("Epoch")
        to_pil_image(imgs).show()
        plt.show()
        

## Visualize the NCA output

In [None]:
hspacer = np.ones((res * 3 + 20, 10, 3))
vspacer = np.ones((10, res, 3))
with VideoWriter(fps=24, ffmpeg_params=['-pix_fmt', 'yuv420p'],
                 bitrate='10000k') as vid, torch.no_grad():
    s = model.seed(1, res, res, res)
    face_images = external_faces.cpu()

    for step in tqdm(range(200)):
        with torch.autocast(device_type='cuda'):
            z = depthwise_conv(torch.cat([s, cond], dim=1), model.filters)  # [b, 5 * chn, h, w]
            if model.pemb:
                z = torch.cat([z, grid[None, ...]], dim=1)

            h = torch.relu(model.w1(z))
            # Un-comment the deletes if you run out of memory
#             del z
            delta_s = model.w2(h)
#             del h
            s[:] += delta_s
#             del delta_s

        rgb_vol = col + s[0:, :3, :, :, :]        
        rgb_ABCDEF = reorder_faces(get_outer_faces(rgb_vol[0]), "ABCDEF")
        img = get_unfolded_image(rgb_ABCDEF.permute(0, 2, 3, 1).cpu().numpy())

        
        vid.add(img)

    vid.show()