<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/custom_dataset_of_script_VAEGAN_voxelsize1_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

In [1]:
#@markdown Mount google drive.
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

# Check if we have linked the folder
from pathlib import Path
if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
    print(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Go to the google drive web UI\n"
        "\t2. Right click shared folder IRCMS_GAN_collaborative_database and click \"Add shortcut to Drive\""
    )

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
#@markdown Install wandb and log in
!pip install wandb==0.9.7
output.clear()
import wandb
!wandb login
output.clear()
print("ok!")

ok!


In [3]:
#@title Configure dataset
#@markdown - set `None`if you want to start a new run
#@markdown - set `"run_id"` if you want to resume a run (for example: `u9imsvva`)
#@markdown - The id of the current run is shown below in the cell with `wandb.init()` or you can find it in the wandb web UI.
id = None #@param {type:"string"}
#@markdown Enter project name (either `handtool-gan` or `tree-gan`)
project_name = "tree-gan" #@param ["tree-gan", "handtool-gan"]
#@markdown Enter dataset location.  
#@markdown - For example via the file browser on the left to locate and right click to copy the path.)
#@markdown - zipfile example: `/content/drive/My Drive/h/k/a.zip`
#@markdown - file folder example: `/content/drive/My Drive/h/k` 
data_location = "/content/drive/My Drive/Hand-Tool-Data-Set/turbosquid_thingiverse_dataset/dataset_ply_out" #@param {type:"string"}
#@markdown - choose rotation augmentation on-the-fly 
#@markdown (augmentation only support file folder in data_location)
data_augmentation = False    #@param {type:"boolean"}
aug_rotation_type = "random rotation"  #@param ["random rotation", "axis rotation"]
#@markdown - specify the rotation axis [x,y,z] (only for aug_rotation_type = "axis rotation")
rotation_axis_x = 0    #@param {type:"number"}
rotation_axis_y = 1    #@param {type:"number"}
rotation_axis_z = 0    #@param {type:"number"}

#@markdown - resolution of the voxelized array (shape resolution**3)
resolution = "32"    #@param [32, 64]

#@markdown Model
#@markdown - select which model to train
selected_model = "VAEGAN"    #@param ["VAEGAN", "GAN", "VAE"]

#@markdown WANDB log
#@markdown - how many epochs before logging images/3D objects
log_interval = 5    #@param {type:"integer"}
#@markdown - how many samples per log
log_num_samples = 3    #@param {type:"integer"}


print("id:", id)
print("project_name:", project_name)
print("data_location:", data_location)
print("data_augmentation:", data_augmentation)
print("resolution:", resolution)
print("log_interval:", log_interval)
print("log_num_samples:", log_num_samples)


id: None
project_name: tree-gan
data_location: /content/drive/My Drive/Hand-Tool-Data-Set/turbosquid_thingiverse_dataset/dataset_ply_out
data_augmentation: False
resolution: 32
log_interval: 5
log_num_samples: 3


# To just train a model, no edits should be required in any cells below.

In [4]:
import os
from pathlib import Path
# os.environ["WANDB_MODE"] = "dryrun"

%cd /content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/
if project_name == "tree-gan":
    %cd colab-treegan/
else:
    %cd colab-handtool/

dataset_path = Path(data_location)
run_path = "./"

if data_location.endswith(".zip"):
    dataset_name = dataset_path.stem
else:
    dataset_name = "dataset_array_custom"

!apt-get update

!pip install pytorch-lightning==0.9.0
!pip install trimesh
!apt install -y xvfb
!pip install trimesh xvfbwrapper
output.clear()
print('ok!')

ok!


In [5]:
import io
from io import BytesIO
import zipfile
import trimesh
import numpy as np
from argparse import Namespace


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

# Ignore excessive warnings
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

from xvfbwrapper import Xvfb

In [6]:
resume = False
if id is None:
    id = wandb.util.generate_id()
else:
    resume = True

run = wandb.init(project=project_name, id=id, entity="bugan", resume=True, dir=run_path)
print("run id: " + str(wandb.run.id))
print("run name: " + str(wandb.run.name))
wandb.watch_called = False
wandb.run.save_code = True

