# Training GANCA-3D

We run five different experiments:

- Use deconvolutional WGAN-GP to as generator. This serves as the baseline
- GANCA on MMGAN
- GANCA on WGAN
- GANCA on WGAN-GP
- GANCA on dual-discriminator WGAN-GP

## Downloads

In [3]:
# enable pytorch tpu spport on colab
# !pip --quiet install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
# !pip --quiet install torch==1.9 torchtext==0.10 torchvision==0.10 torchaudio==0.9

In [None]:
# !wget -q https://github.com/chaosarium/3D-GANCA/raw/master/data_helper.py
# !wget -q https://github.com/chaosarium/3D-GANCA/raw/master/utils.py
# !wget -q https://github.com/chaosarium/3D-GANCA/raw/master/models.py
# !wget -q https://github.com/chaosarium/3D-GANCA/raw/master/block_ids_alt.tsv
# !wget -q https://github.com/chaosarium/3D-GANCA/raw/master/output.zip
# !unzip -q output.zip
# !mkdir dataset
# !wget -q https://github.com/chaosarium/3D-GANCA/blob/master/dataset/filtered_houses_stats.pkl4 -O dataset/filtered_houses_stats.pkl4
# !wget -q https://github.com/chaosarium/3D-GANCA/blob/master/dataset/filtered_houses_stats.pkl -O dataset/filtered_houses_stats.pkl

In [None]:
# !pip --quiet install torchsummaryX loguru einops pytorch_lightning

## Imports

In [12]:
# import stuff

import data_helper
from data_helper import GANCA3DDataModule
from models import VoxelPerceptionNet, VoxelUpdateNet, VoxelNCAModel, VoxelDiscriminator, VoxelDeconvGenerator
import visualise_helper
import utils
%load_ext autoreload
from tqdm.notebook import tqdm
%autoreload 2

import numpy as np
import sys
import torch
import random
from torchsummaryX import summary
import matplotlib.pyplot as plt
from pytorch_lightning import Trainer
import pytorch_lightning as pl
import math

from loguru import logger as gurulogger
gurulogger.remove()
gurulogger.add(sys.stdout, colorize=True, format="<blue>{time}</blue> <level>{message}</level>")
gurulogger.level("INFO", color="<red><bold>")

import os
from einops import rearrange
import torch.nn.functional as F
import pandas as pd

BLOCK2VEC_OUT_PATH = 'output/block2vec saves/block2vec 64 dim locked air/'
NUM_WORKERS = int(os.cpu_count() / 2)

%load_ext tensorboard

