<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/notebook_util/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

In [None]:
#@markdown Mount google drive.
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

# Check if we have linked the folder
from pathlib import Path
if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
    print(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Go to the google drive web UI\n"
        "\t2. Right click shared folder IRCMS_GAN_collaborative_database and click \"Add shortcut to Drive\""
    )

In [None]:
#@markdown Install wandb and log in
%pip install wandb
output.clear()
import wandb
from pathlib import Path
wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
wandb_local_netrc_path = Path("/root/.netrc")
if wandb_drive_netrc_path.exists():
    import shutil

    print("Wandb .netrc file found, will use that to log in.")
    shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
else:
    print(
        f"Wandb config not found at {wandb_drive_netrc_path}.\n"
        f"Using manual login.\n\n"
        f"To use auto login in the future, finish the manual login first and then run:\n\n"
        f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
        f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
        f"Then that file will be used to login next time.\n"
    )

!wandb login
output.clear()
print("ok!")

In [None]:
#@title Configure dataset
#@markdown - Leave empty if you want to start a new run
#@markdown - Set `"run_id"` if you want to resume a run (for example: `u9imsvva`)
#@markdown - The id of the current run is shown below in the cell with `wandb.init()` or you can find it in the wandb web UI.
resume_id = "" #@param {type:"string"}
#@markdown Enter project name (either `chair-gan`, `handtool-gan` or `tree-gan`)
project_name = "tree-gan" #@param ["tree-gan", "handtool-gan", "chair-gan"]
#@markdown Enter dataset location.  
#@markdown - For example via the file browser on the left to locate and right click to copy the path.)
#@markdown - zipfile example: `/content/drive/My Drive/h/k/a.zip`
#@markdown - file folder example: `/content/drive/My Drive/h/k`
#@markdown - if data_location_option is not empty, data_location_option will overwrite data_location_default
data_location_default = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj_auto_generated/sessions/simplified/tree-session-2020-09-14_23-23-Friedrich_2-target-face-num-1000.zip" #@param ["/content/drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj_auto_generated/sessions/simplified/tree-session-2020-09-14_23-23-Friedrich_2-target-face-num-1000.zip", "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj_auto_generated/sessions/simplified/tree-sessions-2020-09-10-simplified-26k-target-face-num-1000.zip", "/content/drive/My Drive/Hand-Tool-Data-Set/turbosquid_thingiverse_dataset/dataset_ply_out_zipped.zip", "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Chairs_Princeton/chair_train.zip", "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj_auto_generated/sessions/simplified/tree-sessions-2020-09-10-simplified-26k-target-face-num-1000-class-label.zip","/content/drive/My Drive/Hand-Tool-Data-Set/handtool-v3-combined-tnf-1000.zip"] 
data_location_option = "" #@param {type:"string"}
#@markdown - For conditional dataset
#@markdown - maximum number of classes to extract based on the data_location path
num_classes = 10 #@param {type:"integer"}
#@markdown - choose rotation augmentation on-the-fly 
#@markdown (augmentation only support file folder in data_location)
data_augmentation = True    #@param {type:"boolean"}
aug_rotation_type = "axis rotation"  #@param ["random rotation", "axis rotation"]
#@markdown - specify the rotation axis [x,y,z] (only for aug_rotation_type = "axis rotation")
rotation_axis_x = 0    #@param {type:"number"}
rotation_axis_y = 1    #@param {type:"number"}
rotation_axis_z = 0    #@param {type:"number"}

#@markdown - resolution of the voxelized array (shape resolution**3)
resolution = "32"    #@param [32, 64]

#@markdown Model
#@markdown - select which model to train
#@markdown - choosing unconditional models will set num_classes = 0
#@markdown to load the dataset in unconditional way
selected_model = "WGAN_GP"    #@param ["VAEGAN", "GAN", "VAE", "WGAN", "WGAN_GP", "CGAN"]

#@markdown WANDB log
#@markdown - how many epochs before logging images/3D objects
log_interval = 20    #@param {type:"integer"}
#@markdown - how many samples per log
log_num_samples = 4    #@param {type:"integer"}

#@markdown WANDB run note
#@markdown - please describe the reason for running this experiment
run_note = "" #@param {type:"string"}


