# Training GANCA-3D

In [7]:
# import stuff

import data_helper
import utils

%load_ext autoreload
%autoreload 2

import numpy as np
import sys
import pickle
import torch
import random
from torchsummaryX import summary
import matplotlib.pyplot as plt

from loguru import logger as gurulogger
gurulogger.remove()
gurulogger.add(sys.stdout, colorize=True, format="<blue>{time}</blue> <level>{message}</level>")
gurulogger.level("INFO", color="<red><bold>")


import os
from einops import rearrange
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

BLOCK2VEC_OUT_PATH = 'output/block2vec saves/block2vec 64 dim/'
NUM_WORKERS = int(os.cpu_count() / 2)

import pytorch_lightning as pl

if torch.cuda.is_available():
    print('CUDNN VERSION:', torch.backends.cudnn.version())
    print('Number CUDA Devices:', torch.cuda.device_count())
    print('CUDA Device Name:',torch.cuda.get_device_name(0))
    print('CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
embedding, mcid2block, block2embeddingidx, embeddingidx2block = utils.get_embedding_info(BLOCK2VEC_OUT_PATH)

In [7]:
from data_helper import GANCA3DDataModule

In [8]:
from models import VoxelPerceptionNet, VoxelUpdateNet, VoxelNCAModel, VoxelDiscriminator
from pytorch_lightning import Trainer

In [9]:
%load_ext tensorboard
# %tensorboard --logdir lightning_logs/

In [5]:
class GANCA(pl.LightningModule):

    def __init__(self,
            lr = 2e-4,
            beta1 = 0.9,
            beta2 = 0.999,
            batch_size = 16,
            num_embedding_channels = 64,
            num_hidden_channels = 63,
            update_net_channel_dims = [32, 32],
            embedding: torch.nn.Embedding = None,
            step_range = [16, 20],
        ):
        
        super().__init__()
        # call this to save args to the checkpoint
        self.save_hyperparameters()
        
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.batch_size = batch_size
        self.num_embedding_channels = num_embedding_channels
        self.num_hidden_channels = num_hidden_channels
        self.update_net_channel_dims = update_net_channel_dims
        # the channels will be like [alpha, embeddings ... , hiddens ...]
        self.num_channels = 1 + self.num_embedding_channels + self.num_hidden_channels
        self.world_size = (32,32,32)
        self.embedding = embedding
        self.step_range = step_range
        
        self.generator = VoxelNCAModel(
            alpha_living_threshold = 0.1,
            cell_fire_rate = 0.5,
            num_perceptions = 3,
            perception_requires_grad = True,
            num_embedding_channels = self.num_embedding_channels,
            num_hidden_channels = self.num_hidden_channels,
            normal_std = 0.0002,
            use_normal_init = True,
            zero_bias = True,
            update_net_channel_dims = self.update_net_channel_dims,
        )
        self.discriminator = VoxelDiscriminator(
            num_in_channels = self.num_embedding_channels, 
            use_sigmoid=True,
        )
        
        # generate some random seeds (N, channels, x, y, z)
        self.validation_noise = self.make_seed_states(16)
        
    def make_seed_states(self, batch_size):
        return utils.make_seed_state(
            batch_size = batch_size,
            num_channels = self.num_channels, 
            alpha_channel_index = 0,
            seed_dim = (4, 4, 4), 
            world_size = self.world_size,
        )
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        
        num_steps = random.randint(self.step_range)
            
        real_houses = batch
        type_holder = batch[0,0,0,0].to(torch.float) # this is a dummy type for creating labels
        gurulogger.info(f'created typeholder tensor {type_holder} for new tensors with type {type_holder.type()} on device {type_holder.get_device()}')
        size_this_batch = real_houses.shape[0]
        gurulogger.info(f'got shape {real_houses.shape} as real example')
                
        # make noise
        
        seed_states = self.make_seed_states(size_this_batch).type_as(type_holder) # same batch size as those coming in
        gurulogger.info(f'made seed states with shape {seed_states.shape} and type {seed_states.type()}')
            
        # train generator
        if optimizer_idx == 0:
            
            gurulogger.info("ENTERING GENERATOR TRAIN LOOP")

            # generate images
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps)
            gurulogger.info(f'generated houses of shape {fake_houses_states.shape}')
                        
            # create ground truth result (all fake results) we want D to say the generated ones are real so they are all 1s
            real_labels = utils.make_real_labels(size_this_batch).type_as(type_holder)
            gurulogger.info(f'real labels looks like: {real_labels}')
            
            # now get the embedding parts out of the fake states
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :]
            gurulogger.info(f'extracted fake_houses with shape: {fake_houses.shape}')
            
            # see what discriminator thinks
            fake_predictions = self.discriminator.forward(fake_houses)
            gurulogger.info(f"the discriminator predicted {fake_predictions} for the houses. let's compare that to the real label {real_labels}")
            
            # calculate loss
            g_loss = F.binary_cross_entropy(fake_predictions, real_labels) # y_hat, y  
            gurulogger.info(f'gooooooooooooooooooooooooooooooooooooooood so far for generator. step done with g_loss **{g_loss}**')
                                    
            self.log("g_loss", g_loss.detach(), prog_bar=True, logger=True)

            return g_loss
        
        if optimizer_idx == 1:
            
            gurulogger.info("ENTERING DISCRIMINATOR TRAIN LOOP")
            
            # get embeddings for the real house
            real_houses = utils.examples2embedding(real_houses, self.embedding)
            
            gurulogger.info(f"got embedded real houses with shape {real_houses.shape}")

            # pass in real image to D and try to make it predict all 1s
            real_targets = utils.make_real_labels(size_this_batch).type_as(type_holder)
            real_predictions = self.discriminator.forward(real_houses)
            gurulogger.info(f"on a real house, D gave predictions: {real_predictions}; and we want to match the label {real_targets}")
            real_loss = F.binary_cross_entropy(real_predictions, real_targets)
            real_acc = torch.mean(real_predictions > 0.5).item() # the average prediction. The closer to 1 the more accurate
            gurulogger.info(f"real loss is therefore: {real_loss}; and real accuracy is {real_acc}")
            
            # meanwhile, we also want D to be able to tell that outputs from G are all fake
            fake_targets = utils.make_fake_labels(size_this_batch).type_as(type_holder)
            fake_houses_states = self.generator.forward(seed_states, steps=num_steps).detach() # detach so that gradients don't pass back into generator
            fake_houses = fake_houses_states[:, 1:self.num_embedding_channels+1, :, :, :] # extract the embedding parts
            gurulogger.info(f"created fake target: {fake_targets}; generated from seeds fake houses output {fake_houses_states.shape} and extracted the embeddings to {fake_houses.shape}")
            fake_predictions = self.discriminator.forward(fake_houses)
            gurulogger.info(f"D predicted: {fake_predictions} on fake houses")
            fake_loss = F.binary_cross_entropy(fake_predictions, fake_targets)
            fake_acc = 1 - torch.mean(fake_predictions < 0.5).item() # the 1 - average prediction. The closer to 1 the more accurate
            gurulogger.info(f"fake loss is therefore: {fake_loss}; and fake accuracy is {fake_acc}")
                        
            # discriminator loss is the average of the two losses
            d_loss = (real_loss + fake_loss) / 2
            avg_acc = (real_acc + fake_acc) / 2
            gurulogger.info(f"avg D loss is: {d_loss}; and avg accuracy is {avg_acc}")
                    
            self.log("fake_loss", fake_loss.detach(), prog_bar=False, logger=True)
            self.log("real_loss", real_loss.detach(), prog_bar=False, logger=True)
            self.log("d_loss", d_loss.detach(), prog_bar=True, logger=True)
            self.log("real_acc", real_acc, prog_bar=False, logger=True)
            self.log("fake_acc", fake_acc, prog_bar=False, logger=True)
            self.log("avg_acc", avg_acc, prog_bar=True, logger=True)

            return d_loss

    def configure_optimizers(self,):
        # define pytorch optimizers here. return [list of optimizers], [list of LR schedulers]
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        return [g_optimizer, d_optimizer], []

