# GANCA Experimentats & Implementation

## Data

In [1]:
# import stuff
import data_helper
import importlib
importlib.reload(data_helper)
import numpy as np
import pickle
import torch
from torchsummaryX import summary
import os
from einops import rearrange
import torch.nn as nn

BLOCK2VEC_OUT_PATH = 'output/block2vec saves/block2vec 64 dim/'
NUM_WORKERS = int(os.cpu_count() / 2)

from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import pytorch_lightning as pl

In [2]:
class GANCA3DDataset(Dataset):
    def __init__(self, loaded_worlds):
        super().__init__()
        self.dataset = loaded_worlds[:,:,:,:,0]
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        example = self.dataset[index]
        return torch.tensor(example)

In [3]:
# dataset template
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        # super().__init__()
        
        self.batch_size = batch_size
        
        self.dims = (32, 32, 32) # this will be returned when calling this.size()
                
    def prepare_data(self):
        # download data
        data_helper.download_dataset()

    def setup(self, stage=None):
        # splitting data and process stuff
        print('settin gup')
        full_dataset = GANCA3DDataset(data_helper.houses_dataset())
        self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(full_dataset, [1600, 192, 185])
        
    # these funcs can also be placed directly inside a LightningModule
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=NUM_WORKERS,
        )

    def val_dataloader(self):
        # return DataLoader(self.val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=NUM_WORKERS)
        pass

    def test_dataloader(self):
        # return DataLoader(self.test_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=NUM_WORKERS)
        pass

    def predict_dataloader(self):
        pass

## Embeddings

In [4]:
with open(BLOCK2VEC_OUT_PATH + "representations.pkl", 'rb') as f:
	embeddings_dict = pickle.load(f)

In [5]:
len(embeddings_dict)

218

## Model

## Generating 3D artifacts paper experiments

### Voxel Perception Net!

In [6]:
class VoxelPerceptionNet(nn.Module):
    # Essentially running a trainable perceptor on each layer. This expands the number of channels by a factor of num_perceptions and gives us the visual features to do NCA updates
    
    def __init__(self, 
        num_in_channels, 
        num_perceptions=3, 
        normal_std=0.02, 
        use_normal_init=True, 
        zero_bias=True
    ):
        super(VoxelPerceptionNet, self).__init__()
        
        self.num_in_channels = num_in_channels
        self.normal_std = normal_std
        
        self.sequence = nn.Sequential(
            nn.Conv3d(
                self.num_in_channels, # incoming channels
                self.num_in_channels * num_perceptions, # expand num_in_channels by factor of 3
                3, # kernal size of 3, which means neighbour radius of 1
                stride=1, # stride of 1 so look at each voxel
                padding=1, # make sure to look at edge voxels
                groups=self.num_in_channels, # disconnect the perceptions so that multiple percpetions run on each channel; essentially we have num_perceptions convolutional layers side by side
                bias=False, # no bias
            ),
        )
        
        def init_weights(m):
            if isinstance(m, nn.Conv3d):
                # init weights for Conv3d layers
                nn.init.normal_(m.weight, std=normal_std)
                
                # if bias exist, init bias
                if getattr(m, "bias", None) is not None:
                    if zero_bias:
                        nn.init.zeros_(m.bias)
                    else:
                        nn.init.normal_(m.bias, std=normal_std)

        # weight initialisation
        if use_normal_init:
            with torch.no_grad():
                self.apply(init_weights)
                
    def forward(self, input):
        return self.sequence(input)

In [7]:
test2 = VoxelPerceptionNet(3, normal_std=0.02, num_perceptions=3, use_normal_init=True)

In [8]:
# print all params for debug'
for name, param in test2.named_parameters():
    if param.requires_grad:
        print(name, param.data)