#adjust parameter datatype
resolution = int(resolution)
if data_location_option:
    data_location = data_location_option
else:
    data_location = data_location_default
if data_location.endswith(".zip"):
    dataset = Path(data_location).stem
else:
    dataset = "dataset_array_custom"
if selected_model in ["VAEGAN", "GAN", "VAE", "WGAN", "WGAN_GP"]:
    num_classes = 0

colab_config = {
    "aug_rotation_type": aug_rotation_type,
    "data_augmentation": data_augmentation,
    "aug_rotation_axis": (rotation_axis_x,rotation_axis_y,rotation_axis_z),
    "data_location": data_location,
    "dataset": dataset,
    "resume_id": resume_id,
    "selected_model": selected_model,
    "log_interval": log_interval,
    "log_num_samples": log_num_samples,
    "project_name": project_name,
    "resolution": resolution,
    "num_classes": num_classes,
}

for k, v in colab_config.items():
    print(f"=> {k:20}: {v}")


# To just train a model, no edits should be required in any cells below.

In [None]:
import os
from pathlib import Path
# os.environ["WANDB_MODE"] = "dryrun"

%cd /content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/
if project_name == "tree-gan":
    %cd colab-treegan/
elif project_name == "handtool-gan":
    %cd colab-handtool/
else:
    %cd colab-chair/

dataset_path = Path(data_location)
run_path = "./"

!apt-get update

!apt install -y xvfb
%pip install --upgrade xvfbwrapper
output.clear()
print('ok!')

In [None]:
from argparse import Namespace, ArgumentParser

from xvfbwrapper import Xvfb
import torch
# import pkg_resources

def get_resume_run_config(project_name, resume_id):
    # all config will be replaced by the stored one in wandb
    api = wandb.Api()
    previous_run = api.run(f"bugan/{project_name}/{resume_id}")
    config = Namespace(**previous_run.config)
    return config

def get_bugan_package_revision_number():
    # version_str = pkg_resources.get_distribution('bugan').version
    # rev_number = (version_str.split("+g")[1]).split(".")[0]
    import bugan
    import io, sys
    #EXTRACT package version
        #switch stdout to temperary stringIO
    old_stdout = sys.stdout
    temp_stdout = io.StringIO()
    sys.stdout = temp_stdout
        #get version
    %pip freeze | grep bugan
    version = temp_stdout.getvalue()
    rev_number = version.split("+g")[1].rstrip()
        #switch back stdout
    sys.stdout = old_stdout
    return rev_number

#train setup config and package

In [None]:
config = Namespace(**colab_config)
config.seed = 1234
config.epochs = 3000
config.batch_size = 32
# resume previous run config
if config.resume_id:
    project_name = config.project_name
    resume_id = config.resume_id
    prev_config = get_resume_run_config(project_name, resume_id)
    #replace config with prev_config
    config = vars(config)
    config.update(vars(prev_config))
    config = Namespace(**config)
    #reassign resume_id
    config.resume_id = resume_id

In [None]:
# load bugan package and record revision_number
if hasattr(config,"rev_number"):
    print("loading BUGAN package rev_number", config.rev_number)
    %pip install --upgrade git+https://github.com/buganart/BUGAN.git@{config.rev_number}#egg=bugan
else:
    print("loading BUGAN package latest")
    %pip install --upgrade git+https://github.com/buganart/BUGAN.git#egg=bugan
output.clear()

import bugan
from bugan.trainPL import (
    init_wandb_run,
    setup_datamodule,
    setup_model,
    train,
)
#record revision number
config.rev_number = get_bugan_package_revision_number()

#training (wandb_init, datamodule, model, train)

In [None]:
run_path = Path(run_path).absolute().parent
run, config = init_wandb_run(config, run_dir=run_path)#, mode="offline")
run.notes = run_note

In [None]:
model, extra_trainer_args = setup_model(config, run)
dataModule = setup_datamodule(config)

if torch.cuda.is_available():
    extra_trainer_args["gpus"] = -1

#render setup
vdisplay = Xvfb()
vdisplay.start()

train(config, run, model, dataModule, extra_trainer_args)