# GANCA Experimentats & Implementation

## Data

In [208]:
# import stuff
import data_helper
import importlib
importlib.reload(data_helper)
import numpy as np
import pickle
import torch
import os
import torch.nn as nn

BLOCK2VEC_OUT_PATH = 'output/block2vec saves/block2vec 64 dim/'
NUM_WORKERS = int(os.cpu_count() / 2)

from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torch.nn as nn

In [209]:
class GANCA3DDataset(Dataset):
    def __init__(self, loaded_worlds):
        super().__init__()
        self.dataset = loaded_worlds[:,:,:,:,0]
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        example = self.dataset[index]
        return torch.tensor(example)

In [210]:
# template
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        # super().__init__()
        
        self.batch_size = batch_size
        
        self.dims = (32, 32, 32) # this will be returned when calling this.size()
                
    def prepare_data(self):
        # download data
        data_helper.download_dataset()

    def setup(self, stage=None):
        # splitting data and process stuff
        print('settin gup')
        full_dataset = GANCA3DDataset(data_helper.houses_dataset())
        self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(full_dataset, [1600, 192, 185])
        
    # these funcs can also be placed directly inside a LightningModule
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=NUM_WORKERS,
        )

    def val_dataloader(self):
        # return DataLoader(self.val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=NUM_WORKERS)
        pass

    def test_dataloader(self):
        # return DataLoader(self.test_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=NUM_WORKERS)
        pass

    def predict_dataloader(self):
        pass

## Embeddings

In [211]:
with open(BLOCK2VEC_OUT_PATH + "representations.pkl", 'rb') as f:
	embeddings_dict = pickle.load(f)

In [212]:
len(embeddings_dict)

218

## Model

In [213]:
import torch.nn.functional as F
import torch

In [214]:
class VoxelNCA(nn.Module):
    def __init__(self, 
            world_shape=(32,32,32), 
            hidden_size=128,
        ):
        super(VoxelNCA, self).__init__()
        
        self.world_shape = world_shape
        self.hidden_size = hidden_size

    def forward(self, input_state):
        # in comes a world also with shape (hidden_size, world_shape[0], world_shape[1], world_shape[2])

        

        return input_state

## Generating 3D artifacts paper experiments

In [215]:
class PrintShape(nn.Module):
    def __init__(self):
        super(PrintShape, self).__init__()
    
    def forward(self, x):
        # Do your print / debug stuff here
        print("-> shape is now", x.shape)
        return x

### Sequential CNN??

### Voxel Perception Net!

In [216]:
class VoxelPerceptionNet(nn.Module):
    # Essentially running a trainable perceptor on each layer. This expands the number of channels by a factor of num_perceptions and gives us the visual features to do NCA updates
    
    def __init__(
        self, num_in_channels, num_perceptions=3, normal_std=0.02, use_normal_init=True, zero_bias=True
    ):
        super(VoxelPerceptionNet, self).__init__()
        
        self.num_in_channels = num_in_channels
        self.normal_std = normal_std
        
        self.sequence = nn.Sequential(
            nn.Conv3d(
                self.num_in_channels, # incoming channels
                self.num_in_channels * num_perceptions, # expand num_in_channels by factor of 3
                3, # kernal size of 3, which means neighbour radius of 1
                stride=1, # stride of 1 so look at each voxel
                padding=1, # make sure to look at edge voxels
                groups=self.num_in_channels, # disconnect the perceptions so that multiple percpetions run on each channel; essentially we have num_perceptions convolutional layers side by side
                bias=False, # no bias
            ),
            PrintShape()
        )
        
        def init_weights(m):
            if isinstance(m, nn.Conv3d):
                # init weights for Conv3d layers
                nn.init.normal_(m.weight, std=normal_std)
                
                # if bias exist, init bias
                if getattr(m, "bias", None) is not None:
                    if zero_bias:
                        nn.init.zeros_(m.bias)
                    else:
                        nn.init.normal_(m.bias, std=normal_std)

        # weight initialisation
        if use_normal_init:
            with torch.no_grad():
                self.apply(init_weights)
                
    def forward(self, input):
        return self.sequence(input)

In [217]:
test2 = VoxelPerceptionNet(3, normal_std=0.02, num_perceptions=3, use_normal_init=True)