[34m[1mwandb[0m: Wandb version 0.10.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


run id: q9gzw9xe
run name: q9gzw9xe


In [7]:
#keep track of hyperparams
config = Namespace()

config.epochs = 3000
config.batch_size = 8
config.array_size = int(resolution)

#model param
config.z_size = 128
if config.array_size == 32:
    config.gen_num_layer_unit = [1024, 512, 256, 128]
    config.dis_num_layer_unit = [32, 64, 128, 128]
else:
    config.gen_num_layer_unit = [1024, 512, 256, 128, 128]
    config.dis_num_layer_unit = [32, 64, 64, 128, 128]

if selected_model == "VAEGAN":
    config.vae_recon_loss_factor = 1
    config.vae_opt = "Adam"
    config.dis_opt = "Adam"
    config.vae_lr = 0.0025
    config.vae_encoder_layer = 1
    config.vae_decoder_layer = 2
    config.d_lr = 0.00005            
    config.d_layer = 1
elif selected_model == "GAN":
    config.gen_opt = "Adam"
    config.dis_opt = "Adam"
    config.g_lr = 0.0025
    config.g_layer = 2
    config.d_lr = 0.00005           
    config.d_layer = 1
else:
    config.vae_opt = "Adam"
    config.vae_lr = 0.0025
    config.vae_encoder_layer = 1
    config.vae_decoder_layer = 2

config.seed = 1234
config.log_interval = log_interval
config.log_num_samples = log_num_samples
config.data_augmentation = data_augmentation
config.aug_rotation_type = aug_rotation_type
config.aug_rotation_axis = (rotation_axis_x,rotation_axis_y,rotation_axis_z)



#dataset

In [8]:
### load our package

#directly install using pip
!pip install -U git+https://github.com/buganart/BUGAN.git#egg=bugan
output.clear()

from bugan.functionsPL import *
from bugan.modelsPL import VAEGAN, VAE_train, GAN, VAE, Discriminator, Generator


run.tags.append(selected_model)
run.group = selected_model

###     load dataset
np.random.seed(config.seed)

config.dataset = dataset_name
dataModule = DataModule_process(config, run, dataset_path)
config.num_data = dataModule.size

#train

In [10]:
#set seed
torch.manual_seed(config.seed)
torch.autograd.set_detect_anomaly(True)

#render setup
vdisplay = Xvfb()
vdisplay.start()

#wandb logger setup
wandb_logger = WandbLogger(experiment=run, log_model=True)
#log config
wandb.config.update(config)

checkpoint_path = os.path.join(wandb.run.dir, 'checkpoint.ckpt')
callbacks = [SaveWandbCallback(config.log_interval, checkpoint_path)]

if resume:
    #get file from the wandb cloud
    load_checkpoint_from_cloud(checkpoint_path = 'checkpoint.ckpt')
    #restore training state completely
    trainer = pl.Trainer(max_epochs = config.epochs, logger=wandb_logger, callbacks=callbacks, checkpoint_callback = None, resume_from_checkpoint = checkpoint_path)
else:
    trainer = pl.Trainer(max_epochs = config.epochs, logger=wandb_logger, callbacks=callbacks, checkpoint_callback = None)

#model
if selected_model == "VAEGAN":
    model = VAEGAN(config).to(device)
elif selected_model == "GAN":
    model = GAN(config).to(device)
else:
    model = VAE_train(config).to(device)
wandb_logger.watch(model)


#train
trainer.fit(model, dataModule)

[34m[1mwandb[0m: Wandb version 0.10.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name          | Type          | Params
------------------------------------------------
0 | vae           | VAE           | 58 M  
1 | discriminator | Discriminator | 1 M   


Processed dataset /content/drive/My Drive/Hand-Tool-Data-Set/turbosquid_thingiverse_dataset/dataset_ply_out/dataset_array_processed_res32.npy already exists.




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

  "See the documentation of nn.Upsample for details.".format(mode))
  "See the documentation of nn.Upsample for details.".format(mode))
Saving latest checkpoint..





1