sequence.0.weight tensor([[[[[-2.6198e-03, -1.5750e-02, -3.8632e-02],
           [ 3.1331e-02,  1.0204e-02, -1.5932e-02],
           [-1.3426e-02,  2.2175e-02,  7.1230e-05]],

          [[-2.6037e-02, -1.6551e-02, -2.4201e-02],
           [-1.5151e-03, -2.9265e-02,  1.2596e-02],
           [ 2.9972e-02,  1.3883e-02, -2.0188e-02]],

          [[ 2.2803e-02,  2.0536e-02,  4.6237e-02],
           [ 9.9740e-03, -1.2579e-02,  6.9827e-03],
           [ 2.3991e-02,  1.4006e-02,  1.1868e-02]]]],



        [[[[-1.6638e-02,  1.1156e-02, -1.6559e-02],
           [-2.3454e-02,  2.6331e-02,  2.8104e-03],
           [ 3.1758e-02,  3.0760e-02,  2.8987e-02]],

          [[-2.2075e-02,  2.7824e-02, -7.2198e-04],
           [ 1.7351e-02, -1.1078e-02, -1.7559e-02],
           [ 9.3460e-03, -2.2787e-02, -2.6558e-03]],

          [[ 8.4337e-03, -1.7452e-02,  2.3922e-02],
           [ 8.8808e-03, -5.5788e-03, -7.6895e-04],
           [ 1.2141e-02, -1.7628e-03, -1.9574e-03]]]],



        [[[[-8.8659e-03, -

In [9]:
_ = summary(test2, torch.rand(16, 3, 8, 8, 8))

                        Kernel Shape      Output Shape  Params  Mult-Adds
Layer                                                                    
0_sequence.Conv3d_0  [1, 9, 3, 3, 3]  [16, 9, 8, 8, 8]     243     124416
-------------------------------------------------------------------------
                      Totals
Total params             243
Trainable params         243
Non-trainable params       0
Mult-Adds             124416


### Voxel Update Net

In [10]:
class VoxelUpdateNet(nn.Module):
    # This is essentially running dense nets in parallel for each voxel. It takes visual features and predict the update for each feature. In comes the output from a VoxelPerceptionNet, which is of the shape (N, num_in_channels * num_perceptions, channel_dims[0], channel_dims[1], channel_dims[2])
    
    def __init__(self,
        num_channels = 16,
        num_perceptions=3,
        channel_dims=[64, 64],
        normal_std=0.02,
        use_normal_init=True,
        zero_bias=True,
    ):
        super(VoxelUpdateNet, self).__init__()
        
        def make_sequental(num_channels, channel_dims):
                
            # make first layer. 
            sequence = [
                # visual_feature_channels, x, y, z -> channel_dims[0], x, y, z
                nn.Conv3d(num_channels * num_perceptions, channel_dims[0], kernel_size=1), 
                nn.ReLU()
            ]
            
            # loop through dims[1:] and make Conv3d
            for i in range(1, len(channel_dims)):
                sequence.extend([
                    nn.Conv3d(channel_dims[i - 1], channel_dims[i], kernel_size=1), 
                    nn.ReLU()
                ])
                
            # make final layer
            sequence.extend([
                    nn.Conv3d(channel_dims[-1], num_channels, kernel_size=1, bias=False)
                ])
                        
            return nn.Sequential(*sequence)
        
        self.update_net = make_sequental(num_channels, channel_dims)

        def init_weights(m):
            if isinstance(m, nn.Conv3d):
                # init weights for Conv3d layers
                nn.init.normal_(m.weight, std=normal_std)
                
                # if bias exist, init bias
                if getattr(m, "bias", None) is not None:
                    if zero_bias:
                        nn.init.zeros_(m.bias)
                    else:
                        nn.init.normal_(m.bias, std=normal_std)

        # weight initialisation
        if use_normal_init:
            with torch.no_grad():
                self.apply(init_weights)

    def forward(self, x):
        return self.update_net(x)

In [11]:
test3 = VoxelUpdateNet(num_channels = 16, num_perceptions=3, channel_dims=[64, 32], normal_std=0.02, use_normal_init=True, zero_bias=True)

In [12]:
_ = summary(test3, torch.rand(3,48,8,8,8))

                            Kernel Shape      Output Shape  Params  Mult-Adds
Layer                                                                        
0_update_net.Conv3d_0  [48, 64, 1, 1, 1]  [3, 64, 8, 8, 8]  3.136k  1.572864M
1_update_net.ReLU_1                    -  [3, 64, 8, 8, 8]       -          -
2_update_net.Conv3d_2  [64, 32, 1, 1, 1]  [3, 32, 8, 8, 8]   2.08k  1.048576M
3_update_net.ReLU_3                    -  [3, 32, 8, 8, 8]       -          -
4_update_net.Conv3d_4  [32, 16, 1, 1, 1]  [3, 16, 8, 8, 8]   512.0   262.144k
-----------------------------------------------------------------------------
                         Totals
Total params             5.728k
Trainable params         5.728k
Non-trainable params        0.0
Mult-Adds             2.883584M


  df_sum = df.sum()


### VoxelNCA Model

In [29]:
from typing import Any, Dict, List, Optional
import torch.nn.functional as F

In [161]:
class VoxelNCAModel(nn.Module):
    def __init__(self,
        alpha_living_threshold: float = 0.1, # level below which the cell would be dead
        cell_fire_rate: float = 0.5, # how often do cells update
        step_size: float = 1.0, # ?
        num_perceptions = 3, # num of filters
        perception_requires_grad: bool = True, # if perception filters are trainable
        num_hidden_channels: int = 127, # num hidden channels
        normal_std: float = 0.0002, # for initialisation
        use_normal_init: bool = True, # whether to init
        zero_bias: bool = True, # whether to init bias as 0s
        update_net_channel_dims: List[int] = [32, 32], # channel sizes for hidden layers in VoxelUpdateNet
    ):
        super(VoxelNCAModel, self).__init__()
        self.alpha_living_threshold = alpha_living_threshold
        self.cell_fire_rate = cell_fire_rate
        self.step_size = step_size
        self.num_perceptions = num_perceptions
        self.perception_requires_grad = perception_requires_grad
        self.num_hidden_channels = num_hidden_channels
        self.normal_std = normal_std
        self.use_normal_init = use_normal_init
        self.zero_bias = zero_bias
        self.update_net_channel_dims = update_net_channel_dims
        
        # let's have 1 alpha channel + num_hidden_channels * hidden channels
        self.alpha_channel_index = 0
        self.num_channels = 1 + self.num_hidden_channels
        
        self.perception_net = VoxelPerceptionNet(
            num_in_channels = self.num_channels, 
            num_perceptions = self.num_perceptions, 
            normal_std = self.normal_std, 
            use_normal_init = self.use_normal_init, 
            zero_bias = self.zero_bias
        )
        if not self.perception_requires_grad:
            for p in self.perception_net.parameters():
                p.requires_grad = False

        self.update_network = VoxelUpdateNet(
            num_channels = self.num_channels,
            num_perceptions = self.num_perceptions,
            channel_dims = self.update_net_channel_dims,
            normal_std = self.normal_std, 
            use_normal_init = self.use_normal_init, 
            zero_bias = self.zero_bias
        )

        self.tanh = nn.Tanh()
        
    def check_alive(self, state):
        # scan the alpha channel and do a max pooling to get the maximum alpha for the cell's neighbourhood
        return F.max_pool3d(
            # cut out the one-hot block channels from the world (N, channels, x, y, z)
            state[:, self.alpha_channel_index : self.alpha_channel_index + 1, :, :, :],
            kernel_size = 3,
            stride = 1,
            padding = 1,
        )
    
    def perceive(self, state):
        return self.perception_net(state)
    
    def update(self, state):
        # this is going to result in a boolean tensor indicating cells that are alive
        pre_update_mask = self.check_alive(state) > self.alpha_living_threshold
        
        # extract features using the perception net
        perception = self.perceive(state)
        # calculate update deltas using the update net
        delta = self.update_network(perception)
        
        # mask out some cells. We take out (N, 1, x, y, z)
        rand_mask = torch.rand_like(state[:, 0:1, :, :, :]) < self.cell_fire_rate
        # multiply with the delta tensor to mask changes. This will broadcast to all channels
        delta = delta * rand_mask.float()
        
        # now we apply the changes
        state = state + delta
        
        # now get another boolean tensor of cells that are alive after update
        post_update_mask = self.check_alive(state) > self.alpha_living_threshold
        
        # cells are alive if they are alive both before and after update
        life_mask = (pre_update_mask & post_update_mask).float()
        # make all the dead cells everything zero
        state = state * life_mask
                
        return state, life_mask
        
    def forward(self, 
            state, # the world state before update
            steps = 1, # how many steps to run the NCA
            get_final_mask = False
        ):
        # in comes a batch of worlds (N, channels, x, y, z)
        for step in range(steps):
            state, life_mask = self.update(state)
        if get_final_mask: return state, life_mask
        return state

In [155]:
voxel_nca_model = VoxelNCAModel(alpha_living_threshold = 0.1,
    cell_fire_rate = 0.5,
    step_size = 1.0,
    num_perceptions = 3,
    perception_requires_grad = True,
    num_hidden_channels = 127,
    normal_std = 0.0002,
    use_normal_init = True,
    zero_bias = True,
    update_net_channel_dims = [32, 32]
)

In [156]:
sample_state = torch.rand(16, 128, 8, 8, 8)

In [159]:
new_state = voxel_nca_model.forward(sample_state, steps = 24)
new_state.shape

torch.Size([16, 128, 8, 8, 8])

In [None]:
from torchviz import make_dot
make_dot(new_state)

In [153]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [158]:
writer.add_graph(voxel_nca_model, input_to_model=(sample_state))