In [218]:
# print all params for debug'
for name, param in test2.named_parameters():
    if param.requires_grad:
        print(name, param.data)

sequence.0.weight tensor([[[[[-3.8883e-02, -5.1184e-03, -6.4090e-03],
           [-5.3794e-02,  1.2653e-02,  6.2345e-03],
           [-1.6375e-02,  9.6170e-03,  1.3706e-02]],

          [[ 3.4142e-03,  1.6779e-02, -2.0713e-02],
           [ 3.3475e-02, -1.6946e-02, -5.9912e-03],
           [-2.3480e-03,  4.5203e-02, -3.0542e-02]],

          [[-1.0536e-02,  7.1457e-04, -3.8952e-03],
           [-5.8230e-03,  1.2197e-02, -2.4776e-03],
           [-2.2198e-02, -2.3461e-02, -1.5070e-02]]]],



        [[[[ 2.5103e-02,  3.6363e-03, -6.9560e-03],
           [-2.0942e-02,  7.4115e-04, -2.8962e-02],
           [ 3.3001e-02, -4.7898e-03, -5.9040e-03]],

          [[ 2.4569e-02, -5.7138e-03, -1.2699e-02],
           [ 1.1442e-02,  7.4514e-03, -2.3042e-02],
           [-1.8838e-02,  2.4893e-02,  1.4786e-02]],

          [[ 2.9716e-02,  1.8288e-02,  6.1580e-03],
           [ 1.3902e-02,  1.0811e-02, -3.5035e-02],
           [-3.6581e-02,  1.1993e-02, -3.1509e-02]]]],



        [[[[ 4.0772e-03, -

In [219]:
out = test2(torch.rand(16, 3, 8, 8, 8))

-> shape is now torch.Size([16, 9, 8, 8, 8])


### Voxel Update Net

In [246]:
class VoxelUpdateNet(nn.Module):
    # This is essentially running dense nets in parallel for each voxel. It takes visual features and predict the update for each feature. In comes the output from a VoxelPerceptionNet, which is of the shape (N, num_in_channels * num_perceptions, channel_dims[0], channel_dims[1], channel_dims[2])
    
    def __init__(
        self,
        num_channels = 16,
        num_perceptions=3,
        channel_dims=[64, 64],
        normal_std=0.02,
        use_normal_init=True,
        zero_bias=True,
    ):
        super(VoxelUpdateNet, self).__init__()
        
        def make_sequental(num_channels, channel_dims):
                
            # make first layer. 
            sequence = [
                # visual_feature_channels, x, y, z -> channel_dims[0], x, y, z
                nn.Conv3d(num_channels * num_perceptions, channel_dims[0], kernel_size=1), 
                nn.ReLU(), 
                PrintShape()
            ]
            
            # loop through dims[1:] and make Conv3d
            for i in range(1, len(channel_dims)):
                sequence.extend([
                    nn.Conv3d(channel_dims[i - 1], channel_dims[i], kernel_size=1), 
                    nn.ReLU(), 
                    PrintShape()
                ])
                
            # make final layer
            sequence.extend([
                    nn.Conv3d(channel_dims[-1], num_channels, kernel_size=1, bias=False),
                    PrintShape()
                ])
                        
            return nn.Sequential(*sequence)
        
        self.update_net = make_sequental(num_channels, channel_dims)

        def init_weights(m):
            if isinstance(m, nn.Conv3d):
                # init weights for Conv3d layers
                nn.init.normal_(m.weight, std=normal_std)
                
                # if bias exist, init bias
                if getattr(m, "bias", None) is not None:
                    if zero_bias:
                        nn.init.zeros_(m.bias)
                    else:
                        nn.init.normal_(m.bias, std=normal_std)

        # weight initialisation
        if use_normal_init:
            with torch.no_grad():
                self.apply(init_weights)

    def forward(self, x):
        return self.update_net(x)

In [260]:
test3 = VoxelUpdateNet(num_channels = 16, num_perceptions=3, channel_dims=[64, 32], normal_std=0.02, use_normal_init=True, zero_bias=True)

In [261]:
out = test3(torch.rand(3,48,8,8,8))

-> shape is now torch.Size([3, 64, 8, 8, 8])
-> shape is now torch.Size([3, 32, 8, 8, 8])
-> shape is now torch.Size([3, 16, 8, 8, 8])


### VoxelNCA Model