if torch.cuda.is_available():
    print('CUDNN VERSION:', torch.backends.cudnn.version())
    print('Number CUDA Devices:', torch.cuda.device_count())
    print('CUDA Device Name:',torch.cuda.get_device_name(0))
    print('CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)

In [21]:
embedding, mcid2block, block2embeddingidx, embeddingidx2block, block2mcid = utils.get_embedding_info(BLOCK2VEC_OUT_PATH)

converter = utils.DataConverter(embedding, mcid2block, block2embeddingidx, embeddingidx2block, block2mcid)

air_embedding = embedding(torch.tensor(converter.block2embeddingidx['minecraft:air']))
air_embedding

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [None]:
# !ps -ef | grep tensorboard | grep -v grep | awk '{print $2}' | xargs kill
# !tensorboard --logdir lightning_logs/
# %tensorboard --logdir lightning_logs/

## MMGAN

In [None]:
class GANCA_MMGAN(pl.LightningModule):

    def __init__(self,
            lr = 2e-4,
            beta1 = 0.9,
            beta2 = 0.999,
            num_embedding_channels = 64,
            num_hidden_channels = 63,
            update_net_channel_dims = [32, 32],
            embedding: torch.nn.Embedding = None,
            step_range = [16, 20],
        ):
        
        super().__init__()
        # call this to save args to the checkpoint
        self.save_hyperparameters()
        
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.num_embedding_channels = num_embedding_channels
        self.num_hidden_channels = num_hidden_channels
        self.update_net_channel_dims = update_net_channel_dims
        # the channels will be like [alpha, embeddings ... , hiddens ...]
        self.num_channels = 1 + self.num_embedding_channels + self.num_hidden_channels
        self.world_size = (32,32,32)
        self.embedding = embedding
        self.embedding.weight.requires_grad=False # freeze embeddings
        self.step_range = step_range
        
        self.generator = VoxelNCAModel(
            alpha_living_threshold = 0.1,
            cell_fire_rate = 0.5,
            num_perceptions = 3,
            perception_requires_grad = True,
            num_embedding_channels = self.num_embedding_channels,
            num_hidden_channels = self.num_hidden_channels,
            normal_std = 0.0002,
            use_normal_init = True,
            zero_bias = True,
            update_net_channel_dims = self.update_net_channel_dims,
        )
        self.discriminator = VoxelDiscriminator(
            num_in_channels = self.num_embedding_channels, 
            use_sigmoid=True,
        )
        
        # generate some random seeds (N, channels, x, y, z)
        self.validation_noise = self.make_seed_states(16)
        
    def make_seed_states(self, batch_size):
        return utils.make_seed_state(
            batch_size = batch_size,
            num_channels = self.num_channels, 
            alpha_channel_index = 0,
            seed_dim = (4, 4, 4), 
            world_size = self.world_size,
        )
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        
        num_steps = random.randint(*self.step_range)
            
        real_houses = batch
        type_holder = batch[0,0,0,0].to(torch.float) # this is a dummy type for creating labels
        size_this_batch = real_houses.shape[0]
                
        # make noise
        
        seed_states = self.make_seed_states(size_this_batch).type_as(type_holder) # same batch size as those coming in
            
        # train generator
        if optimizer_idx == 0:
            
            # generate images
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps)
                        
            # create ground truth result (all fake results) we want D to say the generated ones are real so they are all 1s
            real_labels = utils.make_real_labels(size_this_batch).type_as(type_holder)
            
            # now get the embedding parts out of the fake states
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :]
            
            # see what discriminator thinks
            fake_predictions = self.discriminator.forward(fake_houses)
            
            # calculate loss
            g_loss = F.binary_cross_entropy(fake_predictions, real_labels) # y_hat, y  
                                    
            self.log("g_loss", g_loss.detach(), prog_bar=True, logger=True)

            return g_loss
        
        if optimizer_idx == 1:
            
            # get embeddings for the real house
            real_houses = utils.examples2embedding(real_houses, self.embedding)
            
            # pass in real image to D and try to make it predict all 1s
            real_targets = utils.make_real_labels(size_this_batch).type_as(type_holder)
            real_predictions = self.discriminator.forward(real_houses)
            real_loss = F.binary_cross_entropy(real_predictions, real_targets)
            real_acc = torch.sum(real_predictions > 0.5).item()
            
            # meanwhile, we also want D to be able to tell that outputs from G are all fake
            fake_targets = utils.make_fake_labels(size_this_batch).type_as(type_holder)
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps).detach() # detach so that gradients don't pass back into generator
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :] # extract the embedding parts
            fake_predictions = self.discriminator.forward(fake_houses)
            fake_loss = F.binary_cross_entropy(fake_predictions, fake_targets)
            fake_acc = torch.sum(real_predictions < 0.5).item()
                        
            # discriminator loss is the average of the two losses
            d_loss = (real_loss + fake_loss) / 2
            avg_acc = (real_acc + fake_acc) / 2
                    
            self.log("fake_loss", fake_loss.detach(), prog_bar=False, logger=True)
            self.log("real_loss", real_loss.detach(), prog_bar=False, logger=True)
            self.log("d_loss", d_loss.detach(), prog_bar=True, logger=True)
            self.log("real_acc", real_acc, prog_bar=False, logger=True)
            self.log("fake_acc", fake_acc, prog_bar=False, logger=True)
            self.log("avg_acc", avg_acc, prog_bar=True, logger=True)

            return d_loss

    def configure_optimizers(self,):
        # define pytorch optimizers here. return [list of optimizers], [list of LR schedulers]
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        return [g_optimizer, d_optimizer], []
    
    def visualise_gen_process(self, step_size = 4, max_steps = 32):
        with torch.no_grad():
            
            num_steps = math.ceil(max_steps/step_size)

            state = self.validation_noise.to(self.device)
            snapshots = utils.extract_embedding_channels(state, self.num_embedding_channels)
            alpha_snapshots = utils.extract_alpha_channels(state)

            for i in tqdm(range(num_steps), desc="forward pass for visualisation", colour='orange', ncols=1000, leave=False):
                state = self.generator.forward(state, steps=step_size)
                snapshots = torch.cat((snapshots, utils.extract_embedding_channels(state, self.num_embedding_channels)))
                alpha_snapshots = torch.cat((alpha_snapshots, utils.extract_alpha_channels(state)))
                
            self.snapshots_folder_path = self.logger.log_dir + '/gen_snapshots'
            try: os.makedirs(self.snapshots_folder_path)
            except: pass
            np.save(self.snapshots_folder_path + f'/epoch_{self.current_epoch}.npy', snapshots.cpu().numpy())
            np.save(self.snapshots_folder_path + f'/epoch_{self.current_epoch}_alpha.npy', alpha_snapshots.cpu().numpy())
            
    def training_epoch_end(self, training_step_outputs = None):
        # make snapshot of progress array
        self.visualise_gen_process(step_size=4, max_steps=self.step_range[1])

In [None]:
model = GANCA_MMGAN(
    lr = 2e-4,
    beta1 = 0.9,
    beta2 = 0.999,
    num_embedding_channels = 64,
    num_hidden_channels = 63,
    update_net_channel_dims = [32, 32],
    embedding = embedding,
    step_range = [32, 36],
)

In [None]:
dm = GANCA3DDataModule(batch_size=8, num_workers=NUM_WORKERS, mcid2block = mcid2block, block2embeddingid = block2embeddingidx, debug=False)
dm.prepare_data()
dm.setup()

In [None]:
logger = pl.loggers.TensorBoardLogger(save_dir='lightning_logs', name='GANCA_MMGAN', default_hp_metric=False)
trainer = Trainer(gpus=1, max_epochs=32, fast_dev_run=False, log_every_n_steps=1, logger=logger, profiler=None)

In [None]:
trainer.fit(model, dm)

## WGAN