In [6]:
model = GANCA(
    lr = 2e-4,
    beta1 = 0.9,
    beta2 = 0.999,
    batch_size = 16,
    num_embedding_channels = 64,
    num_hidden_channels = 63,
    update_net_channel_dims = [16, 16],
    embedding = embedding,
    step_range = [16, 20],
)

NameError: name 'embedding' is not defined

In [121]:
# dm = GANCA3DDataModule(batch_size=16, num_workers=1, mcid2block = mcid2block, block2embeddingid = block2embeddingidx, debug=True)

In [122]:
logger = pl.loggers.TensorBoardLogger(save_dir='lightning_logs', name='GANCA', default_hp_metric=False)
trainer = Trainer(gpus=0, max_epochs=1, fast_dev_run=True, log_every_n_steps=1, logger=logger, profiler="simple")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).


In [115]:
trainer.fit(model, dm)


  | Name          | Type               | Params
-----------------------------------------------------
0 | embedding     | Embedding          | 14.0 K
1 | generator     | VoxelNCAModel      | 18.8 K
2 | discriminator | VoxelDiscriminator | 11.3 M
-----------------------------------------------------
11.3 M    Trainable params
14.0 K    Non-trainable params
11.3 M    Total params
45.244    Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] [34m2022-02-28T20:32:31.099838+0800[0m [31m[1mcreated typeholder tensor 0.0 for new tensors with type torch.FloatTensor on device -1[0m
[34m2022-02-28T20:32:31.100758+0800[0m [31m[1mgot shape torch.Size([8, 32, 32, 32]) as real example[0m
[34m2022-02-28T20:32:31.123795+0800[0m [31m[1mmade seed states with shape torch.Size([8, 128, 32, 32, 32]) and type torch.FloatTensor[0m
[34m2022-02-28T20:32:31.124390+0800[0m [31m[1mENTERING GENERATOR TRAIN LOOP[0m
[34m2022-02-28T20:32:31.131597+0800[0m [31m[1mit's shape torch.Size([8, 128, 32, 32, 32]) coming in to the perception net[0m
[34m2022-02-28T20:32:32.505888+0800[0m [33m[1mit's shape torch.Size([8, 384, 32, 32, 32]) coming in to the update net[0m
[34m2022-02-28T20:32:33.011011+0800[0m [31m[1mgenerated houses of shape torch.Size([8, 128, 32, 32, 32])[0m
[34m2022-02-28T20:32:33.014577+0800[0m [31m[1mreal labels looks like: tensor([[1.],
        [1.],
       

FIT Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  20.896         	|  100 %          	|
--------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  20.206         	|1              	|  20.206         	|  96.699         	|
run_training_batch                 	|  13.096         	|1              	|  13.096         	|  62.675         	|
training_step_and_backward         	|  6.5306         	|2              	|  13.061         	|  62.507         	|
backward                           	|  3.4064         	|2              	|  6.8128         	|  32.604         	|
optimizer_step_with_closure_1      




---
debug

---

In [111]:
sample_state = torch.rand(1, 128, 32, 32, 32)
sample_seed_state = utils.make_seed_state(
    batch_size = 1,
    num_channels = 128, 
    alpha_channel_index = 0,
    seed_dim = (4, 4, 4), 
    world_size = (32,32,32),
)

In [112]:
# before train
out1 = model.generator.forward(sample_state, steps=1)
torch.sum(out1)

[34m2022-02-28T20:32:09.898136+0800[0m [31m[1mit's shape torch.Size([1, 128, 32, 32, 32]) coming in to the perception net[0m
[34m2022-02-28T20:32:10.059550+0800[0m [33m[1mit's shape torch.Size([1, 384, 32, 32, 32]) coming in to the update net[0m


tensor(2097196.2500, grad_fn=<SumBackward0>)

In [116]:
# after train
out3 = model.generator.forward(sample_state, steps=1)
torch.sum(out3)

[34m2022-02-28T20:32:54.688500+0800[0m [31m[1mit's shape torch.Size([1, 128, 32, 32, 32]) coming in to the perception net[0m
[34m2022-02-28T20:32:54.819261+0800[0m [33m[1mit's shape torch.Size([1, 384, 32, 32, 32]) coming in to the update net[0m


tensor(2097196.2500, grad_fn=<SumBackward0>)

In [117]:
torch.equal(out1, out3)

False

In [87]:
out2 = model.generator.forward(sample_state, steps=8)
torch.sum(out2)

tensor(2097251.2500, grad_fn=<SumBackward0>)