# Octree Transformer Hands on Tutorial

*Let's start with some basic imports...*

In [24]:
import os
import math
import torch
import k3d
import numpy as np
import itertools

import pytorch_lightning as pl 
from pytorch_lightning import loggers as pl_loggers

from tqdm import tqdm
from tqdm.auto import trange

from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

import sys
sys.path.append('../')
from utils.save_obj import save_obj

*Furthermore, we need some constants.*

In [25]:
PADDING_VALUE = 0
NUM_VOCAB = 3+1
RESOLUTION = 64
SPATIAL_DIM = 3

# Load the Data

Get the data from the ShapeNet data set. For the purpose of this tutroial, there are 10 shapes already voxelized stored in the "data" folder.

# Octree Data Structure
The Octree data structure is a hierarchical representation of 3D voxel data. It starts with a root node that encompasses the entire object. This root node is then recursively split into eight children, forming an octagonal subdivision. Each child node can further be divided into eight children of its own, and this subdivision process continues until a certain resolution is reached. The resolution determines the maximum level of subdivision, and it impacts the number of leaf nodes in the Octree. Specifically, an Octree with a resolution of R has a maximum of 8^(ld(R))=R³ leaf nodes. Leaf nodes represent the smallest subvolumes in the Octree, containing voxel data.

![Octree Structure visualized](../images/octree_explained.png)