In [None]:
class GANCA_WGAN(pl.LightningModule):

    def __init__(self,
            lr = 2e-4,
            n_critic = 5,
            # beta1 = 0.9,
            # beta2 = 0.999,
            weight_clip = 0.01, # glipping gradient in WGAN
            num_embedding_channels = 64,
            num_hidden_channels = 63,
            update_net_channel_dims = [32, 32],
            embedding: torch.nn.Embedding = None,
            step_range = [16, 20],
        ):
        
        super().__init__()
        # call this to save args to the checkpoint
        self.save_hyperparameters()
        
        self.lr = lr
        self.n_critic = n_critic
        self.weight_clip = weight_clip
        # self.beta1 = beta1
        # self.beta2 = beta2
        self.num_embedding_channels = num_embedding_channels
        self.num_hidden_channels = num_hidden_channels
        self.update_net_channel_dims = update_net_channel_dims
        # the channels will be like [alpha, embeddings ... , hiddens ...]
        self.num_channels = 1 + self.num_embedding_channels + self.num_hidden_channels
        self.world_size = (32,32,32)
        self.seed_size = (2,2,2)
        self.embedding = embedding
        self.embedding.weight.requires_grad=False # freeze embeddings
        self.step_range = step_range
        
        self.generator = VoxelNCAModel(
            alpha_living_threshold = 0.1,
            cell_fire_rate = 0.5,
            num_perceptions = 3,
            perception_requires_grad = True,
            num_embedding_channels = self.num_embedding_channels,
            num_hidden_channels = self.num_hidden_channels,
            normal_std = 0.0002,
            use_normal_init = True,
            zero_bias = True,
            update_net_channel_dims = self.update_net_channel_dims,
        )
        self.discriminator = VoxelDiscriminator(
            num_in_channels = self.num_embedding_channels, 
            use_sigmoid=False,
        )
        
        # generate some random seeds (N, channels, x, y, z)
        self.validation_noise = self.make_seed_states(4)
        
    def make_seed_states(self, batch_size):
        return utils.make_seed_state(
            batch_size = batch_size,
            num_channels = self.num_channels, 
            alpha_channel_index = 0,
            seed_dim = self.seed_size, 
            world_size = self.world_size,
        )
    
    def training_step(self, batch, batch_idx, optimizer_idx):
                
        num_steps = random.randint(*self.step_range)

        real_houses_indices = batch
        type_holder = batch[0,0,0,0].to(torch.float) # this is a dummy type for creating labels
        size_this_batch = real_houses_indices.shape[0]
                
        # make noise
        seed_states = self.make_seed_states(size_this_batch).type_as(type_holder) # same batch size as those coming in
            
        # train generator
        if optimizer_idx == 0:
            
            # generate fake houses and get the embedding parts out of it
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps)
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :]
            
            # train gen
            g_loss = self.train_generator(fake_houses)              
            self.log("g_loss", g_loss.detach(), prog_bar=True, logger=True)
            return g_loss
        
        if optimizer_idx == 1:
                        
            # get embeddings for the real house
            real_houses = utils.examples2embedding(real_houses_indices, self.embedding)

            # generate fake houses
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps).detach() # detach so that gradients don't pass back into generator
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :] # extract the embedding parts
            
            # train D
            d_loss, real_loss, fake_loss = self.train_discriminator(real_houses, fake_houses)

            # logging
            self.log("fake_loss", fake_loss.detach(), prog_bar=False, logger=True)
            self.log("real_loss", real_loss.detach(), prog_bar=False, logger=True)
            self.log("d_loss", d_loss.detach(), prog_bar=True, logger=True)

            return d_loss

    def train_discriminator(self, real_houses, fake_houses):
        
        # making sure that real and fake have same shape
        assert real_houses.shape == fake_houses.shape

        # make predictions on real houses and see what D thinks
        real_predictions = self.discriminator.forward(real_houses)
        # For WGAN, we no longer use binary cross entropy. There is no target here.
        real_loss = - torch.mean(real_predictions) # maximising is the same as minimising the negative
        
        # see what D thinks on fake houses
        fake_predictions = self.discriminator.forward(real_houses)
        # once again, not BCE here
        fake_loss = torch.mean(fake_predictions)
        
        # make loss the sum
        d_loss = real_loss + fake_loss
        
        # clamp parameters
        for param in self.discriminator.parameters():
            param.data.clamp_(-self.weight_clip, self.weight_clip)
        
        return d_loss, real_loss, fake_loss
    
    def train_generator(self, fake_houses):
        
        # see what D thinks 
        fake_predictions = self.discriminator.forward(fake_houses)
        
        # calc loss. We want to maximise the prediction for this one (only doing so with G's parameters)
        g_loss = -torch.mean(fake_predictions)
        
        return g_loss

    def configure_optimizers(self):
        
        # for WGAN, use RMSprop
        g_optimizer = torch.optim.RMSprop(self.generator.parameters(), lr=self.lr)
        d_optimizer = torch.optim.RMSprop(self.discriminator.parameters(), lr=self.lr)
        
        return (
            {"optimizer": g_optimizer, "frequency": 1},
            {"optimizer": d_optimizer, "frequency": self.n_critic}
        )
        
    def visualise_gen_process(self, step_size = 4, max_steps = 32):
        with torch.no_grad():
            
            num_steps = math.ceil(max_steps/step_size)

            state = self.validation_noise.to(self.device)
            snapshots = utils.extract_embedding_channels(state, self.num_embedding_channels)
            alpha_snapshots = utils.extract_alpha_channels(state)

            for i in tqdm(range(num_steps), desc="forward pass for visualisation", colour='orange', ncols=1000, leave=False):
                state = self.generator.forward(state, steps=step_size)
                snapshots = torch.cat((snapshots, utils.extract_embedding_channels(state, self.num_embedding_channels)))
                alpha_snapshots = torch.cat((alpha_snapshots, utils.extract_alpha_channels(state)))
                
            self.snapshots_folder_path = self.logger.log_dir + '/gen_snapshots'
            try: os.makedirs(self.snapshots_folder_path)
            except: pass
            np.save(self.snapshots_folder_path + f'/epoch_{self.current_epoch}.npy', snapshots.cpu().numpy())
            np.save(self.snapshots_folder_path + f'/epoch_{self.current_epoch}_alpha.npy', alpha_snapshots.cpu().numpy())
            
    def training_epoch_end(self, training_step_outputs = None):
        # make snapshot of progress array
        self.visualise_gen_process(step_size=4, max_steps=self.step_range[1])

