# Training GANCA-3D

In [9]:
# import stuff

import data_helper
import utils

%load_ext autoreload
%autoreload 2

import numpy as np
import pickle
import torch
from torchsummaryX import summary
import matplotlib.pyplot as plt

import os
from einops import rearrange
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

BLOCK2VEC_OUT_PATH = 'output/block2vec saves/block2vec 64 dim/'
NUM_WORKERS = int(os.cpu_count() / 2)

import pytorch_lightning as pl

if torch.cuda.is_available():
    print('CUDNN VERSION:', torch.backends.cudnn.version())
    print('Number CUDA Devices:', torch.cuda.device_count())
    print('CUDA Device Name:',torch.cuda.get_device_name(0))
    print('CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
embedding, mcid2block, block2embeddingidx, embeddingidx2block = utils.get_embedding_info(BLOCK2VEC_OUT_PATH)

In [11]:
from data_helper import GANCA3DDataModule

In [12]:
dm = GANCA3DDataModule(batch_size=16, num_workers=1, mcid2block = mcid2block, block2embeddingid = block2embeddingidx)

  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


In [13]:
from models import VoxelPerceptionNet, VoxelUpdateNet, VoxelNCAModel, VoxelDiscriminator

In [16]:
%load_ext tensorboard
# %tensorboard --logdir lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [17]:
class GANCA(pl.LightningModule):

    def __init__(self,
            latent_dim = 64,
            image_shape = (3,32,32),
            lr = 2e-4,
            batch_size = 128,
            beta1 = 0.9,
            beta2 = 0.999
        ):
        
        super().__init__()
        # call this to save args to the checkpoint
        self.save_hyperparameters()
        
        self.latent_dim = latent_dim
        self.image_shape = image_shape
        self.lr = lr
        self.batch_size = batch_size
        self.beta1 = beta1
        self.beta2 = beta2
        
                
        self.generator = VoxelNCAModel(
            alpha_living_threshold = 0.1,
            cell_fire_rate = 0.5,
            num_perceptions = 3,
            perception_requires_grad = True,
            num_hidden_channels = 127,
            normal_std = 0.0002,
            use_normal_init = True,
            zero_bias = True,
            update_net_channel_dims = [32, 32]
        )
        self.discriminator = VoxelDiscriminator(
            num_in_channels = 64, 
            use_sigmoid=True
        )
        
        # generate 16 random latent space data for validation of shape latent_dim, 1, 1
        self.validation_noise = torch.randn(16, self.latent_dim, 1, 1)

    def forward(self, z):
        # the forward step here
        # in comes the (N, latent_dim, 1, 1) noise
        return self.generator(z)
    
    def adversarial_loss(self, y_hat, y):
        # this is an added function for calculating loss!
        return F.binary_cross_entropy(y_hat, y)
        
    def training_step(self, batch, batch_idx, optimizer_idx):
            
        images, _ = batch
        
        # make noise
        
        noise = torch.randn(images.size(0), self.latent_dim, 1, 1) # same batch size as those coming in
        noise = noise.type_as(images) # ensuring it's on device
            
        # train generator
        if optimizer_idx == 0:
            
            # generate images
            fake_images = self.generator(noise)
            
            # log some sample image
            grid = torchvision.utils.make_grid(fake_images.detach()[:6])
            self.logger.experiment.add_image("generated_images", grid, 0)
            
            # create ground truth result (all fake results) we want D to say the generated ones are real so they are all 1s
            real_targets = torch.ones(fake_images.size(0), 1).type_as(images) # ensuring it's on device
            
            # see what discriminator thinks
            fake_predictions = self.discriminator(fake_images)
            
            # calculate loss
            g_loss = self.adversarial_loss(fake_predictions, real_targets)
                                    
            self.log("g_loss", g_loss.detach(), prog_bar=True, logger=True)

            return g_loss
        
        if optimizer_idx == 1:
            
            # pass in real image to D and try to make it predict all 1s
            real_targets = torch.ones(images.size(0), 1).type_as(images) # ensuring it's on device
            real_predictions = self.discriminator(images)
            real_loss = self.adversarial_loss(real_predictions, real_targets)
            real_acc = torch.mean(real_predictions).item() # the average prediction. The closer to 1 the more accurate
            
            # meanwhile, we also want D to be able to tell that outputs from G are all fake
            fake_targets = torch.zeros(images.size(0), 1).type_as(images) # ensuring it's on device
            fake_images = self.generator(noise).detach() # detach so that gradients don't pass back into generator
            fake_predictions = self.discriminator(fake_images)
            fake_loss = self.adversarial_loss(fake_predictions, fake_targets)
            fake_acc = 1 - torch.mean(fake_predictions).item() # the 1 - average prediction. The closer to 1 the more accurate
                        
            # discriminator loss is the average of the two losses
            d_loss = (real_loss + fake_loss) / 2
            avg_acc = (real_acc + fake_acc) / 2
        
            self.log("fake_loss", fake_loss.detach(), prog_bar=False, logger=True)
            self.log("real_loss", real_loss.detach(), prog_bar=False, logger=True)
            self.log("d_loss", d_loss.detach(), prog_bar=True, logger=True)
            self.log("real_acc", real_acc, prog_bar=False, logger=True)
            self.log("fake_acc", fake_acc, prog_bar=False, logger=True)
            self.log("avg_acc", avg_acc, prog_bar=True, logger=True)

            return d_loss

    def training_epoch_end(self, training_step_outputs):
        # runs after each epoch
        noise = self.validation_noise
        noise = noise.type_as(self.generator.model[0].weight) # ensuring it's on device
        
        sample_imgs = self(noise).cpu()
        grid = torchvision.utils.make_grid(sample_imgs)
        plt.imshow((grid.permute(1, 2, 0)+1)/2)
        self.logger.experiment.add_image("generated_images_epoch_end", grid, self.current_epoch)

    def configure_optimizers(self,):
        # define pytorch optimizers here. return [list of optimizers], [list of LR schedulers]
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
        return [g_optimizer, d_optimizer], []