<small>Picture by [The Infinite Loop](https://geidav.wordpress.com/2014/07/18/advanced-octrees-1-preliminaries-insertion-strategies-and-max-tree-depth/)</small>


To optimize the Octree's efficiency, we utilize modifications to the traditional Octree structure. One key modification is the representation of subvolumes with the same value for all their voxels. In such cases, instead of representing each individual voxel, we can replace them with a single node that represents the common value. This pruning technique significantly reduces the number of nodes in the Octree, resulting in a more compact representation of the voxel data. 

Implementation Steps:

    Determine the node value:
        1: If all child elements are empty
        2: If there are at least two child elements with different values
        3: If all child elements are occupied
    Prune the tree if the node value is 1 or 3:
        Remove all child nodes of the current node and replace them with a single node representing the value of the current node.
    Recursively split the tree if the node value is 2:
        Create eight child nodes representing the subvolumes within the current node.
        Repeat this splitting process for each child node until the desired resolution is reached.

In [26]:
def linearize(array: np.ndarray,
                    max_resolution: int = 8096):
    def recursive_linearise(array: np.ndarray, pos: np.ndarray, dep: int = 1):
        """ Recursive internal function to linearise given voxel array.

        Note: Uses variables of parent function to store values.

        Args:
            array (np.ndarray): Numpy array (or subarray) holding pixels/voxels of a discretized shape.
            dep (int, optional): Current recursion depth. Defaults to 0.
        """
        # split input into an octree/quadtree
        subarrays = split(array)
        num_subarrays = len(subarrays)

        # initialize dictionary only on first pass
        if dep not in value:
            value[dep] = []
            depth[dep] = []
            position[dep] = []

        # compute values for each subarray
        for idx, sub in enumerate(subarrays):
            value[dep] += [1] if np.max(sub) == 0 else [3] if np.min(sub) > 0 else [2]
            depth[dep] += [dep]
            position[dep] += [2 * pos + dirs[idx]]

        # process each subarray recursivelly
        for idx, sub in enumerate(subarrays):
            cur_idx = -num_subarrays + idx
            if value[dep][cur_idx] == 2 and dep < max_dep:
                recursive_linearise(sub, position[dep][cur_idx], dep + 1)

    # initialise memory
    value = {}
    depth = {}
    position = {}
    dirs = np.array(list(itertools.product([1, 2], repeat=array.ndim)))
    init_pos = np.array(array.ndim * [0])
    max_dep = int(math.log2(max_resolution))

    # call function recursivelly
    recursive_linearise(array, init_pos)

    # flatten dictionaries
    value = np.array(list(itertools.chain(*value.values())))
    depth = np.array(list(itertools.chain(*depth.values())))
    position = np.array(list(itertools.chain(*position.values())))

    return value, depth, position

We have to implement some further function...

In [27]:
def split(array: np.ndarray) -> np.ndarray:
    """ Splits the given array along each axis in half.

    Args:
        elements (np.ndarray): Numpy array of arbitary dimension.

    Returns:
        np.ndarray: Array of splited elements with an additional dimension along the first axis.
    """
    ndim = array.ndim
    array = np.expand_dims(array, axis=0)
    for i in range(ndim, 0, -1):
        array = np.concatenate(np.split(array, indices_or_sections=2, axis=i), axis=0)
    return array

In [28]:
class kdTree():
    """ Implements a kd-tree data structure for volumetric/spatial objects. Works with arrays of spatial data as well
    as linearised token sequence representations.

    This class allows to transform array with spatial elements into kd-trees, where k can by any natural number. Each
    node represents mixed elements, which can be split in its branches. Each leaf represents a final element which is
    either completly empty or completly occupied. These structure can be than linearized as a sequence of tokens, which
    is equivalent to the kd-tree. In the same way, as arrays with spatial elements can be transformed into kd-trees,
    token sequences can be transformed into kd-trees. This allows to seamlessly transform arrays of spatial data into
    token sequences and vice versa.
    """
    def __init__(self):
        """ Initializes the kd-tree for the right spatial dimensionality.
        """
        super().__init__()
        # 3D -> spatial dim = 3
        self.SPATIAL_DIM = 3
        # array with directions for the child nodes (x,y,z)
        self.dirs = np.array(list(itertools.product([1, 2], repeat=self.SPATIAL_DIM)))


    def concat(self, array: np.ndarray) -> np.ndarray:
        """ Concats elements of the array along each dimension, where each subarray is given in the first axis.

        Args:
            array (np.ndarray): Numpy array, holding subarrays in the first axis.

        Return:
            np.ndarray: Array of elements with concatenated subarrays along each axis.
        """
        for i in range(1, self.SPATIAL_DIM + 1):
            array = np.concatenate(np.split(array, indices_or_sections=2, axis=0), axis=i)
            #remove first dimension
        return np.squeeze(array, axis=0)

    def insert_element_array(self, elements, max_depth=float('Inf'), depth=0, pos=None):
        """ Inserts an array of element values which is converted into a kd-tree.

        Args:
            elements: A numpy array of element values, with the dimensionality of the kd-tree.
            max_depth: The maximum depth of the resulting kd-tree. All nodes at `max_depth` are marked as final.
            depth: The current depth of the kd-tree. Used to recursively define the tree depth.

        Return:
            The current node containing inserted values. The returned node should be the root node of the kd-tree.
        """
        self.depth = depth
        self.resolution = np.array(elements.shape[0])
        self.final = True
        self.pos = np.array(self.SPATIAL_DIM * [0]) if pos is None else pos
        # '1' - all elements are empty
        # '2' - elements are empty and occupied
        # '3' - all elements are occupied
        self.value = 1 if np.max(elements) == 0 else 3 if np.min(elements) > 0 else 2

        # input has a resolution of 1 and cannot be splitt anymore
        if self.resolution <= 1:
            return self

        # splitt only when elements are mixed and we are not at maximum depth
        if self.value == 2 and depth <= max_depth:
            self.final = False

            # split elements into subarrays
            sub_elements = split(elements)

            # compute new positions for future nodes
            # layerwise intertwined_positions
            new_pos = [2 * self.pos + d for d in self.dirs]

            # create child nodes
            self.child_nodes = [
                kdTree().insert_element_array(e, max_depth, depth + 1, p)
                for e, p in zip(sub_elements, new_pos)
            ]

        return self

    def get_element_array(self, depth=float('Inf'), mode='occupancy'):
        """ Converts the kd-tree into an array of elements.

        Args:
            depth: Defines the maximum depth of the children nodes, of which the value will be returned in the array.

        Return:
            A numpy array with the dimensionality of the kd-tree, which hold values defined by `mode`.

        """
        res = self.SPATIAL_DIM * [self.resolution]
        if self.final or self.depth == depth:
            # return empty array if all elements are empty
            if self.value == 1:
                return np.tile(0, res)
            # else return value based on `mode`
            elif mode == 'occupancy':
                return np.tile(1, res)

        return self.concat(np.array([node.get_element_array(depth, mode) for node in self.child_nodes]))

    def insert_token_sequence(self, value, resolution, max_depth=float('Inf')):
        """ Inserts a token sequence which is parsed into a kd-tree.

        Args:
            value: A token sequence representing a spatial object. The values should consist only of '1', '2' and '3'.
                The sequence can be eiter a string or an array of strings or integers.
            resolution: The resolution of the token sequence. This value should be a power of 2.
            max_depth: The maximum depth up to which the token sequence will be parsed.

        Return:
            A node which represents the given token sequence. The returned node should be the root node of the kd-tree.
        """
        # fail-fast: malformed input sequence
        all_tokens_valid = all(str(c) in '123' for c in value)
        if not all_tokens_valid:
            raise ValueError(
                "ERROR: Input sequence consists of invalid tokens. Check token values and array type." +
                f"Valid tokens consist of 1 (white), 2 (mixed) and 3 (black). Sequence: {value}."
            )

        # initialize self
        self.value = 0
        self.depth = 0
        self.pos = np.array(self.SPATIAL_DIM * [0])
        self.resolution = np.array(resolution)
        self.final = False

        # initialize parser
        depth = 1
        final_layer = False
        resolution = resolution // 2

        # initialize first nodes
        open_set = []
        self.child_nodes = [kdTree() for _ in range(2**self.SPATIAL_DIM)]
        open_set.extend(self.child_nodes)
        node_counter = len(open_set)

        # compute new positions for future nodes
        # layerwise intertwined_positions
        pos_set = [2 * self.pos + d for d in self.dirs]

        while len(value) > 0 and depth <= max_depth and len(open_set) > 0:
            # consume first token of sequence
            head = int(value[0])
            value = value[1:] if len(value) > 0 else value
            node_counter -= 1

            # get next node that should be populated
            node = open_set.pop(0)

            # assign values to node
            node.value = head
            node.depth = depth
            node.pos = pos_set.pop(0)
            node.resolution = np.array(resolution)

            # final node:
            # - head is '1' or '3', thus all elements have the same value
            # - the resolution is 1, thus the elements cannot be split anymore
            # - we are in the last depth layer, thus all nodes are final
            node.final = head in (1, 3) or np.array_equal(resolution, [1]) or final_layer
            if not node.final:
                node.child_nodes = [kdTree() for _ in range(2**self.SPATIAL_DIM)]
                open_set.extend(node.child_nodes)

                # TODO: add 'intertwined' position encoding
                # compute new positions for future nodes - center of all pixels
                pos_set.extend([node.pos + node.resolution // 2 * d for d in self.dirs])

            # update depth
            if node_counter <= 0:
                depth += 1
                resolution = np.array(resolution // 2)
                # return if the resolution becomes less than 1 - no visible elements
                if resolution < 1:
                    return self

                #TODO delete?
                node_counter = len(open_set)
                # fail-fast: malformed input sequence
                if len(value) < node_counter:
                    # perform simple sequence repair by appending missing tokens
                    value = np.append(value, [0 for _ in range(node_counter - len(value))])

                if len(value) == node_counter:
                    final_layer = True

        return self

    def get_token_sequence(self, depth=float('Inf'), return_depth=False, return_pos=False):
        """ Returns a linearised sequence representation of the kd-tree.

        Args:
            depth: Defines the maximum depth of the nodes, up to which the tree is parsed.
            return_depth: Selects if the corresponding depth sequence should be returned.
            return_pos: Selects if the corresponding position sequence should be returned.

        Return
            A numpy array consisting of integer values representing the linearised kd-tree. Returns additionally the
            corresponding depth and position sequence if specified in `return_depth` or `return_pos`. The values are
            returned in the following order: (value, depth, position).
        """
        seq_value = []
        seq_depth = []
        seq_pos = []
        open_set = []

        # start with root node
        open_set.extend(self.child_nodes)

        while len(open_set) > 0:
            node = open_set.pop(0)

            # reached sufficient depth - return sequence so far
            if node.depth > depth:
                break

            seq_value += [node.value]
            seq_depth += [node.depth]
            seq_pos += [node.pos]

            if not node.final:
                open_set += node.child_nodes

        seq_value = np.asarray(seq_value)
        seq_depth = np.asarray(seq_depth)
        seq_pos = np.asarray(seq_pos)

        # output format depends in flags 'return_depth' and 'return_pos'
        output = [seq_value]
        if return_depth:
            output += [seq_depth]
        if return_pos:
            output += [seq_pos]
        return output


## Data Set
To create a pytorch data set, we can inheret from the torch.utils.data.Dataset class. This class requires us to implement the __len__ and __getitem__ methods. The __len__ method returns the number of samples in the data set and the __getitem__ method returns a sample from the data set at a given index.
In our case, before returning an item from the data set, we first have to create an Octree from the voxel data and linearize it.

In [29]:
class ShapeNet(Dataset):
    """
    Custom dataset class for handling ShapeNet voxel grid data.
    
    Args:
        shape_dir (str): Directory containing the voxel grid files.
    
    Attributes:
        shape_dir (str): Directory containing the voxel grid files.
        file_list (list): List of filenames in the shape directory.
        path (list): List of paths to individual shape files.
    """

    def __init__(self, shape_dir):
        """
        Initialize the ShapeNet dataset.

        Args:
            shape_dir (str): Directory containing the voxel grid files.
            transform (callable, optional): A function/transform to apply to each voxel grid.
            target_transform (callable, optional): A function/transform to apply to each target.
        """
        self.shape_dir = shape_dir
        self.file_list = os.listdir(shape_dir)
        self.path = [os.path.join(shape_dir, file) for file in self.file_list]
        
    def __len__(self):
        """
        Get the total number of voxel grids in the dataset.

        Returns:
            int: Total number of voxel grids.
        """
        return len(self.path)

    def __getitem__(self, idx):
        """
        Get the voxel grid data and its corresponding sequence representation.

        Args:
            idx (int): Index of the desired data sample.

        Returns:
            np.ndarray: Sequence representation of the voxel grid.
        """
        # load the voxel grid
        voxels = np.load(self.path[idx])
        # linearize the grid with fast function to speed up training
        seq = linearize(voxels)
        return seq


In this section, we will implement a collate function that takes a batch of sequences with varying lengths, pads them to equal length, and transforms them into three tensors: one for values, one for depth, and one for position. This function is crucial for preprocessing our data before feeding it into a deep learning model.


![Sequence](../images/sequence.png)

In [30]:
class EncoderOnlyCollate():
    """ Creates a collate module, which pads batched sequences to equal length with the padding token '0'. """

    def __call__(self, batch):
        """ Pads and packs a list of samples for the 'encoder_only' architecture. """
        # pad batched sequences with '0' to same length
        seq = pad_batch(batch)
        # update : müsste doch stimmen? das eine ist unser input gebatch das andere die labels nicht im batch oder seq[-1]
        return seq, seq[0]

def to_sequences(batch):
    """ Transform a list on numpy arrays into sequences of pytorch tensors. """
    batch = [(torch.tensor(v), torch.tensor(d), torch.tensor(p)) for v, d, p in batch]

    # unpack batched sequences
    return zip(*batch)


def pad_batch(batch):
    """ Unpack batch and pad each sequence to a tensor of equal length. """
    val, dep, pos = to_sequences(batch)

    # pad each sequence
    val_pad = pad_sequence(val, batch_first=True, padding_value=PADDING_VALUE)
    dep_pad = pad_sequence(dep, batch_first=True, padding_value=PADDING_VALUE)
    pos_pad = pad_sequence(pos, batch_first=True, padding_value=PADDING_VALUE)

    return val_pad, dep_pad, pos_pad

# Octree Transformer Architecture


![Overview](../images/overview.png)

In this section, we will delve into the architecture of the Octree Transformer model, which is composed of three key components that work collaboratively to process sequences with an octree structure. Let's explore each component in detail.

## 1. Embedding Layer

The embedding layer is the starting point of our Octree Transformer. Here, we embed both the input values and positions into a vector space. However, due to the exponential growth of linearized octrees, we need to employ compression techniques to manage computational complexity effectively.

## 2. Transformer Stack

The heart of our model lies within the Transformer Stack. This stack is comprised of multiple transformer blocks, each playing a crucial role in generating representations for our sequences. Each transformer block is made up of two pivotal components: the Multi-Head Attention layer and the Feed Forward layer. For further information on the transformer architecture we refer to the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper.

## 3. Generative Head

The final phase of the Octree Transformer is the Generative Head, which is responsible for generating the output sequence. This head consists of two crucial components: the Linear layer and the Softmax layer.

### Linear Layer
The Linear layer functions as a bridge, mapping the output of the transformer stack to the vocabulary size. It enables the model to project the learned representations into a format compatible with the vocabulary.

### Softmax Layer
The Softmax layer takes the transformed output and generates a probability distribution over the vocabulary. By assigning probabilities to different tokens, the Softmax layer facilitates the generation of the final sequence.

## Addressing Compression and Upsampling

The Octree Transformer employs compression techniques to handle the expansion of linearized octrees, ensuring computational feasibility. Additionally, to counteract the effects of compression, certain layers of the octree require the Generative Head to generate multiple tokens. This challenge is tackled through upsampling methods like deconvolution, allowing the model to generate the necessary tokens for a given layer.

<span style="font-style: italic;">Now, let's get started!</span>


## Embedding Layer

When working with tokens, we have to use a prober embedding for the tokens. Furthermore, we will use learned positional encodings.
The learned positional encoding represent the position of the token in the Octree. They are learned and can thereby be trained as part of the model. We create the positional encoding by **adding** the encoded value of the x,y,z coordinates.

In [31]:
class PositionalEncoding(nn.Module):
    """
    Positional Encoding module for adding positional information to input sequences.
    
    Args:
        resolution (int): The resolution of the 3D grid.
        embed_dim (int): The dimensionality of the embedding.
    
    Attributes:
        x_encoding (nn.Embedding): Embedding layer for x-coordinate positional encoding.
        y_encoding (nn.Embedding): Embedding layer for y-coordinate positional encoding.
        z_encoding (nn.Embedding): Embedding layer for z-coordinate positional encoding.
    """

    def __init__(self, resolution, embed_dim):
        """
        Initialize the PositionalEncoding module.
        
        Args:
            resolution (int): The resolution of the 3D grid.
            embed_dim (int): The dimensionality of the embedding.
        """
        super(PositionalEncoding, self).__init__()
        # Using *2 because the sequence length can grow to 2*resolution
        self.x_encoding = nn.Embedding(2 * resolution, embed_dim, padding_idx=PADDING_VALUE)
        self.y_encoding = nn.Embedding(2 * resolution, embed_dim, padding_idx=PADDING_VALUE)
        self.z_encoding = nn.Embedding(2 * resolution, embed_dim, padding_idx=PADDING_VALUE)

    def forward(self, position):
        """
        Apply positional encoding to the input position tensor.
        
        Args:
            position (torch.Tensor): Input tensor containing position information.
                Shape: (batch_size, sequence_length, 3)
        
        Returns:
            torch.Tensor: The positional embeddings for each position in the input tensor.
                Shape: (batch_size, sequence_length, embed_dim)
        """
        x = self.x_encoding(position[:, :, 0])  # Get x-coordinate embeddings
        y = self.y_encoding(position[:, :, 1])  # Get y-coordinate embeddings
        z = self.z_encoding(position[:, :, 2])  # Get z-coordinate embeddings
        return x + y + z  # Combine embeddings for positional encoding


Now we can create the whole embedding layer for the input using the positional encoding layer and another **embedding for the values** of the nodes.
Finally, we will add both outputs and return the result.

In [32]:
class EmbeddingLayer(nn.Module):
    """
    Embedding Layer combining value and positional embeddings.

    Args:
        num_vocab (int): Number of vocabulary items.
        embed_dim (int): Dimensionality of the embeddings.
        resolution (int): Resolution of the 3D grid.

    Attributes:
        positional_encoding (PositionalEncoding): Positional encoding module.
        value_embedding (nn.Embedding): Embedding layer for value embeddings.
        mask (torch.Tensor): Mask to handle padding values.
    """

    def __init__(self, num_vocab, embed_dim, resolution):
        """
        Initialize the EmbeddingLayer.

        Args:
            num_vocab (int): Number of vocabulary items.
            embed_dim (int): Dimensionality of the embeddings.
            resolution (int): Resolution of the 3D grid.
        """
        super(EmbeddingLayer, self).__init__()
        self.positional_encoding = PositionalEncoding(resolution, embed_dim)
        self.value_embedding = nn.Embedding(num_vocab, embed_dim, padding_idx=PADDING_VALUE)

    def forward(self, value, depth, position):
        """
        Apply the embedding layer to the input value, depth, and position tensors.

        Args:
            value (torch.Tensor): Input tensor containing value indices.
                Shape: (batch_size, sequence_length)
            depth (torch.Tensor): Input tensor containing depth values.
                Shape: (batch_size, sequence_length)
            position (torch.Tensor): Input tensor containing position information.
                Shape: (batch_size, sequence_length, 3)

        Returns:
            torch.Tensor: Combined embeddings of positional and value information.
                Shape: (batch_size, sequence_length, embed_dim)
        """
        self.mask = (value == PADDING_VALUE)  # Create mask for padding values
        pos = self.positional_encoding(position)  # Get positional embeddings
        val = self.value_embedding(value)  # Get value embeddings
        return pos + val  # Combine positional and value embeddings

    def get_masks(self):
        """
        Get the mask tensor used for handling padding values.

        Returns:
            torch.Tensor: Mask tensor.
        """
        return self.mask


*There is still one problem left...*


**The Challenge**: As the octree structure unfolds, the size of the linearized representation grows exponentially. This escalation in size can impede the model's efficiency, rendering it computationally burdensome, particularly in the lower levels of the octree.

**The Solution**: Our remedy involves the application of convolutions across the embedded sequence. Convolutional operations allow us to effectively compress the information while retaining essential features and patterns. By concentrating on the lower octree levels, where the exponential growth is most pronounced, we strike a balance between efficient computation and information retention.

**Convolutional** Compression: The essence of this technique lies in leveraging convolutional layers to extract essential features and patterns from the sequence of embedded tokens. By doing so, we attain a compressed representation that captures crucial information while alleviating the computational intensity associated with extensive linearized structures.


![Compression](../images/compression.png)

In [33]:
class ConvolutionEmbedding(nn.Module):
    """
    Convolutional Embedding module that combines positional and value embeddings
    and applies 1D convolution on the embedded sequence.

    Args:
        num_vocab (int): Number of vocabulary items.
        embed_dim (int): Dimensionality of the embeddings.
        resolution (int): Resolution of the 3D grid.
        conv_size (int): Kernel size for convolution.

    Attributes:
        num_vocab (int): Number of vocabulary items.
        embed_dim (int): Dimensionality of the embeddings.
        resolution (int): Resolution of the 3D grid.
        chunk_size (int): Size of chunks for convolution.
        embedding (EmbeddingLayer): Embedding layer for combining embeddings.
        conv (nn.Conv1d): 1D convolutional layer.
        mask (torch.Tensor): Mask to handle padding values.
    """

    def __init__(self, num_vocab, embed_dim, resolution, conv_size):
        """
        Initialize the ConvolutionEmbedding module.

        Args:
            num_vocab (int): Number of vocabulary items.
            embed_dim (int): Dimensionality of the embeddings.
            resolution (int): Resolution of the 3D grid.
            conv_size (int): Kernel size for convolution.
        """
        super(ConvolutionEmbedding, self).__init__()
        self.num_vocab = num_vocab
        self.embed_dim = embed_dim
        self.resolution = resolution
        self.chunk_size = conv_size
        self.embedding = EmbeddingLayer(num_vocab, embed_dim, resolution)
        self.conv = nn.Conv1d(embed_dim, embed_dim, kernel_size=conv_size, stride=conv_size)

    def forward(self, value, depth, position):
        """
        Apply ConvolutionEmbedding to input value, depth, and position tensors.

        Args:
            value (torch.Tensor): Input tensor containing value indices.
                Shape: (batch_size, sequence_length)
            depth (torch.Tensor): Input tensor containing depth values.
                Shape: (batch_size, sequence_length)
            position (torch.Tensor): Input tensor containing position information.
                Shape: (batch_size, sequence_length, 3)

        Returns:
            torch.Tensor: Resulting tensor after applying convolution.
                Shape: (batch_size, new_sequence_length, embed_dim)
        """
        # create a mask in the same size of the resulting embedding (chunk_size steps)
        self.mask = self.padding_mask(value[:, ::self.chunk_size])
        # create embedding for sequence
        embedded_seq = self.embedding(value, depth, position)
        # apply convolution
        return self.conv(embedded_seq.permute(0, 2, 1)).permute(0, 2, 1)

    def padding_mask(self, value):
        """
        Create a mask to handle padding values.

        Args:
            value (torch.Tensor): Input tensor containing value indices.
                Shape: (batch_size, sequence_length)

        Returns:
            torch.Tensor: Mask tensor for padding values.
        """
        return value == PADDING_VALUE


As we reach a critical point in our deep learning journey, we're focusing on creating embeddings that cover our entire sequence. This move is a significant step towards effectively utilizing octree structures. Our strategy involves using convolutional techniques tailored to each layer of the octree, which is a key aspect of our approach.

**Essential Strategy**: At the core of our method is the idea of layer-specific embeddings. We recognize that different layers of the octree need different kinds of information. To address this, we're using convolutional embeddings customized for each layer.

**Varied Information Gathering**: What makes our approach unique is its ability to gather information at different levels. This is achieved by applying different sizes of convolutional filters to each layer. These filters help us capture information while adapting to the specific details of each layer.

**Step-by-Step Layer Embeddings**: Our process for creating these embeddings happens step by step, layer by layer. For each layer, we create a specific convolutional embedding that caters to that layer's specific information needs. This method ensures that every part of the octree sequence is well represented in the embeddings.

In [34]:
class SequenceEmbedding(nn.Module):
    """
    Sequence Embedding module that combines different ConvolutionEmbedding layers.

    Args:
        num_vocab (int): Number of vocabulary items.
        embed_dim (int): Dimensionality of the embeddings.
        resolution (int): Resolution of the 3D grid.

    Attributes:
        embeddings (list): List of ConvolutionEmbedding instances for different layers.
        mask (torch.Tensor): Padding mask for the sequence.
    """

    def __init__(self, num_vocab, embed_dim, resolution):
        """
        Initialize the SequenceEmbedding module.

        Args:
            num_vocab (int): Number of vocabulary items.
            embed_dim (int): Dimensionality of the embeddings.
            resolution (int): Resolution of the 3D grid.
        """
        super(SequenceEmbedding, self).__init__()
        self.embeddings = [
            ConvolutionEmbedding(num_vocab, embed_dim, resolution, conv_size=1),
            ConvolutionEmbedding(num_vocab, embed_dim, resolution, conv_size=1),
            ConvolutionEmbedding(num_vocab, embed_dim, resolution, conv_size=1),
            ConvolutionEmbedding(num_vocab, embed_dim, resolution, conv_size=4),
            ConvolutionEmbedding(num_vocab, embed_dim, resolution, conv_size=8),
            ConvolutionEmbedding(num_vocab, embed_dim, resolution, conv_size=8),
        ]

    def forward(self, value, depth, position):
        """
        Apply the SequenceEmbedding to input value, depth, and position tensors.

        Args:
            value (torch.Tensor): Input tensor containing value indices.
                Shape: (batch_size, sequence_length)
            depth (torch.Tensor): Input tensor containing depth values.
                Shape: (batch_size, sequence_length)
            position (torch.Tensor): Input tensor containing position information.
                Shape: (batch_size, sequence_length, 3)

        Returns:
            torch.Tensor: Padded sequence embedding tensor.
                Shape: (batch_size, padded_sequence_length, embed_dim)
        """

        padding_mask = []
        # extract value, depth, and position sequence of the current sample
        val, dep, pos = value[0], depth[0], position[0]
        b_emb = torch.tensor([])

        # embed layerwise
        for layer_idx, embedding in enumerate(self.embeddings):
            layer_depth = layer_idx + 1
            # extract layer sequence
            val_seq = val[dep == layer_depth]
            dep_seq = dep[dep == layer_depth]
            pos_seq = pos[dep == layer_depth]

            if val_seq.shape[0] == 0:
                break

            # compute layer embedding
            layer_emb = embedding(
                val_seq.unsqueeze(0),
                dep_seq.unsqueeze(0),
                pos_seq.unsqueeze(0),
            )[0]
            # append layer embedding to embedding
            b_emb = torch.cat([b_emb, layer_emb])

        sequence = b_emb.unsqueeze(0)
        # create padding mask
        padding_mask = [torch.zeros(b_emb.shape[0], dtype=torch.bool)]
        # pad the sequence with 1s so the transformer will not attend to these tokens
        self.mask = pad_sequence(padding_mask, batch_first=True, padding_value=1)
        # pad embedding sequence
        return pad_sequence(sequence, batch_first=True, padding_value=0.0)

    def get_mask(self):
        """
        Get the padding mask for the sequence.

        Returns:
            torch.Tensor: Padding mask, where padding tokens '0' of the value sequence are masked out.
        """
        return self.mask


In [35]:
def padding_mask(input_sequence):
    """ Create a padding mask for the given input.

        Always assumens '0' as a padding value. `input_sequence` has the shape (N, S).
    """
    return torch.zeros_like(input_sequence).masked_fill(input_sequence == 0, 1).bool()


## Generative Head

For the time being, we'll set aside the transformer aspect and direct our attention to the generative head. The concept here shares **similarity with the embedding process**, albeit in the opposite direction. We initiate the process with a linear head, incorporating a GELU activation function followed by a linear layer. To augment the expressive capabilities of this head, we introduce positional encoding. This integration of positional encoding serves to enhance the model's ability to understand and generate contextually meaningful outputs.

In [36]:
class LinearHead(nn.Module):
    """
    Linear Head module for generating output tokens.

    Args:
        num_vocab (int): Number of vocabulary items.
        embed_dim (int): Dimensionality of the embeddings.
        resolution (int): Resolution of the 3D grid.

    Attributes:
        linear (nn.Linear): Linear layer for token generation.
        pos_enc (PositionalEncoding): Positional encoding module.
        activation (nn.GELU): Activation function for intermediate processing.
    """

    def __init__(self, num_vocab, embed_dim, resolution):
        """
        Initialize the LinearHead module.

        Args:
            num_vocab (int): Number of vocabulary items.
            embed_dim (int): Dimensionality of the embeddings.
            resolution (int): Resolution of the 3D grid.
        """
        super().__init__()
        self.linear = nn.Linear(embed_dim, num_vocab)
        self.pos_enc = PositionalEncoding(resolution, num_vocab)
        self.activation = nn.GELU()

    def forward(self, x, value, depth, pos):
        """
        Apply LinearHead to input data.

        Args:
            x (torch.Tensor): Input tensor.
            value (torch.Tensor): Input tensor containing value indices.
                Shape: (batch_size, sequence_length)
            depth (torch.Tensor): Input tensor containing depth values.
                Shape: (batch_size, sequence_length)
            pos (torch.Tensor): Input tensor containing position information.
                Shape: (batch_size, sequence_length, 3)

        Returns:
            torch.Tensor: Generated token tensor.
        """
        x = self.activation(x)  # Apply activation function
        x = self.linear(x)  # Apply linear transformation
        pos_enc = self.pos_enc(pos)  # Get positional encoding
        x = x + pos_enc  # Add positional encoding
        return x


As previously explained, we need to **reverse** the compression applied in the sequence embedding. To achieve this, we employ a convolutional head that facilitates **deconvolutions** for upsampling the input to a specified conv_size. In order to enhance this upsampling procedure with extra contextual information, we leverage the sequence embedding by passing it through a convolutional layer. This approach integrates the benefits of both upsampling and embedding, contributing to the overall effectiveness of the model.

In [37]:
class ConvolutionalHead(nn.Module):
    """
    Convolutional Head module for generating output tokens using convolutional operations.

    Args:
        num_vocab (int): Number of vocabulary items.
        embed_dim (int): Dimensionality of the embeddings.
        head_dim (int): Dimensionality of the intermediate representation.
        resolution (int): Resolution of the 3D grid.
        conv_size (int): Kernel size for convolution.

    Attributes:
        activation (nn.GELU): Activation function for intermediate processing.
        deconvolution (nn.ConvTranspose1d): Transposed convolution layer.
        convolution (BlockConvolution): BlockConvolution layer for convolution.
        embed (EmbeddingLayer): Embedding layer for combining embeddings.
        linear (nn.Linear): Linear layer for token generation.
    """

    def __init__(self, num_vocab, embed_dim, head_dim, resolution, conv_size):
        """
        Initialize the ConvolutionalHead module.

        Args:
            num_vocab (int): Number of vocabulary items.
            embed_dim (int): Dimensionality of the embeddings.
            head_dim (int): Dimensionality of the intermediate representation.
            resolution (int): Resolution of the 3D grid.
            conv_size (int): Kernel size for convolution.
        """
        super(ConvolutionalHead, self).__init__()
        self.activation = nn.GELU()
        self.deconvolution = nn.ConvTranspose1d(embed_dim, head_dim, conv_size, stride=conv_size)
        self.convolution = BlockConvolution(head_dim, head_dim, conv_size)
        self.embed = EmbeddingLayer(num_vocab, head_dim, resolution)
        self.linear = nn.Linear(head_dim, num_vocab)

    def forward(self, x, value, depth, pos):
        """
        Apply ConvolutionalHead to input data.

        Args:
            x (torch.Tensor): Input tensor.
            value (torch.Tensor): Input tensor containing value indices.
                Shape: (batch_size, sequence_length)
            depth (torch.Tensor): Input tensor containing depth values.
                Shape: (batch_size, sequence_length)
            pos (torch.Tensor): Input tensor containing position information.
                Shape: (batch_size, sequence_length, 3)

        Returns:
            torch.Tensor: Generated token tensor.
        """
        x = self.activation(x)  # Apply activation function
        x = self.deconvolution(x.transpose(1, 2)).transpose(1, 2)  # Apply deconvolution, transpoe dimesnions to convolute over the featrues
        embed = self.embed(value, depth, pos)  # Get embedded representation
        embed = self.activation(x)  # Apply activation function
        embed = self.convolution(embed[:, :x.shape[1]])  # Apply convolution (left bound should probably not be necessary)
        x = x + embed  # Add convolution output to the deconvolution output
        x = self.activation(x)  # Apply activation function
        return self.linear(x)  # Generate output tokens using linear layer


In Block Convolution, we perform a 1D convolution operation on the input and advance features in a block-wise manner, unraveling insights block by block.

In [38]:
class BlockConvolution(nn.Module):
    def __init__(self, source_dim, target_dim, block_size):
        """ Performs masked blockwise convolution on an input sequence.
            The mask is always an upper right triangle matrix with zeros on the diagonal.

        Args:
            source_dim: Defines the embedding dimension of the input sequence.
            target_dim: Defines the embedding dimension of the output sequence.
            block_size: Defines the size of the block over which we convolute.
        """
        super(BlockConvolution, self).__init__()

        self.block_size = block_size
        self.convolution = nn.Conv1d(source_dim, target_dim, (1,), bias=False)
        sigma = math.sqrt(1. / (block_size * source_dim))
        self.bias = nn.Parameter(torch.empty(block_size))
        nn.init.uniform_(self.bias, -sigma, sigma)

    def forward(self, seq_vector):
        """ Convolute tokens to reduce sequence length

        Args:
            seq_vector: Sequence vector with elements of the shape [N, S, E].

        Return:
            Sequence vector with the same length and target embedding dimension [N, S, E']
        """
        #basic convolution with some transpose for dimension fit
        features = self.convolution(seq_vector.transpose(1, 2)).transpose(1, 2)
        #vecor in the from of sequence vector filled wtih 0
        out = torch.zeros_like(seq_vector)
        for i in range(self.block_size):
            #loop over block element
            for j in range(i):
                #add to one element the feature vector of all previous elements but not the elemnts infron of it
                out[:, i::self.block_size] += features[:, j::self.block_size]
            #add bias i to every following block elemnt on pos i
            out[:, i::self.block_size] += self.bias[i]

        return out

To achieve the final generative output, we employ a linear head for the first three layers, which **don't involve compression**. Following that, we switch to the convolutional head for the next three layers to **reverse the compression**. This results in **autoregressive** and **layer-wise** token generation.

In [39]:
class GenerativeHead(nn.Module):
    """
    Generative Head module for transforming the output of the transformer into target value logits.

    Args:
        num_vocab (int): Number of vocabulary items.
        embed_dim (int): Dimensionality of the embeddings.
        head_dim (int): Dimensionality of the intermediate representation.
        resolution (int): Resolution of the 3D grid.

    Attributes:
        num_vocab (int): Number of vocabulary items.
        embed_dim (int): Dimensionality of the embeddings.
        head_dim (int): Dimensionality of the intermediate representation.
        resolution (int): Resolution of the 3D grid.
        fc (nn.Linear): Linear layer for transformation.
        heads (list): List of different heads for transformation.
        reduction_factor (dict): Dictionary containing reduction factors for each layer depth.
    """

    def __init__(self, num_vocab, embed_dim, head_dim, resolution):
        """
        Initialize the GenerativeHead module.

        Args:
            num_vocab (int): Number of vocabulary items.
            embed_dim (int): Dimensionality of the embeddings.
            head_dim (int): Dimensionality of the intermediate representation.
            resolution (int): Resolution of the 3D grid.
        """
        super().__init__()
        self.num_vocab = num_vocab
        self.embed_dim = embed_dim
        self.head_dim = head_dim
        self.resolution = resolution
        self.fc = nn.Linear(embed_dim, num_vocab)
        self.heads = [
            LinearHead(num_vocab, embed_dim, resolution),
            LinearHead(num_vocab, embed_dim, resolution),
            LinearHead(num_vocab, embed_dim, resolution),
            ConvolutionalHead(num_vocab, embed_dim, head_dim, resolution, conv_size=4),
            ConvolutionalHead(num_vocab, embed_dim, head_dim, resolution, conv_size=8),
            ConvolutionalHead(num_vocab, embed_dim, head_dim, resolution, conv_size=8)
        ]
        self.reduction_factor = {
            1: 1,
            2: 1,
            3: 1,
            4: 4,
            5: 8,
            6: 8
        }

    def forward(self, x, value, depth, position):
        """
        Transform the output of the transformer into target value logits.

        Args:
            x (torch.Tensor): Output of the transformer, the latent vector [N, T, E].
            value (torch.Tensor): Target value token sequence [N, T].
            depth (torch.Tensor): Target depth token sequence [N, T].
            position (torch.Tensor): Target position token sequence [N, T, A].

        Returns:
            torch.Tensor: Logits of target value sequence.
        """
        out = []

        # Process each sample individually (N=1), also squeeze redundant
        for latent_vec, val, dep, pos in zip(x, value, depth, position):

            logits = torch.tensor([])
            vector_idx = 0

            # Compute logits layerwise
            for layer_idx, head in enumerate(self.heads):
                layer_depth = layer_idx + 1
                # Get value, depth, and position sequence of current layer
                layer_val = val[dep == layer_depth]
                layer_dep = dep[dep == layer_depth]
                layer_pos = pos[dep == layer_depth]
                # stop if no more tokens in current layer
                if layer_pos.shape[0] == 0:
                    break
                # Compute number of vectors in latent vector of current layer
                # because we might have X tokens but only X/red_factor many feature vectors due to reduction
                num_vectors = torch.sum(dep == layer_depth) // self.reduction_factor[layer_depth]

                # Filter latent vector of current layer
                layer_vec = latent_vec[vector_idx:vector_idx + num_vectors]
                # Compute layer logits
                layer_logits = head(
                    layer_vec.unsqueeze(0),
                    layer_val.unsqueeze(0),
                    layer_dep.unsqueeze(0),
                    layer_pos.unsqueeze(0),
                )[0]
                logits = torch.cat([logits, layer_logits])

                # Discard processed tokens
                vector_idx += num_vectors
            out += [logits]

        # Pad embedding sequence
        return pad_sequence(out, batch_first=True, padding_value=0.0)


## Transformer Stack
![Transfromer](https://tfwiki.net/mediawiki/images2/thumb/3/37/Optimusg1.jpg/350px-Optimusg1.jpg)

In this stage, we construct the transformer stack using the PyTorch implementation. We focus solely on the encoder component, as is commonly done. It's important to recognize that even though our model uses the encoder, it behaves like a decoder by autoregressively generating tokens. The transfromer enables us to generate features out of the sequence embeddings.

In [40]:
class Transformer(nn.Module):
    """
    Transformer module for sequence processing using self-attention mechanism.

    Args:
        embed_dim (int): Dimensionality of the embeddings.
        num_heads (int): Number of attention heads.
        dropout (float): Dropout rate for attention layers.
        num_layers (int): Number of transformer layers.

    Attributes:
        sos (nn.Parameter): Start of sequence token.
        transformer (nn.TransformerEncoder): Transformer encoder stack.
    """

    def __init__(self, embed_dim, num_heads, dropout, num_layers):
        """
        Initialize the Transformer module.

        Args:
            embed_dim (int): Dimensionality of the embeddings.
            num_heads (int): Number of attention heads.
            dropout (float): Dropout rate for attention layers.
            num_layers (int): Number of transformer layers.
        """
        super().__init__()
        self.sos = nn.Parameter(torch.zeros(embed_dim))
        nn.init.normal_(self.sos)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=4 * embed_dim,
            dropout=dropout,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
            norm=nn.LayerNorm(embed_dim))

    def forward(self, input_seq, padding_mask):
        """
        Apply the Transformer to input sequence.

        Args:
            input_seq (torch.Tensor): Input sequence tensor.
                Shape: (batch_size, sequence_length, embed_dim)
            padding_mask (torch.Tensor): Padding mask for the input sequence.
                Shape: (batch_size, sequence_length)

        Returns:
            torch.Tensor: Output sequence tensor after processing by the Transformer.
        """
        # Create a tensor of shape batch_size x 1 x embed_dim filled with the SOS token
        sos = self.sos.unsqueeze(0).unsqueeze(0).repeat(input_seq.shape[0], 1, 1)
        batch_size, seq_len, _ = input_seq.shape
        # Concatenate the SOS token with the input sequence (except the last token)
        input_seq = torch.cat([sos, input_seq[:, :-1]], dim=1)
        # Generate mask for the attention mechanism
        mask = get_mask(seq_len)
        # Process input sequence through the Transformer stack to get the output sequence
        output_seq = self.transformer(
            src=input_seq,
            mask=mask,  # [L, L]
            src_key_padding_mask=padding_mask,  # [N, L]
        )
        return output_seq

def get_mask(seq_len):
    """
    Creates a diagonal mask to prevent self-attention from looking ahead.

    Args:
        seq_len (int): Length of the sequence.

    Returns:
        torch.Tensor: Diagonal mask tensor.
    """
    # create a matrix of dimension seq_len x seq_len filled with -Inf
    attn_mask = torch.full((seq_len, seq_len), -float("Inf"))
    # just keep the upper triangular part of the matrix and use this as mask
    return torch.triu(attn_mask, diagonal=1)



## OctreeTransformer in PyTorch Lightning
<div>
<img src="https://cdn-fgnbn.nitrocdn.com/CExiFXHDXAeTGrlIkvRSSLnZISOqDumi/assets/images/optimized/rev-bd09e47/wp-content/uploads/2022/08/wx_mining-lightning-strike_features-benefits_1200x630.png" width="500"/>
</div>

With our three crucial building blocks in place, it's time to connect the dots and form the **OctreeTransformer**. To keep things neat and simple, we've fashioned the OctreeTransformer as a **PyTorch Lightning** module. Thereby, we don't have to worry about training loops.

In [41]:
class OctreeTransformer_pl(pl.LightningModule):
    """
    Lightning module implementing an OctreeTransformer for sequence-to-sequence tasks.

    Args:
        embed_dim (int): Dimensionality of the embeddings.
        num_heads (int): Number of attention heads.
        dropout (float): Dropout rate for attention layers.
        num_layers (int): Number of transformer layers.
        num_vocab (int): Number of vocabulary items, including the padding index.
        resolution (int): Resolution of the 3D grid.

    Attributes:
        embedding (SequenceEmbedding): Sequence embedding module.
        transformer (Transformer): Transformer module.
        head (GenerativeHead): Generative head module.
    """

    def __init__(self, embed_dim, num_heads, dropout, num_layers, num_vocab, resolution):
        """
        Initialize the OctreeTransformer_pl module.

        Args:
            embed_dim (int): Dimensionality of the embeddings.
            num_heads (int): Number of attention heads.
            dropout (float): Dropout rate for attention layers.
            num_layers (int): Number of transformer layers.
            num_vocab (int): Number of vocabulary items, including the padding index.
            resolution (int): Resolution of the 3D grid.
        """
        super().__init__()
        print("Please note that num_vocab should include the padding index")
        self.embedding = SequenceEmbedding(num_vocab, embed_dim, resolution)
        self.transformer = Transformer(embed_dim, num_heads, dropout, num_layers)
        self.head = GenerativeHead(num_vocab, embed_dim, head_dim, resolution)
        self.save_hyperparameters()

    def forward(self, sequence):
        """
        Forward pass of the OctreeTransformer_pl module.

        Args:
            sequence (tuple): Tuple containing (value, depth, position) sequences.

        Returns:
            torch.Tensor: Output tensor after processing through the OctreeTransformer.
        """
        value, depth, position = sequence
        # create embedding
        embeddings = self.embedding(value, depth, position)
        # pass through transformer, with respective mask calculated from embedding
        encoder_output = self.transformer(embeddings, self.embedding.get_mask())
        # pass through head
        x = self.head(encoder_output, value, depth, position)
        return x

    def step(self, batch, batch_idx):
        """
        Step function for training and evaluation steps.

        Args:
            batch (tuple): Tuple containing input sequence and target tensor.
            batch_idx (int): Index of the batch.

        Returns:
            torch.Tensor: Calculated loss value.
        """
        sequence, target = batch
        # get logits
        output = self(sequence)
        # calculate loss
        loss = self.calculate_loss(output, target)
        return loss

    def training_step(self, batch, batch_idx):
        # training step for pl module
        loss = self.step(batch, batch_idx)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # validation step for pl module
        loss = self.step(batch, batch_idx)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        # test step for pl module
        loss = self.step(batch, batch_idx)
        self.log('test_loss', loss)
    
    def configure_optimizers(self):
        # conifgure optimizer and lr scheduler
        # AdamW optimizer 
        optimizer = AdamW(self.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)
        # Cosine annealing warm restarts scheduler
        lr_scheduler = CosineAnnealingWarmRestarts(optimizer, 200)
        return [optimizer], [lr_scheduler]

    def calculate_loss(self, output, target):
        """
        Calculate the loss between the output and target sequences.

        Args:
            output (torch.Tensor): Output tensor from the model.
            target (torch.Tensor): Target tensor containing true values.

        Returns:
            torch.Tensor: Calculated loss value.
        """
        # cross entropy loss
        loss_function = nn.CrossEntropyLoss(ignore_index=PADDING_VALUE)
        # Change shape to batch size x class x sequence length for cross entropy loss
        output = output.permute(0, 2, 1)
        # calculate loss
        return loss_function(output, target)


# Train the Model
Now everything is set up...

*Let's start the training*

In [90]:
tb_logger = pl_loggers.TensorBoardLogger('../logs/')

resolution = 64
embed_dim = 128
num_heads = 2
dropout = 0.1
num_layers = 2
#print all hyperparameters
print(f"Resolution: {resolution}")
print(f"Embedding Dimension: {embed_dim}")
print(f"Number of Heads: {num_heads}")
print(f"Dropout: {dropout}")
print(f"Number of Layers: {num_layers}")

model = OctreeTransformer_pl(embed_dim, num_heads, dropout, num_layers, NUM_VOCAB, resolution)

# Load the dataset
dataset = ShapeNet("../data")
# Create data loaders
train_loader = DataLoader(dataset, collate_fn=EncoderOnlyCollate(), batch_size=1, shuffle=True)

# Initialize Trainer
trainer = pl.Trainer(max_epochs=10, log_every_n_steps=1, logger=tb_logger)
trainer.fit(model, train_loader)

Resolution: 64
Embedding Dimension: 128
Number of Heads: 2
Dropout: 0.1
Number of Layers: 2
Please note that num_vocab should include the padding index


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name        | Type              | Params
--------------------------------------------------
0 | embedding   | SequenceEmbedding | 0     
1 | transformer | Transformer       | 396 K 
2 | head        | GenerativeHead    | 516   
--------------------------------------------------
397 K     Trainable params
0         Non-trainable params
397 K     Total params
1.590     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


# Generate Shapes

![shapes](../images/samples.png)


To generate shapes, we have to create a token generator. This generator samples tokens in an **autoregressive** way with respect to the kernel size which was used for compression. Thereby, it samples **kernel size** many tokens each time step.

In [86]:
class Generator:
    def __init__(self, model, num_tokens=1, **_):
        """ Create token generator instance which samples 'num_tokens' in one pass.

        Args:
            model: OctreeTransformer model instance.
            num_tokens: Defines the number of sampled tokens in each step.
        """
        self.model = model
        self.kernel_size = num_tokens

    def __call__(self, val, dep, pos, temperature=1.0):
        """ Sample autoregressive current value token sequence and return updated value sequence.

        Args:
            val: Value token sequence of current layer.
            dep: Depth token sequence of current layer.
            pos: Position token sequence of current layer.
            temperature: Defines the randomness of the samples.

        Return:
            Sampled token sequence with values of the current layer.
        """
        # compute indices
        token_idx = 0
        sampled_idx = len(torch.cat(val[:-1])) if len(val) > 1 else 0

        # sample tokens autoregressive
        for _ in trange(len(val[-1]) // self.kernel_size, leave=False, desc="Tokens"):
            for block_idx in range(self.kernel_size):
                # concat layers and slice sequence for speed_up
                seq = (
                    torch.cat(val)[:sampled_idx + token_idx + self.kernel_size].unsqueeze(0),
                    torch.cat(dep)[:sampled_idx + token_idx + self.kernel_size].unsqueeze(0),
                    torch.cat(pos)[:sampled_idx + token_idx + self.kernel_size].unsqueeze(0),
                )

                logits = self.model(seq)[0]

                # retrieve only logits for for current index
                sampled_token_logits = logits[sampled_idx + token_idx + block_idx]

                # compute token probabilities from logits
                sampled_token_logits[0] = -float("Inf")  # 'padding' token
                probs = torch.nn.functional.softmax(sampled_token_logits / temperature, dim=-1)  # [t, V]
                # sample next sequence token
                val[-1][token_idx + block_idx] = torch.multinomial(probs, num_samples=1)[0]

            # update indices
            token_idx += self.kernel_size

        return val[-1]

Now, we're introducing a Sampler class into the mix. For each layer, we have a Generator that predicts additional tokens by building upon the compression we performed earlier. When we use the sampler, it follows a specific sequence. Initially, it selects a **random voxel** using the precondition_resolution. Afterward, the sampler generates **additional layers** as needed to achieve the intended target_resolution. For each of these layers, there's a **dedicated generator** responsible for producing tokens based on predictions made by the model. In essence, the process involves progressively generating more details while moving towards the desired resolution.

In [87]:
class Sampler:
    def __init__(self, model, max_resolution, **_):
        """ Provides a basic implementation of the sampler for the 'encoder_only' architecture.

        Args:
            model: Model which is used for sampling.
            max_resolution: Maximum resolution the model is trained on.
        """
        self.generators = [
            Generator(model,1),
            Generator(model,1),
            Generator(model,1), 
            Generator(model,4), 
            Generator(model,8),
            Generator(model,8), 
        ]

        self.max_resolution = max_resolution

    def __call__(self, precondition_resolution, target_resolution, temperature):
        """ Perform an iterative sampling of the given sequence until reaching the end of sequence, the maximum sequence
            length or the desired resolution.

        Args:
            precondition: An array of elements (pixels/voxels) as an numpy array.
            precondition_resolution: Resolution at which the autoencoder will reconstruct the layer.
            target_resolution: Resolution up to which an object should be sampled.
            temperature: Defines the randomness of the samples.

        Return:
            A token sequence with values, encoding the final sample.
        """
        #get sample
        val, dep, pos = self.generate_sample(precondition_resolution)

        # compute the number of finished (current) layers and the maximum sampleable layer
        cur_layer = len(val)
        max_layer = int(math.log2(min(target_resolution, self.max_resolution)))

        with torch.no_grad():

            # sample layer-wise
            for idx in tqdm(range(cur_layer, max_layer), initial=cur_layer, total=max_layer, leave=True, desc="Layers"):

                # init sequences for next layer
                next_val, next_dep, next_pos = next_layer_tokens(
                    val, dep, pos, self.max_resolution
                )
                # predict value tokens for current layer
                next_val = self.generators[idx](
                    val=val + [next_val],
                    dep=dep + [next_dep],
                    pos=pos + [next_pos],
                    temperature=temperature
                )

                # append sampled tokens to current sequence
                val += [next_val]
                dep += [next_dep]
                pos += [next_pos]

                if torch.sum(next_val == 2) == 0:
                    break  # early-out, no mixed tokens sampled

        return self.postprocess(val, target_resolution)
    
    def generate_sample(self, precondition_resolution):
        """ Generate a random sample from the model.

        Args:
            precondition_resolution: Resolution at which the autoencoder will reconstruct the layer.

        Return:
            A token sequence with values, encoding the final sample.
        """

        #generate random numpy array with values between 0-2
        array_size = 3 * [self.max_resolution]
        precondition = torch.randint(low=0, high=2, size=array_size, dtype=torch.long).numpy()

        # convert input array into token sequence
        tree = kdTree()
        tree = tree.insert_element_array(precondition, max_depth=math.log2(precondition_resolution) + 1)
        value, depth, position = tree.get_token_sequence(
            depth=math.log2(precondition_resolution), return_depth=True, return_pos=True
        )

        val = []
        dep = []
        pos = []
        # extract each depth layer separately and convert to PyTorch as a long tensor
        for d in range(1, max(depth) + 1):
            val += [torch.tensor(value[depth == d], dtype=torch.long)]
            dep += [torch.tensor(depth[depth == d], dtype=torch.long)]
            pos += [torch.tensor(position[depth == d], dtype=torch.long)]

        return val, dep, pos
    
    def postprocess(self,value, target_resolution):
        """ Transform sequence of value tokens into an array of elements (voxels/pixels).

        Args:
            value: List of value token sequences for each layer as pytorch tensors.
            target_resolution: Resolution up to which an object should be sampled.

        Return:
            An array of elements as a numpy array.
        """
        # concat all layers
        value = torch.cat(value)

        # move value sequence to the cpu and convert to numpy array
        value = value.cpu().numpy()

        # insert the sequence into a kd-tree
        tree = kdTree().insert_token_sequence(
            value,
            resolution=target_resolution
        )

        # retrive pixels/voxels from the kd-tree
        return tree.get_element_array(mode="occupancy")


This function generates the **next layer** of tokens with their positions and depth, used by the sampler.

In [88]:
def next_layer_tokens(value, depth, position, max_resolution):
    """ Creates artificial tokens for the next layer of the value sequence, to match the predefined shape. Precomputes
    corresponding depth and position tokens of the sequence, too.

    Args:
        value: List of value token sequences for each layer as pytorch tensors.
        depth: List of depth token sequences for each layer as pytorch tensors.
        position: List of position token sequences for each layer as pytorch tensors.
        max_resolution: The maximal resolution the corresponding model is trained for.

    Return:
        Pre-initialised next layer sequence (value, depth, position).
    """
    dirs = np.array(list(itertools.product([1, 2], repeat=3)))
    num_children = 2**SPATIAL_DIM

    # got an empty input - initialize with default values and return
    if len(value[0]) == 0:
        value = torch.tensor(num_children * [1], dtype=torch.long)
        depth = torch.tensor(num_children * [1], dtype=torch.long)
        pos = (
            torch.ones(num_children, SPATIAL_DIM, dtype=torch.long) *
            torch.tensor(dirs)
        )
    # compute next layer depth and number of future tokens
    cur_depth = len(value)
    num_future_tokens = num_children * torch.sum(value[-1] == 2)

    # compute future sequence (non padding token) and future depth sequence
    nl_value = torch.tensor([1], dtype=torch.long).repeat(num_future_tokens)
    nl_depth = torch.tensor([cur_depth + 1], dtype=torch.long).repeat(num_future_tokens)

    # retrive and copy mixed tokens positions
    pos_token = position[-1][value[-1] == 2]
    nl_pos = torch.repeat_interleave(pos_token, num_children, dim=0)

    # compute position difference and add it to future positions with respect to predefined pattern

    nl_pos = 2 * nl_pos + torch.tensor(dirs).repeat(pos_token.shape[0], 1)
    return nl_value, nl_depth, nl_pos

Finally, let's create a Sampler instance and sample a voxel. 

*Let's see what our model has learned*

In [91]:
# create sampler
sampler = Sampler(model,max_resolution=RESOLUTION)
# create sample
sample = sampler(2,64,1)
# save sample as obj
save_obj(sample, "sample")



Tokens:   0%|          | 0/64 [00:00<?, ?it/s]



Tokens:   0%|          | 0/80 [00:00<?, ?it/s]



Tokens:   0%|          | 0/66 [00:00<?, ?it/s]



Tokens:   0%|          | 0/109 [00:00<?, ?it/s]



Tokens:   0%|          | 0/296 [00:00<?, ?it/s]

Layers: 100%|██████████| 6/6 [01:11<00:00, 14.22s/it]


# Visualization of Results
In this step we can visualize our samples and find out, if the model was able to learn what a plane looks like.

In [92]:
# use k3d to plot the sampled shape
plt_voxels = k3d.voxels(sample,
                        color_map=[0xfdc192, 0xa15525],
                        outlines_color=0xffffff)

plot = k3d.plot()
plot += plt_voxels
plot.display()



Output()