In [None]:
model = GANCA_WGAN(
    lr = 0.00005, # proposed in the WGAN paper
    n_critic = 5, # proposed in the WGAN paper
    weight_clip = 0.01, # proposed in the WGAN paper
    num_embedding_channels = 64,
    num_hidden_channels = 63,
    update_net_channel_dims = [32, 32],
    embedding = embedding,
    step_range = [32, 36],
)

In [None]:
dm = GANCA3DDataModule(batch_size=8, num_workers=NUM_WORKERS, mcid2block = mcid2block, block2embeddingid = block2embeddingidx, debug=False)
dm.prepare_data()
dm.setup()

In [None]:
logger = pl.loggers.TensorBoardLogger(save_dir='lightning_logs', name='GANCA_WGAN', default_hp_metric=False)
trainer = Trainer(gpus=1, max_epochs=32, fast_dev_run=False, log_every_n_steps=1, logger=logger, profiler=None)

In [None]:
trainer.fit(model, dm)

## WGAN-GP

In [None]:
class GANCA_WGANGP(pl.LightningModule):

    def __init__(self,
            lr = 0.0001,
            alpha_living_threshold = 0.1,
            beta1 = 0,
            beta2 = 0.9,
            n_gen = 1,
            n_critic = 5,
            lambda_gp = 10,
            num_embedding_channels = 64,
            num_hidden_channels = 63,
            update_net_channel_dims = [32, 32],
            embedding: torch.nn.Embedding = None,
            converter = None,
            step_range = [16, 20],
        ):
        
        super().__init__()
        # call this to save args to the checkpoint
        self.save_hyperparameters()
        
        self.alpha_living_threshold = alpha_living_threshold
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.n_gen = n_gen
        self.n_critic = n_critic
        self.lambda_gp = lambda_gp
        self.num_embedding_channels = num_embedding_channels
        self.num_hidden_channels = num_hidden_channels
        self.update_net_channel_dims = update_net_channel_dims
        # the channels will be like [alpha, embeddings ... , hiddens ...]
        self.num_channels = 1 + self.num_embedding_channels + self.num_hidden_channels
        self.world_size = (32,32,32)
        self.seed_size = (2,2,2)
        self.embedding = embedding
        self.embedding.weight.requires_grad=False # freeze embeddings
        self.step_range = step_range
        self.converter = converter
        
        self.generator = VoxelNCAModel(
            alpha_living_threshold = self.alpha_living_threshold,
            cell_fire_rate = 0.5,
            num_perceptions = 3,
            perception_requires_grad = True,
            num_embedding_channels = self.num_embedding_channels,
            num_hidden_channels = self.num_hidden_channels,
            normal_std = 0.0002,
            use_normal_init = True,
            zero_bias = True,
            update_net_channel_dims = self.update_net_channel_dims,
        )
        self.discriminator = VoxelDiscriminator(
            num_in_channels = self.num_embedding_channels, 
            use_sigmoid=False,
        )
        
        # generate some random seeds (N, channels, x, y, z)
        self.validation_noise = self.make_seed_states(4)
        
    def make_seed_states(self, batch_size):
        return utils.make_seed_state(
            batch_size = batch_size,
            num_channels = self.num_channels, 
            alpha_channel_index = 0,
            seed_dim = self.seed_size, 
            world_size = self.world_size,
        )
    
    def compute_gradient_penalty(self, real_samples, fake_samples):
        
        # Random weight term for interpolation between real and fake samples. We get a tensor of shape (N, 1, 1, 1, 1)
        alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1, 1))).to(self.device)
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        interpolates = interpolates.to(self.device)
        # calc predictions for interpolated samples
        interpolates_predictions = self.discriminator.forward(interpolates)
        fake = torch.Tensor(real_samples.shape[0], 1).fill_(1.0).to(self.device)
        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
            outputs=interpolates_predictions,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        gradients = gradients.reshape(gradients.size(0), -1).to(self.device)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def training_step(self, batch, batch_idx, optimizer_idx):
        
        num_steps = random.randint(*self.step_range)

        real_houses_indices = batch
        type_holder = batch[0,0,0,0].to(torch.float) # this is a dummy type for creating labels
        size_this_batch = real_houses_indices.shape[0]
                
        # make noise
        seed_states = self.make_seed_states(size_this_batch).type_as(type_holder) # same batch size as those coming in
            
        # train generator
        if optimizer_idx == 0:

            # generate fake houses and get the embedding parts out of it
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps)
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :]
            
            # train gen
            g_loss = self.train_generator(fake_houses)  
                        
            self.log("g_loss", g_loss.detach(), prog_bar=True, logger=True)
            return g_loss
        
        if optimizer_idx == 1:
            
            # get embeddings for the real house
            real_houses = utils.examples2embedding(real_houses_indices, self.embedding)

            # generate fake houses
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps).detach() # detach so that gradients don't pass back into generator
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :] # extract the embedding parts
            
            # train D
            d_loss = self.train_discriminator(real_houses, fake_houses)

            # logging
            self.log("d_loss", d_loss.detach(), prog_bar=True, logger=True)
            return d_loss

    def train_discriminator(self, real_houses, fake_houses):
        assert real_houses.shape == fake_houses.shape

        real_predictions = self.discriminator.forward(real_houses)
        fake_predictions = self.discriminator.forward(fake_houses)
        self.log("real_validity", torch.mean(real_predictions).detach(), prog_bar=True, logger=True)
        self.log("fake_validity", torch.mean(fake_predictions).detach(), prog_bar=True, logger=True)
        
        gradient_penalty = self.compute_gradient_penalty(real_houses.data, fake_houses.data)
                
        d_loss = -torch.mean(real_predictions) + torch.mean(fake_predictions) + self.lambda_gp * gradient_penalty
                
        return d_loss
    
    def train_generator(self, fake_houses):
        # see what D thinks 
        fake_predictions = self.discriminator.forward(fake_houses)
        
        # calc loss. We want to maximise the prediction for this one (only doing so with G's parameters)
        g_loss = -torch.mean(fake_predictions)
        
        return g_loss

    def configure_optimizers(self):
        
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        
        return [
            {"optimizer": g_optimizer, "frequency": self.n_gen},
            {"optimizer": d_optimizer, "frequency": self.n_critic},
        ]
        
    def visualise_gen_process(self, step_size = 4, max_steps = 32):
        with torch.no_grad():
            
            num_steps = math.ceil(max_steps/step_size)

            state = self.validation_noise.to(self.device)
            snapshots = utils.extract_embedding_channels(state, self.num_embedding_channels)
            alpha_snapshots = utils.extract_alpha_channels(state)

            for i in tqdm(range(num_steps), desc="forward pass for visualisation", colour='orange', ncols=1000, leave=False):
                state = self.generator.forward(state, steps=step_size)
                snapshots = torch.cat((snapshots, utils.extract_embedding_channels(state, self.num_embedding_channels)))
                alpha_snapshots = torch.cat((alpha_snapshots, utils.extract_alpha_channels(state)))
                
            self.snapshots_folder_path = self.logger.log_dir + '/gen_snapshots'
            try: os.makedirs(self.snapshots_folder_path)
            except: pass
            np.save(self.snapshots_folder_path + f'/epoch_{self.current_epoch}.npy', snapshots.cpu().numpy())
            np.save(self.snapshots_folder_path + f'/epoch_{self.current_epoch}_alpha.npy', alpha_snapshots.cpu().numpy())
            
    def training_epoch_end(self, training_step_outputs = None):
        # make snapshot of progress array
        self.visualise_gen_process(step_size=4, max_steps=self.step_range[1])

