<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/custom_dataset_of_script_VAEGAN_voxelsize1_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

In [1]:
#@markdown Mount google drive.
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

# Check if we have linked the folder
from pathlib import Path
if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
    print(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Go to the google drive web UI\n"
        "\t2. Right click shared folder IRCMS_GAN_collaborative_database and click \"Add shortcut to Drive\""
    )

Mounted at /content/drive


In [None]:
#@markdown Install wandb and log in
!pip install wandb==0.9.7
output.clear()
import wandb
!wandb login
output.clear()
print("ok!")

[34m[1mwandb[0m: You can find your API key in your browser here: https://app.wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 

In [None]:
#@title Configure dataset
#@markdown - set `None`if you want to start a new run
#@markdown - set `"run_id"` if you want to resume a run (for example: `u9imsvva`)
#@markdown - The id of the current run is shown below in the cell with `wandb.init()` or you can find it in the wandb web UI.
id = None #@param {type:"string"}
#@markdown Enter project name (either `handtool-gan` or `tree-gan`)
project_name = "tree-gan" #@param ["tree-gan", "handtool-gan"]
#@markdown Enter file location.  
#@markdown - For example via the file browser on the left to locate and right click to copy the path.)
#@markdown - zipfile example: `/content/drive/My Drive/h/k/a.zip`
#@markdown - file folder example: `/content/drive/My Drive/h/k` 
file_location = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Chairs_Princeton/chair_train.zip" #@param {type:"string"}
#@markdown - choose rotation augmentation on-the-fly 
#@markdown (augmentation only support file folder in file_location)
data_augmentation = True    #@param {type:"boolean"}
#@markdown - specify the rotation axis (x,y,z)
rotation_axis_x = 0    #@param {type:"number"}
rotation_axis_y = 1    #@param {type:"number"}
rotation_axis_z = 0    #@param {type:"number"}

#@markdown - resolution of the voxelized array (shape resolution**3)
resolution = "64"    #@param [32, 64]

#@markdown WANDB log
#@markdown - how many epochs before logging images/3D objects
log_interval = 5    #@param {type:"integer"}
#@markdown - how many samples per log
log_num_samples = 3    #@param {type:"integer"}


print("id:", id)
print("project_name:", project_name)
print("file_location:", file_location)
print("data_augmentation:", data_augmentation)
print("resolution:", resolution)
print("log_interval:", log_interval)
print("log_num_samples:", log_num_samples)

# To just train a model, no edits should be required in any cells below.

In [None]:
import os
from pathlib import Path
os.environ["WANDB_MODE"] = "dryrun"

%cd /content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/
if project_name == "tree-gan":
    %cd colab-treegan/
else:
    %cd colab-handtool/

dataset_path = Path(file_location)
run_path = "./"

if file_location.endswith(".zip"):
    dataset_name = dataset_path.stem
else:
    dataset_name = "dataset_array_custom"

!apt-get update

!pip install pytorch-lightning==0.9.0
!pip install trimesh
!apt install -y xvfb
!pip install trimesh xvfbwrapper
output.clear()
print('ok!')

In [None]:
import io
from io import BytesIO
import zipfile
import trimesh
import numpy as np


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

# Ignore excessive warnings
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

from xvfbwrapper import Xvfb

In [None]:
resume = False
if id is None:
    id = wandb.util.generate_id()
else:
    resume = True

run = wandb.init(project=project_name, id=id, entity="bugan", resume=True, dir=run_path)
print("run id: " + str(wandb.run.id))
print("run name: " + str(wandb.run.name))
wandb.watch_called = False
wandb.run.save_code = True

In [None]:
#keep track of hyperparams
config = wandb.config

config.batch_size = 8
config.array_size = int(resolution)

config.z_size = 128
if config.array_size == 32:
    config.gen_num_layer_unit = [1024, 512, 256, 128]
    config.dis_num_layer_unit = [32, 64, 128, 128]
else:
    config.gen_num_layer_unit = [1024, 512, 256, 128, 128]
    config.dis_num_layer_unit = [32, 64, 64, 128, 128]

config.leakyReLU = False    #leakyReLU implementation still not in modelPL
config.balance_voxel_in_space = False

config.epochs = 3000
config.vae_lr = 0.0025
config.vae_encoder_layer = 1
config.vae_decoder_layer = 2
config.d_lr = 0.00005            
config.d_layer = 1
config.vae_recon_loss_factor = 1
config.seed = 1234
config.log_interval = log_interval
config.log_num_samples = log_num_samples
config.data_augmentation = data_augmentation
config.aug_rotation_axis = (rotation_axis_x,rotation_axis_y,rotation_axis_z)

config.vae_opt = "Adam"
config.dis_opt = "Adam"

#dataset

In [None]:
### load our package

#clone then install
# !git clone https://github.com/buganart/BUGAN repo
# !pip install -e ./repo/
# import site
# site.main()

#directly install using pip
!pip install -U git+https://github.com/buganart/BUGAN.git#egg=bugan
output.clear()

from bugan.functionsPL import *
from bugan.modelsPL import VAEGAN, VAE, Discriminator, Generator

# from functionsPL import *
# from modelsPL import VAEGAN, VAE, Discriminator, Generator

run.tags.append("VAEGAN")
run.group = "VAEGAN"

###     load dataset
np.random.seed(config.seed)
# dataModule = DataModule(config, run)
# config.num_data = dataModule.size

config.dataset = dataset_name
if config.data_augmentation:
    dataModule = DataModule_augmentation(config, run, dataset_path)
else:
    dataModule = DataModule_process(config, run, dataset_path)
config.num_data = dataModule.size

#train

In [None]:
#set seed
torch.manual_seed(config.seed)
torch.autograd.set_detect_anomaly(True)

#render setup
vdisplay = Xvfb()
vdisplay.start()

#wandb logger setup
wandb_logger = WandbLogger(experiment=run, log_model=True)

checkpoint_path = os.path.join(wandb.run.dir, 'checkpoint.ckpt')

if resume:
    #get file from the wandb cloud
    load_checkpoint_from_cloud(checkpoint_path = 'checkpoint.ckpt')
    #restore training state completely
    trainer = pl.Trainer(max_epochs = config.epochs, logger=wandb_logger, checkpoint_callback = None, resume_from_checkpoint = checkpoint_path)
else:
    trainer = pl.Trainer(max_epochs = config.epochs, logger=wandb_logger, checkpoint_callback = None)

#model
vaegan = VAEGAN(config, trainer, save_model_path = checkpoint_path).to(device)
wandb_logger.watch(vaegan)

trainer.fit(vaegan, dataModule)