# Training GANCA-3D

## Downloads

In [None]:
# !wget -q https://github.com/chaosarium/3D-GANCA/raw/master/data_helper.py
# !wget -q https://github.com/chaosarium/3D-GANCA/raw/master/utils.py
# !wget -q https://github.com/chaosarium/3D-GANCA/raw/master/models.py
# !wget -q https://github.com/chaosarium/3D-GANCA/raw/master/block_ids_alt.tsv
# !wget -q https://github.com/chaosarium/3D-GANCA/raw/master/output.zip
# !unzip -q output.zip
# !mkdir dataset
# !wget -q https://github.com/chaosarium/3D-GANCA/blob/master/dataset/filtered_houses_stats.pkl4 -O dataset/filtered_houses_stats.pkl4
# !wget -q https://github.com/chaosarium/3D-GANCA/blob/master/dataset/filtered_houses_stats.pkl -O dataset/filtered_houses_stats.pkl

dataset/filtered_houses_stats.pkl4: No such file or directory
dataset/filtered_houses_stats.pkl: No such file or directory


In [None]:
# !pip --quiet install torchsummaryX loguru einops pytorch_lightning

[?25l[K     |█████▋                          | 10 kB 34.7 MB/s eta 0:00:01[K     |███████████▎                    | 20 kB 15.2 MB/s eta 0:00:01[K     |████████████████▉               | 30 kB 13.8 MB/s eta 0:00:01[K     |██████████████████████▌         | 40 kB 13.7 MB/s eta 0:00:01[K     |████████████████████████████    | 51 kB 8.6 MB/s eta 0:00:01[K     |████████████████████████████████| 58 kB 4.3 MB/s 
[K     |████████████████████████████████| 527 kB 17.2 MB/s 
[K     |████████████████████████████████| 134 kB 88.6 MB/s 
[K     |████████████████████████████████| 596 kB 69.7 MB/s 
[K     |████████████████████████████████| 397 kB 87.2 MB/s 
[K     |████████████████████████████████| 829 kB 86.4 MB/s 
[K     |████████████████████████████████| 952 kB 77.2 MB/s 
[K     |████████████████████████████████| 1.1 MB 51.0 MB/s 
[K     |████████████████████████████████| 271 kB 73.5 MB/s 
[K     |████████████████████████████████| 94 kB 3.9 MB/s 
[K     |████████████████████████

## Imports

In [1]:
# import stuff

import data_helper
import utils

%load_ext autoreload
%autoreload 2

import numpy as np
import sys
import pickle
import torch
import random
from torchsummary import summary
import matplotlib.pyplot as plt

from loguru import logger as gurulogger
gurulogger.remove()
gurulogger.add(sys.stdout, colorize=True, format="<blue>{time}</blue> <level>{message}</level>")
gurulogger.level("INFO", color="<red><bold>")


import os
from einops import rearrange
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

BLOCK2VEC_OUT_PATH = 'output/block2vec saves/block2vec 64 dim/'
NUM_WORKERS = int(os.cpu_count() / 2)

import pytorch_lightning as pl

if torch.cuda.is_available():
    print('CUDNN VERSION:', torch.backends.cudnn.version())
    print('Number CUDA Devices:', torch.cuda.device_count())
    print('CUDA Device Name:',torch.cuda.get_device_name(0))
    print('CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)

In [2]:
embedding, mcid2block, block2embeddingidx, embeddingidx2block = utils.get_embedding_info(BLOCK2VEC_OUT_PATH)

In [3]:
from data_helper import GANCA3DDataModule

In [4]:
from models import VoxelPerceptionNet, VoxelUpdateNet, VoxelNCAModel, VoxelDiscriminator
from pytorch_lightning import Trainer

In [7]:
%load_ext tensorboard
# %tensorboard --logdir lightning_logs/

## MM GAN

In [None]:
class GANCA_MMGAN(pl.LightningModule):

    def __init__(self,
            lr = 2e-4,
            beta1 = 0.9,
            beta2 = 0.999,
            batch_size = 16,
            num_embedding_channels = 64,
            num_hidden_channels = 63,
            update_net_channel_dims = [32, 32],
            embedding: torch.nn.Embedding = None,
            step_range = [16, 20],
        ):
        
        super().__init__()
        # call this to save args to the checkpoint
        self.save_hyperparameters()
        
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.batch_size = batch_size
        self.num_embedding_channels = num_embedding_channels
        self.num_hidden_channels = num_hidden_channels
        self.update_net_channel_dims = update_net_channel_dims
        # the channels will be like [alpha, embeddings ... , hiddens ...]
        self.num_channels = 1 + self.num_embedding_channels + self.num_hidden_channels
        self.world_size = (32,32,32)
        self.embedding = embedding
        self.embedding.weight.requires_grad=False # freeze embeddings
        self.step_range = step_range
        
        self.generator = VoxelNCAModel(
            alpha_living_threshold = 0.1,
            cell_fire_rate = 0.5,
            num_perceptions = 3,
            perception_requires_grad = True,
            num_embedding_channels = self.num_embedding_channels,
            num_hidden_channels = self.num_hidden_channels,
            normal_std = 0.0002,
            use_normal_init = True,
            zero_bias = True,
            update_net_channel_dims = self.update_net_channel_dims,
        )
        self.discriminator = VoxelDiscriminator(
            num_in_channels = self.num_embedding_channels, 
            use_sigmoid=True,
        )
        
        # generate some random seeds (N, channels, x, y, z)
        self.validation_noise = self.make_seed_states(16)
        
    def make_seed_states(self, batch_size):
        return utils.make_seed_state(
            batch_size = batch_size,
            num_channels = self.num_channels, 
            alpha_channel_index = 0,
            seed_dim = (4, 4, 4), 
            world_size = self.world_size,
        )
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        
        num_steps = random.randint(*self.step_range)
            
        real_houses = batch
        type_holder = batch[0,0,0,0].to(torch.float) # this is a dummy type for creating labels
        size_this_batch = real_houses.shape[0]
                
        # make noise
        
        seed_states = self.make_seed_states(size_this_batch).type_as(type_holder) # same batch size as those coming in
            
        # train generator
        if optimizer_idx == 0:
            
            # generate images
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps)
                        
            # create ground truth result (all fake results) we want D to say the generated ones are real so they are all 1s
            real_labels = utils.make_real_labels(size_this_batch).type_as(type_holder)
            
            # now get the embedding parts out of the fake states
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :]
            
            # see what discriminator thinks
            fake_predictions = self.discriminator.forward(fake_houses)
            
            # calculate loss
            g_loss = F.binary_cross_entropy(fake_predictions, real_labels) # y_hat, y  
                                    
            self.log("g_loss", g_loss.detach(), prog_bar=True, logger=True)

            return g_loss
        
        if optimizer_idx == 1:
            
            # get embeddings for the real house
            real_houses = utils.examples2embedding(real_houses, self.embedding)
            
            # pass in real image to D and try to make it predict all 1s
            real_targets = utils.make_real_labels(size_this_batch).type_as(type_holder)
            real_predictions = self.discriminator.forward(real_houses)
            real_loss = F.binary_cross_entropy(real_predictions, real_targets)
            # TODO fix accuracy formula
            real_acc = torch.mean(real_predictions).item() # the average prediction. The closer to 1 the more accurate
            
            # meanwhile, we also want D to be able to tell that outputs from G are all fake
            fake_targets = utils.make_fake_labels(size_this_batch).type_as(type_holder)
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps).detach() # detach so that gradients don't pass back into generator
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :] # extract the embedding parts
            fake_predictions = self.discriminator.forward(fake_houses)
            fake_loss = F.binary_cross_entropy(fake_predictions, fake_targets)
            # TODO fix accuracy formula
            fake_acc = 1 - torch.mean(fake_predictions).item() # the 1 - average prediction. The closer to 1 the more accurate
                        
            # discriminator loss is the average of the two losses
            d_loss = (real_loss + fake_loss) / 2
            avg_acc = (real_acc + fake_acc) / 2
                    
            self.log("fake_loss", fake_loss.detach(), prog_bar=False, logger=True)
            self.log("real_loss", real_loss.detach(), prog_bar=False, logger=True)
            self.log("d_loss", d_loss.detach(), prog_bar=True, logger=True)
            self.log("real_acc", real_acc, prog_bar=False, logger=True)
            self.log("fake_acc", fake_acc, prog_bar=False, logger=True)
            self.log("avg_acc", avg_acc, prog_bar=True, logger=True)

            return d_loss

    def configure_optimizers(self,):
        # define pytorch optimizers here. return [list of optimizers], [list of LR schedulers]
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        return [g_optimizer, d_optimizer], []

In [None]:
model = GANCA_MMGAN(
    lr = 2e-4,
    beta1 = 0.9,
    beta2 = 0.999,
    batch_size = 8,
    num_embedding_channels = 64,
    num_hidden_channels = 63,
    update_net_channel_dims = [16, 16],
    embedding = embedding,
    step_range = [16, 20],
)

In [None]:
dm = GANCA3DDataModule(batch_size=8, num_workers=1, mcid2block = mcid2block, block2embeddingid = block2embeddingidx, debug=False)
dm.prepare_data()
dm.setup()

Dataset already exists


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


  0%|          | 0/1977 [00:00<?, ?it/s]

loaded 1977 houses
Turning MC id into embedding idx. This could take up to a minute.


In [None]:
logger = pl.loggers.TensorBoardLogger(save_dir='lightning_logs', name='GANCA', default_hp_metric=False)
trainer = Trainer(gpus=1, max_epochs=1, fast_dev_run=False, log_every_n_steps=1, logger=logger, profiler="simple")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model, dm)

  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: lightning_logs/GANCA

  | Name          | Type               | Params
-----------------------------------------------------
0 | embedding     | Embedding          | 14.0 K
1 | generator     | VoxelNCAModel      | 18.8 K
2 | discriminator | VoxelDiscriminator | 11.3 M
-----------------------------------------------------
11.3 M    Trainable params
14.0 K    Non-trainable params
11.3 M    Total params
45.244    Total estimated model params size (MB)
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## WGAN

In [28]:
class GANCA_WGAN(pl.LightningModule):

    def __init__(self,
            lr = 2e-4,
            # beta1 = 0.9,
            # beta2 = 0.999,
            weight_clip = 0.01, # glipping gradient in WGAN
            batch_size = 16,
            num_embedding_channels = 64,
            num_hidden_channels = 63,
            update_net_channel_dims = [32, 32],
            embedding: torch.nn.Embedding = None,
            step_range = [16, 20],
        ):
        
        super().__init__()
        # call this to save args to the checkpoint
        self.save_hyperparameters()
        
        self.lr = lr
        self.weight_clip = weight_clip
        # self.beta1 = beta1
        # self.beta2 = beta2
        self.batch_size = batch_size
        self.num_embedding_channels = num_embedding_channels
        self.num_hidden_channels = num_hidden_channels
        self.update_net_channel_dims = update_net_channel_dims
        # the channels will be like [alpha, embeddings ... , hiddens ...]
        self.num_channels = 1 + self.num_embedding_channels + self.num_hidden_channels
        self.world_size = (32,32,32)
        self.seed_size = (2,2,2)
        self.embedding = embedding
        self.embedding.weight.requires_grad=False # freeze embeddings
        self.step_range = step_range
        
        self.generator = VoxelNCAModel(
            alpha_living_threshold = 0.1,
            cell_fire_rate = 0.5,
            num_perceptions = 3,
            perception_requires_grad = True,
            num_embedding_channels = self.num_embedding_channels,
            num_hidden_channels = self.num_hidden_channels,
            normal_std = 0.0002,
            use_normal_init = True,
            zero_bias = True,
            update_net_channel_dims = self.update_net_channel_dims,
        )
        self.discriminator = VoxelDiscriminator(
            num_in_channels = self.num_embedding_channels, 
            use_sigmoid=False,
        )
        
        # generate some random seeds (N, channels, x, y, z)
        self.validation_noise = self.make_seed_states(16)
        
    def make_seed_states(self, batch_size):
        return utils.make_seed_state(
            batch_size = batch_size,
            num_channels = self.num_channels, 
            alpha_channel_index = 0,
            seed_dim = self.seed_size, 
            world_size = self.world_size,
        )
    
    def training_step(self, batch, batch_idx, optimizer_idx):
                
        num_steps = random.randint(*self.step_range)

        real_houses_indices = batch
        type_holder = batch[0,0,0,0].to(torch.float) # this is a dummy type for creating labels
        size_this_batch = real_houses_indices.shape[0]
                
        # make noise
        seed_states = self.make_seed_states(size_this_batch).type_as(type_holder) # same batch size as those coming in
            
        # train generator
        if optimizer_idx == 0:
            
            # generate fake houses and get the embedding parts out of it
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps)
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :]
            
            # train gen
            g_loss = self.train_generator(fake_houses)              
            self.log("g_loss", g_loss.detach(), prog_bar=True, logger=True)
            return g_loss
        
        if optimizer_idx == 1:
                        
            # get embeddings for the real house
            real_houses = utils.examples2embedding(real_houses_indices, self.embedding)

            # generate fake houses
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps).detach() # detach so that gradients don't pass back into generator
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :] # extract the embedding parts
            
            # train D
            d_loss, real_loss, fake_loss = self.train_discriminator(real_houses, fake_houses)

            # logging
            self.log("fake_loss", fake_loss.detach(), prog_bar=False, logger=True)
            self.log("real_loss", real_loss.detach(), prog_bar=False, logger=True)
            self.log("d_loss", d_loss.detach(), prog_bar=True, logger=True)
            # TODO implement on acc calculations
            # self.log("real_acc", real_acc, prog_bar=False, logger=True)
            # self.log("fake_acc", fake_acc, prog_bar=False, logger=True)
            # self.log("avg_acc", avg_acc, prog_bar=True, logger=True)

            return d_loss

    def train_discriminator(self, real_houses, fake_houses):
        
        # making sure that real and fake have same shape
        assert real_houses.shape == fake_houses.shape

        # make predictions on real houses and see what D thinks
        real_predictions = self.discriminator.forward(real_houses)
        # For WGAN, we no longer use binary cross entropy. There is no target here.
        real_loss = - torch.mean(real_predictions) # maximising is the same as minimising the negative
        
        # see what D thinks on fake houses
        fake_predictions = self.discriminator.forward(real_houses)
        # once again, not BCE here
        fake_loss = torch.mean(fake_predictions)
        
        # make loss the sum
        d_loss = real_loss + fake_loss
        
        # clamp parameters
        for param in self.discriminator.parameters():
            param.data.clamp_(-self.weight_clip, self.weight_clip)
        
        return d_loss, real_loss, fake_loss
    
    def train_generator(self, fake_houses):
        
        # see what D thinks 
        fake_predictions = self.discriminator.forward(fake_houses)
        
        # calc loss. We want to maximise the prediction for this one (only doing so with G's parameters)
        g_loss = -torch.mean(fake_predictions)
        
        return g_loss

    def configure_optimizers(self):
        n_critic = 5
        
        # for WGAN, use RMSprop
        g_optimizer = torch.optim.RMSprop(self.generator.parameters(), lr=self.lr)
        d_optimizer = torch.optim.RMSprop(self.discriminator.parameters(), lr=self.lr)
        
        return (
            {"optimizer": g_optimizer, "frequency": 1},
            {"optimizer": d_optimizer, "frequency": n_critic}
        )

In [34]:
model = GANCA_WGAN(
    lr = 2e-4,
    weight_clip = 0.01,
    batch_size = 2,
    num_embedding_channels = 64,
    num_hidden_channels = 63,
    update_net_channel_dims = [8, 8],
    embedding = embedding,
    step_range = [1, 3],
)

In [35]:
dm = GANCA3DDataModule(batch_size=2, num_workers=1, mcid2block = mcid2block, block2embeddingid = block2embeddingidx, debug=True)
dm.prepare_data()
dm.setup()

Dataset already exists


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


  0%|          | 0/1977 [00:00<?, ?it/s]

loaded 1977 houses
Turning MC id into embedding idx. This could take up to a minute.


In [36]:
logger = pl.loggers.TensorBoardLogger(save_dir='lightning_logs', name='GANCA', default_hp_metric=False)
trainer = Trainer(gpus=0, max_epochs=2, fast_dev_run=False, log_every_n_steps=1, logger=logger, profiler="simple")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [37]:
trainer.fit(model, dm)

  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
  rank_zero_deprecation(

  | Name          | Type               | Params
-----------------------------------------------------
0 | embedding     | Embedding          | 14.0 K
1 | generator     | VoxelNCAModel      | 14.5 K
2 | discriminator | VoxelDiscriminator | 11.3 M
-----------------------------------------------------
11.3 M    Trainable params
14.0 K    Non-trainable params
11.3 M    Total params
45.227    Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/8 [00:00<?, ?it/s] [34m2022-03-03T12:01:51.444992+0800[0m [31m[1mtraining step with optimizer_idx 0[0m
[34m2022-03-03T12:01:51.450620+0800[0m [31m[1mtraining G[0m
Epoch 0:  12%|█▎        | 1/8 [00:03<00:25,  3.60s/it, loss=-0.0409, v_num=4, g_loss=-.0409][34m2022-03-03T12:01:52.942144+0800[0m [31m[1mtraining step with optimizer_idx 1[0m
[34m2022-03-03T12:01:52.943449+0800[0m [31m[1mtraining D[0m
Epoch 0:  25%|██▌       | 2/8 [00:04<00:13,  2.31s/it, loss=-0.0205, v_num=4, g_loss=-.0409, d_loss=0.000][34m2022-03-03T12:01:53.832096+0800[0m [31m[1mtraining step with optimizer_idx 1[0m
[34m2022-03-03T12:01:53.834149+0800[0m [31m[1mtraining D[0m
Epoch 0:  38%|███▊      | 3/8 [00:05<00:09,  1.96s/it, loss=-0.0136, v_num=4, g_loss=-.0409, d_loss=0.000][34m2022-03-03T12:01:55.073479+0800[0m [31m[1mtraining step with optimizer_idx 1[0m
[34m2022-03-03T12:01:55.075047+0800[0m [31m[1mtraining D[0m
Epoch 0:  50%|█████     | 4/8 [00

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## WGAN-GP

In [52]:
class GANCA_WGANGP(pl.LightningModule):

    def __init__(self,
            lr = 0.0001,
            beta1 = 0,
            beta2 = 0.9,
            n_critic = 5,
            lambda_gp = 10,
            weight_clip = 0.01,
            batch_size = 16,
            num_embedding_channels = 64,
            num_hidden_channels = 63,
            update_net_channel_dims = [32, 32],
            embedding: torch.nn.Embedding = None,
            step_range = [16, 20],
        ):
        
        super().__init__()
        # call this to save args to the checkpoint
        self.save_hyperparameters()
        
        self.lr = lr
        self.weight_clip = weight_clip
        self.beta1 = beta1
        self.beta2 = beta2
        self.n_critic = n_critic
        self.lambda_gp = lambda_gp
        self.batch_size = batch_size
        self.num_embedding_channels = num_embedding_channels
        self.num_hidden_channels = num_hidden_channels
        self.update_net_channel_dims = update_net_channel_dims
        # the channels will be like [alpha, embeddings ... , hiddens ...]
        self.num_channels = 1 + self.num_embedding_channels + self.num_hidden_channels
        self.world_size = (32,32,32)
        self.seed_size = (2,2,2)
        self.embedding = embedding
        self.embedding.weight.requires_grad=False # freeze embeddings
        self.step_range = step_range
        
        self.generator = VoxelNCAModel(
            alpha_living_threshold = 0.1,
            cell_fire_rate = 0.5,
            num_perceptions = 3,
            perception_requires_grad = True,
            num_embedding_channels = self.num_embedding_channels,
            num_hidden_channels = self.num_hidden_channels,
            normal_std = 0.0002,
            use_normal_init = True,
            zero_bias = True,
            update_net_channel_dims = self.update_net_channel_dims,
        )
        self.discriminator = VoxelDiscriminator(
            num_in_channels = self.num_embedding_channels, 
            use_sigmoid=False,
        )
        
        # generate some random seeds (N, channels, x, y, z)
        self.validation_noise = self.make_seed_states(16)
        
    def make_seed_states(self, batch_size):
        return utils.make_seed_state(
            batch_size = batch_size,
            num_channels = self.num_channels, 
            alpha_channel_index = 0,
            seed_dim = self.seed_size, 
            world_size = self.world_size,
        )
    
    def compute_gradient_penalty(self, real_samples, fake_samples):
        # Random weight term for interpolation between real and fake samples. We get a tensor of shape (N, 1, 1, 1, 1)
        alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1, 1))).to(self.device)
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        interpolates = interpolates.to(self.device)
        # calc predictions for interpolated samples
        interpolates_predictions = self.discriminator.forward(interpolates)
        fake = torch.Tensor(real_samples.shape[0], 1).fill_(1.0).to(self.device)
        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
            outputs=interpolates_predictions,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1).to(self.device)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def training_step(self, batch, batch_idx, optimizer_idx):
        
        num_steps = random.randint(*self.step_range)

        real_houses_indices = batch
        type_holder = batch[0,0,0,0].to(torch.float) # this is a dummy type for creating labels
        size_this_batch = real_houses_indices.shape[0]
                
        # make noise
        seed_states = self.make_seed_states(size_this_batch).type_as(type_holder) # same batch size as those coming in
            
        # train generator
        if optimizer_idx == 0:

            # generate fake houses and get the embedding parts out of it
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps)
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :]
            
            # train gen
            g_loss = self.train_generator(fake_houses)  
                        
            self.log("g_loss", g_loss.detach(), prog_bar=True, logger=True)
            return g_loss
        
        if optimizer_idx == 1:
            
            # get embeddings for the real house
            real_houses = utils.examples2embedding(real_houses_indices, self.embedding)

            # generate fake houses
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps).detach() # detach so that gradients don't pass back into generator
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :] # extract the embedding parts
            
            # train D
            d_loss = self.train_discriminator(real_houses, fake_houses)

            # logging
            self.log("d_loss", d_loss.detach(), prog_bar=True, logger=True)
            return d_loss

    def train_discriminator(self, real_houses, fake_houses):
        assert real_houses.shape == fake_houses.shape

        real_predictions = self.discriminator.forward(real_houses)
        fake_predictions = self.discriminator.forward(real_houses)
        
        gradient_penalty = self.compute_gradient_penalty(real_houses.data, fake_houses.data)
                
        d_loss = -torch.mean(real_predictions) + torch.mean(fake_predictions) + self.lambda_gp * gradient_penalty
                
        return d_loss
    
    def train_generator(self, fake_houses):
        # see what D thinks 
        fake_predictions = self.discriminator.forward(fake_houses)
        
        # calc loss. We want to maximise the prediction for this one (only doing so with G's parameters)
        g_loss = -torch.mean(fake_predictions)
        
        return g_loss

    def configure_optimizers(self):
        
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        
        return [
            {"optimizer": g_optimizer, "frequency": 1},
            {"optimizer": d_optimizer, "frequency": self.n_critic},
        ]

In [57]:
model = GANCA_WGANGP(
    lr = 0.0001,
    beta1 = 0,
    beta2 = 0.9,
    n_critic = 5,
    lambda_gp = 10,
    weight_clip = 0.01,
    batch_size = 2,
    num_embedding_channels = 64,
    num_hidden_channels = 63,
    update_net_channel_dims = [16, 16],
    embedding = embedding,
    step_range = [16, 20],
)

In [58]:
dm = GANCA3DDataModule(batch_size=2, num_workers=1, mcid2block = mcid2block, block2embeddingid = block2embeddingidx, debug=True)
dm.prepare_data()
dm.setup()

Dataset already exists


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


  0%|          | 0/1977 [00:00<?, ?it/s]

loaded 1977 houses
Turning MC id into embedding idx. This could take up to a minute.
Done with that.


In [60]:
logger = pl.loggers.TensorBoardLogger(save_dir='lightning_logs', name='GANCA', default_hp_metric=False)
trainer = Trainer(gpus=0, max_epochs=8, fast_dev_run=False, log_every_n_steps=1, logger=logger, profiler=None)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [61]:
trainer.fit(model, dm)

  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
  rank_zero_deprecation(

  | Name          | Type               | Params
-----------------------------------------------------
0 | embedding     | Embedding          | 14.0 K
1 | generator     | VoxelNCAModel      | 18.8 K
2 | discriminator | VoxelDiscriminator | 11.3 M
-----------------------------------------------------
11.3 M    Trainable params
14.0 K    Non-trainable params
11.3 M    Total params
45.244    Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:  25%|██▌       | 1/4 [00:23<01:11, 23.91s/it, loss=0.658, v_num=5, g_loss=0.658][34m2022-03-03T14:09:53.074093+0800[0m [31m[1mgradient_penalty calculated to be 64.43315887451172[0m
Epoch 0:  50%|█████     | 2/4 [00:33<00:33, 16.98s/it, loss=322, v_num=5, g_loss=0.658, d_loss=644.0][34m2022-03-03T14:10:03.014123+0800[0m [31m[1mgradient_penalty calculated to be 14.945566177368164[0m
Epoch 0:  75%|███████▌  | 3/4 [00:43<00:14, 14.63s/it, loss=265, v_num=5, g_loss=0.658, d_loss=149.0][34m2022-03-03T14:10:16.508604+0800[0m [31m[1mgradient_penalty calculated to be 12.260957717895508[0m
Epoch 1:  25%|██▌       | 1/4 [00:26<01:20, 26.96s/it, loss=183, v_num=5, g_loss=0.331, d_loss=123.0][34m2022-03-03T14:10:52.053256+0800[0m [31m[1mgradient_penalty calculated to be 1.7487647533416748[0m
Epoch 1:  50%|█████     | 2/4 [00:35<00:35, 17.65s/it, loss=156, v_num=5, g_loss=0.331, d_loss=17.50][34m2022-03-03T14:11:00.636030+0800[0m [31m[1mgradient_penalty calculated to

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## Debug

In [None]:
data_helper.download_dataset()

Dataset already exists
Extracting dataset to dataset/house_data


In [None]:
sample_state = torch.rand(1, 128, 32, 32, 32)
sample_seed_state = utils.make_seed_state(
    batch_size = 1,
    num_channels = 128, 
    alpha_channel_index = 0,
    seed_dim = (4, 4, 4), 
    world_size = (32,32,32),
)

In [None]:
# before train
out1 = model.generator.forward(sample_state, steps=1)
torch.sum(out1)

[34m2022-02-28T20:32:09.898136+0800[0m [31m[1mit's shape torch.Size([1, 128, 32, 32, 32]) coming in to the perception net[0m
[34m2022-02-28T20:32:10.059550+0800[0m [33m[1mit's shape torch.Size([1, 384, 32, 32, 32]) coming in to the update net[0m


tensor(2097196.2500, grad_fn=<SumBackward0>)

In [None]:
# after train
out3 = model.generator.forward(sample_state, steps=1)
torch.sum(out3)

[34m2022-02-28T20:32:54.688500+0800[0m [31m[1mit's shape torch.Size([1, 128, 32, 32, 32]) coming in to the perception net[0m
[34m2022-02-28T20:32:54.819261+0800[0m [33m[1mit's shape torch.Size([1, 384, 32, 32, 32]) coming in to the update net[0m


tensor(2097196.2500, grad_fn=<SumBackward0>)

In [None]:
torch.equal(out1, out3)

False

In [13]:
out2 = model.generator.forward(sample_state, steps=8)
torch.sum(out2)

NameError: name 'sample_state' is not defined