In [None]:
model = GANCA_WGANGP(
    alpha_living_threshold = 0,
    lr = 0.0001, # as proposed in WGAN-GP paper
    beta1 = 0, # as proposed in WGAN-GP paper
    beta2 = 0.9, # as proposed in WGAN-GP paper
    n_gen = 2, # huristically 2
    n_critic = 3, # huristically 3
    lambda_gp = 10,
    num_embedding_channels = 64,
    num_hidden_channels = 63,
    update_net_channel_dims = [32, 32],
    embedding = embedding,
    converter = converter,
    step_range = [32, 36],
)

In [None]:
dm = GANCA3DDataModule(batch_size=8, num_workers=NUM_WORKERS, mcid2block = mcid2block, block2embeddingid = block2embeddingidx, debug=False)
dm.prepare_data()
dm.setup()

In [None]:
logger = pl.loggers.TensorBoardLogger(save_dir='lightning_logs', name='GANCA_WGANGP', default_hp_metric=False)
trainer = Trainer(gpus=1, max_epochs=32, fast_dev_run=False, log_every_n_steps=1, logger=logger, profiler=None)

In [None]:
trainer.fit(model, dm)

## WGAN-GP-Dual-D

In [2]:
class GANCA_WGANGP_DUAL_D(pl.LightningModule):

    def __init__(self,
            lr = 0.0001,
            alpha_living_threshold = 0.1,
            beta1 = 0,
            beta2 = 0.9,
            n_gen = 1,
            n_critic = 5,
            n_a_critic = 3,
            lambda_gp = 10,
            num_embedding_channels = 64,
            num_hidden_channels = 63,
            update_net_channel_dims = [32, 32],
            embedding: torch.nn.Embedding = None,
            converter = None,
            step_range = [16, 20],
        ):
        
        super().__init__()
        # call this to save args to the checkpoint
        self.save_hyperparameters()
        
        self.alpha_living_threshold = alpha_living_threshold
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.n_gen = n_gen
        self.n_critic = n_critic
        self.n_a_critic = n_a_critic
        self.lambda_gp = lambda_gp
        self.num_embedding_channels = num_embedding_channels
        self.num_hidden_channels = num_hidden_channels
        self.update_net_channel_dims = update_net_channel_dims
        # the channels will be like [alpha, embeddings ... , hiddens ...]
        self.num_channels = 1 + self.num_embedding_channels + self.num_hidden_channels
        self.world_size = (32,32,32)
        self.seed_size = (2,2,2)
        self.embedding = embedding
        self.embedding.weight.requires_grad=False # freeze embeddings
        self.step_range = step_range
        self.converter = converter
        
        self.generator = VoxelNCAModel(
            alpha_living_threshold = self.alpha_living_threshold,
            cell_fire_rate = 0.5,
            num_perceptions = 3,
            perception_requires_grad = True,
            num_embedding_channels = self.num_embedding_channels,
            num_hidden_channels = self.num_hidden_channels,
            normal_std = 0.0002,
            use_normal_init = True,
            zero_bias = True,
            update_net_channel_dims = self.update_net_channel_dims,
            dont_mask_alpha = True,
        )
        self.discriminator = VoxelDiscriminator(
            num_in_channels = self.num_embedding_channels, 
            use_sigmoid=False,
        )
        self.aliveness_discriminator = VoxelDiscriminator(
            num_in_channels = 1, 
            use_sigmoid=False,
        )
        
        # generate some random seeds (N, channels, x, y, z)
        self.validation_noise = self.make_seed_states(4)
        
    def make_seed_states(self, batch_size):
        return utils.make_seed_state(
            batch_size = batch_size,
            num_channels = self.num_channels, 
            alpha_channel_index = 0,
            seed_dim = self.seed_size, 
            world_size = self.world_size,
        )
    
    def compute_gradient_penalty(self, real_samples, fake_samples):
        
        # Random weight term for interpolation between real and fake samples. We get a tensor of shape (N, 1, 1, 1, 1)
        alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1, 1))).to(self.device)
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        interpolates = interpolates.to(self.device)
        # calc predictions for interpolated samples
        interpolates_predictions = self.discriminator.forward(interpolates)
        fake = torch.Tensor(real_samples.shape[0], 1).fill_(1.0).to(self.device)
        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
            outputs=interpolates_predictions,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        gradients = gradients.reshape(gradients.size(0), -1).to(self.device)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty
    
    def compute_alpha_gradient_penalty(self, real_samples, fake_samples):
        
        # Random weight term for interpolation between real and fake samples. We get a tensor of shape (N, 1, 1, 1, 1)
        alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1, 1))).to(self.device)
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        interpolates = interpolates.to(self.device)
        # calc predictions for interpolated samples
        interpolates_predictions = self.aliveness_discriminator.forward(interpolates)
        fake = torch.Tensor(real_samples.shape[0], 1).fill_(1.0).to(self.device)
        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
            outputs=interpolates_predictions,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        gradients = gradients.reshape(gradients.size(0), -1).to(self.device)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        
        num_steps = random.randint(*self.step_range)

        real_houses_indices = batch
        type_holder = batch[0,0,0,0].to(torch.float) # this is a dummy type for creating labels
        size_this_batch = real_houses_indices.shape[0]
                
        # make noise
        seed_states = self.make_seed_states(size_this_batch).type_as(type_holder) # same batch size as those coming in
            
        # train generator
        if optimizer_idx == 0:

            # generate fake houses and get the embedding parts out of it
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps)
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :]
            fake_alpha = fake_houses_states[:, 0:1, :, :, :]

            # train gen
            g_loss = self.train_generator(fake_houses, fake_alpha)  
                        
            self.log("g_loss", g_loss.detach(), prog_bar=True, logger=True)
            return g_loss
        
        if optimizer_idx == 1:
            
            # get embeddings for the real house
            real_houses = utils.examples2embedding(real_houses_indices, self.embedding)

            # generate fake houses
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps).detach() # detach so that gradients don't pass back into generator
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :] # extract the embedding parts
            
            # train D
            d_loss = self.train_discriminator(real_houses, fake_houses)

            # logging
            self.log("d_loss", d_loss.detach(), prog_bar=True, logger=True)
            return d_loss
        
        if optimizer_idx == 2:
            # train the aliveness discriminator.
            '''
            ## Analysis:
            
            Assigning -1 to all dead cells and 1 to all alive cells in a real example doesn't work because fake alpha is close to 0 by default, so it's very easy for aliveness D
            So, what do we do?
            We could try to award the generator for generating an alive configuration that look like a real house without punishing it for dead cells
            Maybe we can go for a leaky relu!
            
            Another problem: the empty cells that haven't grown yet are punished for having a 0 alpha while they shouldn't
            Proposed solution: mask the real alpha by selecting the grown cells that have a none-zero alpha
            
            Another problem: there's too much award given to an alpha value close to 1, but all we want is a positive alpha value
            Proposed solution: make a concaved down activation so that the difference between 0.1 and 1.0 isn't large
            
            '''
            # generate fake houses
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps).detach()
            fake_alpha = fake_houses_states[:, 0:1, :, :, :] # extract alpha channel

            # get alpha for the real house
            real_alpha = torch.zeros(real_houses_indices.shape).type_as(type_holder)
            real_alpha = real_alpha + (real_houses_indices == 0) * -1
            real_alpha = real_alpha + (real_houses_indices != 0) * 1
            real_alpha = real_alpha.reshape(real_alpha.shape[0], 1, real_alpha.shape[1], real_alpha.shape[2], real_alpha.shape[3])
            
            # train alpha D
            d_a_loss = self.train_aliveness_discriminator(real_alpha, fake_alpha)

            # logging
            self.log("d_a_loss", d_a_loss.detach(), prog_bar=True, logger=True)
            return d_a_loss

    def train_discriminator(self, real_houses, fake_houses):
        assert real_houses.shape == fake_houses.shape

        real_predictions = self.discriminator.forward(real_houses)
        fake_predictions = self.discriminator.forward(fake_houses)
        self.log("real_validity", torch.mean(real_predictions).detach(), prog_bar=True, logger=True)
        self.log("fake_validity", torch.mean(fake_predictions).detach(), prog_bar=True, logger=True)
        
        gradient_penalty = self.compute_gradient_penalty(real_houses.data, fake_houses.data)
                
        d_loss = -torch.mean(real_predictions) + torch.mean(fake_predictions) + self.lambda_gp * gradient_penalty
                
        return d_loss
    
    def train_aliveness_discriminator(self, real_alpha, fake_alpha):
        assert real_alpha.shape == fake_alpha.shape

        real_predictions = self.aliveness_discriminator.forward(real_alpha)
        fake_predictions = self.aliveness_discriminator.forward(fake_alpha)
        self.log("real_alpha_validity", torch.mean(real_predictions).detach(), prog_bar=True, logger=True)
        self.log("fake_alpha_validity", torch.mean(fake_predictions).detach(), prog_bar=True, logger=True)
        
        gradient_penalty = self.compute_alpha_gradient_penalty(real_alpha.data, fake_alpha.data)
                
        d_a_loss = -torch.mean(real_predictions) + torch.mean(fake_predictions) + self.lambda_gp * gradient_penalty
                
        return d_a_loss
    
    def train_generator(self, fake_houses, fake_alpha):
        # see what D thinks 
        fake_predictions = self.discriminator.forward(fake_houses)
        fake_alpha_predictions = self.aliveness_discriminator.forward(fake_alpha)
        
        # calc loss. We want to maximise the prediction for this one (only doing so with G's parameters)
        # g_loss = - torch.mean(fake_predictions)
        g_loss = - (0. * torch.mean(fake_predictions) + 1. * torch.mean(fake_alpha_predictions))
        
        return g_loss

    def configure_optimizers(self):
        
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        d_a_optimizer = torch.optim.Adam(self.aliveness_discriminator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        
        return [
            {"optimizer": g_optimizer, "frequency": self.n_gen},
            {"optimizer": d_optimizer, "frequency": self.n_critic},
            {"optimizer": d_a_optimizer, "frequency": self.n_a_critic},
        ]
    
    def visualise_gen_process(self, step_size = 4, max_steps = 32):
        with torch.no_grad():
            
            num_steps = math.ceil(max_steps/step_size)

            state = self.validation_noise.to(self.device)
            snapshots = utils.extract_embedding_channels(state, self.num_embedding_channels)
            alpha_snapshots = utils.extract_alpha_channels(state)

            for i in tqdm(range(num_steps), desc="forward pass for visualisation", colour='orange', ncols=1000, leave=False):
                state = self.generator.forward(state, steps=step_size)
                snapshots = torch.cat((snapshots, utils.extract_embedding_channels(state, self.num_embedding_channels)))
                alpha_snapshots = torch.cat((alpha_snapshots, utils.extract_alpha_channels(state)))
                
            self.snapshots_folder_path = self.logger.log_dir + '/gen_snapshots'
            try: os.makedirs(self.snapshots_folder_path)
            except: pass
            np.save(self.snapshots_folder_path + f'/epoch_{self.current_epoch}.npy', snapshots.cpu().numpy())
            np.save(self.snapshots_folder_path + f'/epoch_{self.current_epoch}_alpha.npy', alpha_snapshots.cpu().numpy())
            
    def training_epoch_end(self, training_step_outputs = None):
        # make snapshot of progress array
        self.visualise_gen_process(step_size=4, max_steps=self.step_range[1])

NameError: name 'pl' is not defined

In [None]:
model = GANCA_WGANGP_DUAL_D(
    alpha_living_threshold = 0,
    lr = 0.0001, # as proposed in WGAN-GP paper
    beta1 = 0, # as proposed in WGAN-GP paper
    beta2 = 0.9, # as proposed in WGAN-GP paper
    n_gen = 2, # huristically 2
    n_critic = 3, # huristically 3
    n_a_critic = 1, # huristically 1
    lambda_gp = 10, # as proposed in WGAN-GP paper
    num_embedding_channels = 64,
    num_hidden_channels = 63,
    update_net_channel_dims = [32, 32],
    embedding = embedding,
    converter = converter,
    step_range = [32, 36],
)

In [None]:
dm = GANCA3DDataModule(batch_size=8, num_workers=NUM_WORKERS, mcid2block = mcid2block, block2embeddingid = block2embeddingidx, debug=False)
dm.prepare_data()
dm.setup()

In [None]:
logger = pl.loggers.TensorBoardLogger(save_dir='lightning_logs', name='GANCA_WGANGP_DUAL_D', default_hp_metric=False)
trainer = Trainer(gpus=1, max_epochs=32, fast_dev_run=False, log_every_n_steps=1, logger=logger, profiler=None)

In [None]:
trainer.fit(model, dm)

## WGAN-GP Deconv Generator Baseline

In [15]:
generator = VoxelDeconvGenerator(
    latent_dim = 512,
    world_shape = (32,32,32),
    num_embedding_channels = 64,
)

In [25]:
out = generator(torch.rand(4, 512, 1, 1, 1)).detach()
out.shape

torch.Size([4, 64, 32, 32, 32])

In [27]:
class Deconv_WGANGP(pl.LightningModule):

    def __init__(self,
            lr = 0.0001,
            beta1 = 0,
            beta2 = 0.9,
            n_gen = 1,
            n_critic = 5,
            lambda_gp = 10,
            latent_dim = 512,
            embedding: torch.nn.Embedding = None,
            num_embedding_channels = 64,
        ):
        
        super().__init__()
        # call this to save args to the checkpoint
        self.save_hyperparameters()
        
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.n_gen = n_gen
        self.n_critic = n_critic
        self.lambda_gp = lambda_gp
        self.latent_dim = latent_dim
        self.num_embedding_channels = num_embedding_channels
        
        self.embedding = embedding
        self.embedding.weight.requires_grad=False # freeze embeddings

        self.world_size = (32,32,32)
        
        generator = VoxelDeconvGenerator(
            latent_dim = self.latent_dim,
            world_shape = self.world_size,
            num_embedding_channels = self.num_embedding_channels,
        )

        self.discriminator = VoxelDiscriminator(
            num_in_channels = self.num_embedding_channels, 
            use_sigmoid=False,
        )
        
        # generate some random noise (N, self.latent_dim, 1, 1, 1)
        self.validation_noise = torch.rand(4, self.latent_dim, 1, 1, 1)
            
    def compute_gradient_penalty(self, real_samples, fake_samples):
        
        # Random weight term for interpolation between real and fake samples. We get a tensor of shape (N, 1, 1, 1, 1)
        alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1, 1))).to(self.device)
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        interpolates = interpolates.to(self.device)
        # calc predictions for interpolated samples
        interpolates_predictions = self.discriminator.forward(interpolates)
        fake = torch.Tensor(real_samples.shape[0], 1).fill_(1.0).to(self.device)
        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
            outputs=interpolates_predictions,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        gradients = gradients.reshape(gradients.size(0), -1).to(self.device)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def training_step(self, batch, batch_idx, optimizer_idx):
        
        real_houses_indices = batch
        type_holder = batch[0,0,0,0].to(torch.float) # this is a dummy type for creating labels
        size_this_batch = real_houses_indices.shape[0]
                
        # make noise
        noise = torch.rand(size_this_batch, self.latent_dim, 1, 1, 1).type_as(type_holder)
            
        # train generator
        if optimizer_idx == 0:

            # generate fake houses and get the embedding parts out of it
            fake_houses = self.generator.forward(noise)
            
            # train gen
            g_loss = self.train_generator(fake_houses)  
                        
            self.log("g_loss", g_loss.detach(), prog_bar=True, logger=True)
            return g_loss
        
        if optimizer_idx == 1:
            
            # get embeddings for the real house
            real_houses = utils.examples2embedding(real_houses_indices, self.embedding)

            # generate fake houses
            fake_houses = self.generator.forward(noise).detach()
            
            # train D
            d_loss = self.train_discriminator(real_houses, fake_houses)

            # logging
            self.log("d_loss", d_loss.detach(), prog_bar=True, logger=True)
            return d_loss

    def train_discriminator(self, real_houses, fake_houses):
        assert real_houses.shape == fake_houses.shape

        real_predictions = self.discriminator.forward(real_houses)
        fake_predictions = self.discriminator.forward(fake_houses)
        self.log("real_validity", torch.mean(real_predictions).detach(), prog_bar=True, logger=True)
        self.log("fake_validity", torch.mean(fake_predictions).detach(), prog_bar=True, logger=True)
        
        gradient_penalty = self.compute_gradient_penalty(real_houses.data, fake_houses.data)
                
        d_loss = -torch.mean(real_predictions) + torch.mean(fake_predictions) + self.lambda_gp * gradient_penalty
                
        return d_loss
    
    def train_generator(self, fake_houses):
        # see what D thinks 
        fake_predictions = self.discriminator.forward(fake_houses)
        
        # calc loss. We want to maximise the prediction for this one (only doing so with G's parameters)
        g_loss = -torch.mean(fake_predictions)
        
        return g_loss

    def configure_optimizers(self):
        
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        
        return [
            {"optimizer": g_optimizer, "frequency": self.n_gen},
            {"optimizer": d_optimizer, "frequency": self.n_critic},
        ]
        
    def visualise_gen_process(self, step_size = 4, max_steps = 32):
        with torch.no_grad():
            
            num_steps = math.ceil(max_steps/step_size)

            state = self.validation_noise.to(self.device)
            snapshots = utils.extract_embedding_channels(state, self.num_embedding_channels)
            alpha_snapshots = utils.extract_alpha_channels(state)

            for i in tqdm(range(num_steps), desc="forward pass for visualisation", colour='orange', ncols=1000, leave=False):
                state = self.generator.forward(state, steps=step_size)
                snapshots = torch.cat((snapshots, utils.extract_embedding_channels(state, self.num_embedding_channels)))
                alpha_snapshots = torch.cat((alpha_snapshots, utils.extract_alpha_channels(state)))
                
            self.snapshots_folder_path = self.logger.log_dir + '/gen_snapshots'
            try: os.makedirs(self.snapshots_folder_path)
            except: pass
            np.save(self.snapshots_folder_path + f'/epoch_{self.current_epoch}.npy', snapshots.cpu().numpy())
            np.save(self.snapshots_folder_path + f'/epoch_{self.current_epoch}_alpha.npy', alpha_snapshots.cpu().numpy())
            
    def training_epoch_end(self, training_step_outputs = None):
        # make snapshot of progress array
        self.visualise_gen_process(step_size=4, max_steps=self.step_range[1])

In [29]:
model = Deconv_WGANGP(
    lr = 0.0001, # as proposed in WGAN-GP paper
    beta1 = 0, # as proposed in WGAN-GP paper
    beta2 = 0.9, # as proposed in WGAN-GP paper
    n_gen = 2, # huristically 2
    n_critic = 3, # huristically 3
    lambda_gp = 10,
    num_embedding_channels = 64,
    embedding = embedding,
)

In [30]:
dm = GANCA3DDataModule(batch_size=8, num_workers=NUM_WORKERS, mcid2block = mcid2block, block2embeddingid = block2embeddingidx, debug=False)
dm.prepare_data()
dm.setup()

Dataset already exists


  0%|          | 0/1977 [00:00<?, ?it/s]

loaded 1977 houses
Turning MC id into embedding idx. This could take up to a minute.
Done with that.


In [None]:
logger = pl.loggers.TensorBoardLogger(save_dir='lightning_logs', name='Deconv_WGANGP', default_hp_metric=False)
trainer = Trainer(gpus=1, max_epochs=32, fast_dev_run=False, log_every_n_steps=1, logger=logger, profiler=None)