<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/custom_dataset_of_script_VAEGAN_voxelsize1_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#mount google drive
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# WandB – Import the wandb library
!pip install wandb==0.9.7
output.clear()
import wandb
!wandb login
output.clear()

#Information

In [3]:
#set id = None if you want to start a new run
#set id = "run_id" if you want to resume a run (for example: id = "u9imsvva")
#id of the current run is shown below in the cell with wandb.init() (the cell 6)
id = None
#please enter project name (either "handtool-gan" or "tree-gan")
project_name = "tree-gan"
#enter file location after drive/My Drive/ 
    # zipfile example: drive/My Drive/h/k/a.zip, then enter "h/k/a.zip"
    # file folder example: drive/My Drive/h/k/a.obj,drive/My Drive/h/k/b.obj, ...  then enter "h/k"
file_location = "IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj_auto_generated/sessions/simplified/tree-session-2020-09-14_23-23-Friedrich_2-target-face-num-1000.zip"  
# file_location = "IRCMS_GAN_collaborative_database/Research/Peter/sample_off_files"

#env

In [4]:
#right click shared folder IRCMS_GAN_collaborative_database and "Add shortcut to Drive" to My drive
%cd drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/
if project_name == "tree-gan":
    %cd colab-treegan/
else:
    %cd colab-handtool/

#record paths to resources
file_loc_list = file_location.split("/")

# data_dir = "../../../../../My Drive/Hand-Tool-Data-Set/"    #take care of .shortcut-targets-by-id/"folder-id"/ folders
run_path = "./"

import os
if file_location.endswith(".zip"):
    zipfile_name = file_loc_list[-1]
    dataset_name = ".".join(zipfile_name.split(".")[:-1])
    data_dir = os.path.join("../../../../../../drive/My Drive/", ("/".join(file_loc_list[:-1])))
    data_filepath = os.path.join(data_dir, zipfile_name)
else:
    dataset_name = "dataset_array_custom"
    data_dir = os.path.join("../../../../../../drive/My Drive/", file_location)
    data_filepath = data_dir

!apt-get update

!pip install pytorch-lightning==0.9.0
!pip install trimesh
!apt install -y xvfb
!pip install trimesh xvfbwrapper
output.clear()

#add libraries, and login to wandb

In [5]:
import io
from io import BytesIO
import zipfile
import trimesh
import numpy as np


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

# Ignore excessive warnings
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

from xvfbwrapper import Xvfb

In [6]:
#id None to start a new run. For resuming run, put the id of the run below
# id = None
resume = False
if id is None:
    id = wandb.util.generate_id()
else:
    resume = True

run = wandb.init(project=project_name, id=id, entity="bugan", resume=True, dir=run_path)
print("run id: " + str(wandb.run.id))
print("run name: " + str(wandb.run.name))
wandb.watch_called = False
wandb.run.save_code = True

[34m[1mwandb[0m: Wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


run id: kxdireeg
run name: kxdireeg


In [7]:
#keep track of hyperparams
config = wandb.config

config.batch_size = 8
config.array_size = 32

config.z_size = 128
config.gen_num_layer_unit = [1024, 512, 256, 128]
config.dis_num_layer_unit = [32, 64, 128, 128]
config.leakyReLU = False    #leakyReLU implementation still not in modelPL
config.balance_voxel_in_space = False

config.epochs = 3000
config.vae_lr = 0.0025
config.vae_encoder_layer = 1
config.vae_decoder_layer = 2
config.d_lr = 0.00005            
config.d_layer = 1
config.vae_recon_loss_factor = 1
config.seed = 1234
config.log_image_interval = 5
config.log_mesh_interval = 50
config.data_augmentation = True
config.num_augment_data = 4

config.vae_opt = "Adam"
config.dis_opt = "Adam"

[34m[1mwandb[0m: Wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


#dataset

In [8]:
### load our package

#clone then install
# !git clone https://github.com/buganart/BUGAN repo
# !pip install -e ./repo/
# import site
# site.main()

#directly install using pip
!pip install git+https://github.com/buganart/BUGAN.git#egg=bugan
output.clear()


from bugan.functionsPL import *
from bugan.modelsPL import VAEGAN, VAE, Discriminator, Generator

# from functionsPL import *
# from modelsPL import VAEGAN, VAE, Discriminator, Generator

run.tags.append("VAEGAN")
run.group = "VAEGAN"

###     load dataset
np.random.seed(config.seed)
# dataModule = DataModule(config, run)
# config.num_data = dataModule.size

config.dataset = dataset_name
dataModule = DataModule_process(config, run, data_filepath)
config.num_data = dataModule.size

[34m[1mwandb[0m: Wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


#train

In [None]:
#set seed
torch.manual_seed(config.seed)
torch.autograd.set_detect_anomaly(True)

#render setup
vdisplay = Xvfb()
vdisplay.start()

#wandb logger setup
wandb_logger = WandbLogger(experiment=run, log_model=True)

checkpoint_path = os.path.join(wandb.run.dir, 'checkpoint.ckpt')

if resume:
    #get file from the wandb cloud
    load_checkpoint_from_cloud(checkpoint_path = 'checkpoint.ckpt')
    #restore training state completely
    trainer = pl.Trainer(max_epochs = config.epochs, logger=wandb_logger, checkpoint_callback = None, resume_from_checkpoint = checkpoint_path)
else:
    trainer = pl.Trainer(max_epochs = config.epochs, logger=wandb_logger, checkpoint_callback = None)

#model
vaegan = VAEGAN(config, trainer, save_model_path = checkpoint_path).to(device)
wandb_logger.watch(vaegan)

trainer.fit(vaegan, dataModule)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores


processed dataset_array shape: (8710, 32, 32, 32)
number of failed data: 0


[34m[1mwandb[0m: Wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade

  | Name          | Type          | Params
------------------------------------------------
0 | vae           | VAE           | 58 M  
1 | discriminator | Discriminator | 1 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

  "See the documentation of nn.Upsample for details.".format(mode))
  "See the documentation of nn.Upsample for details.".format(mode))
