<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/DT_script_VAEGAN_voxelsize1_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#mount google drive
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#right click shared folder IRCMS_GAN_collaborative_database and "Add shortcut to Drive" to My drive
%cd drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/colab-treegan/

#record paths to resources
data_path = "../../Research/Peter/Tree_3D_models_obj_auto_generated/sessions/simplified/"
run_path = "./"

!apt-get update

!pip install pytorch-lightning==0.9.0
!pip install trimesh
!pip install wandb==0.9.7
!apt install -y xvfb
!pip install trimesh xvfbwrapper
output.clear()

#add libraries, and login to wandb

In [3]:
import io
import os
import trimesh
import numpy as np



import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

# Ignore excessive warnings
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

# WandB – Import the wandb library
import wandb

from xvfbwrapper import Xvfb

In [4]:
!wandb login
output.clear()

In [5]:
#id None to start a new run. For resuming run, put the id of the run below
id = None 
resume = False
if id is None:
    id = wandb.util.generate_id()
else:
    resume = True

run = wandb.init(project="tree-gan", id=id, entity="bugan", resume=True, dir=run_path)
print("run id: " + str(wandb.run.id))
print("run name: " + str(wandb.run.name))
wandb.watch_called = False
wandb.run.save_code = True

[34m[1mwandb[0m: Wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


run id: 351kaqjq
run name: 351kaqjq


In [6]:
#keep track of hyperparams
#wandb.config
config = {}

config["batch_size"] = 8
config["array_size"] = 32

config["z_size"] = 128
config["gen_num_layer_unit"] = [256, 1024, 512, 128]
config["dis_num_layer_unit"] = [32, 64, 128, 128]
config["leakyReLU"] = False    #leakyReLU implementation still not in modelPL
config["balance_voxel_in_space"] = False

config["epochs"] = 1000
config["vae_lr"] = 0.0025
config["vae_encoder_layer"] = 1
config["vae_decoder_layer"] = 2
config["d_lr"] = 0.00005            
config["d_layer"] = 1
config["vae_recon_loss_factor"] = 1
config["seed"] = 1234
config["log_image_interval"] = 5
config["log_mesh_interval"] = 50
config["data_augmentation"] = True
config["num_augment_data"] = 4

config["vae_opt"] = "Adam"
config["dis_opt"] = "Adam"

config["dataset"] = "dataset_array_Friedrich_2_8710.npy"

#cast config to namespace
from argparse import Namespace
config = Namespace(**config)

#dataset

In [7]:
### load our package

#clone then install
# !git clone https://github.com/buganart/BUGAN repo
# !pip install -e ./repo/
# import site
# site.main()

#directly install using pip
!pip install git+https://github.com/buganart/BUGAN.git#egg=bugan
output.clear()
from pytorch_lightning.callbacks.base import Callback

from bugan.functionsPL import *
from bugan.modelsDT import VAEGAN, VAE, Discriminator, Generator

# from functionsPL import *
# from modelsPL import VAEGAN, VAE, Discriminator, Generator

run.tags.append("VAEGAN")
run.group = "VAEGAN"

###     load dataset
np.random.seed(config.seed)

class DataModule_custom(pl.LightningDataModule):

    def __init__(self, config, run, filename):
        super().__init__()
        self.config = config
        self.run = run
        self.dataset_artifact = None
        self.dataset = None
        self.size = 0
        self.filename = filename

    def prepare_data(self):
        return

    def setup(self, stage=None):
        config = self.config
        dataset = np.load(os.path.join(data_path, self.filename))
                
        #now all the returned array contains multiple samples
        self.size = dataset.shape[0]
        self.dataset = torch.unsqueeze(torch.tensor(dataset), 1)

    def train_dataloader(self):
        config = self.config
        tensor_dataset = TensorDataset(self.dataset)
        return DataLoader(tensor_dataset, batch_size=config.batch_size, shuffle=True)
        
dataModule = DataModule_custom(config, run, config.dataset)
config.num_data = dataModule.size

#train

In [None]:
#set seed
torch.manual_seed(config.seed)
torch.autograd.set_detect_anomaly(True)

#render setup
vdisplay = Xvfb()
vdisplay.start()

#wandb logger setup
wandb_logger = WandbLogger(experiment=run, log_model=True)
#log config
wandb.config.update(config)

checkpoint_path = os.path.join(wandb.run.dir, 'checkpoint.ckpt')
callbacks = [SaveWandbCallback(config.log_image_interval, checkpoint_path)]

if resume:
    #get file from the wandb cloud
    load_checkpoint_from_cloud(checkpoint_path = 'checkpoint.ckpt')
    #restore training state completely
    trainer = pl.Trainer(max_epochs = config.epochs, logger=wandb_logger, callbacks=callbacks,\
                         checkpoint_callback = None, resume_from_checkpoint = checkpoint_path, gpus=-1, distributed_backend='ddp')
else:
    trainer = pl.Trainer(max_epochs = config.epochs, logger=wandb_logger, callbacks=callbacks,\
                         checkpoint_callback = None, gpus=-1, distributed_backend='ddp')

#model
vaegan = VAEGAN(config).to(device)
wandb_logger.watch(vaegan)

trainer.fit(vaegan, dataModule)

[34m[1mwandb[0m: Wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=ddp
All DDP processes registered. Starting ddp with 1 processes
----------------------------------------------------------------------------------------------------

  | Name          | Type          | Params
------------------------------------------------
0 | vae           | VAE           | 62 M  
1 | discriminator | Discriminator | 1 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

  "See the documentation of nn.Upsample for details.".format(mode))
  "See the documentation of nn.Upsample for details.".format(mode))
