<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/notebook_util/generate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

In [None]:
#@markdown Check GPU, should be a Tesla V100
!nvidia-smi -L
import os
print(f"We have {os.cpu_count()} CPU cores.")

In [None]:
#@markdown Mount google drive.
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

# Check if we have linked the folder
from pathlib import Path
if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
    print(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Go to the google drive web UI\n"
        "\t2. Right click shared folder IRCMS_GAN_collaborative_database and click \"Add shortcut to Drive\""
    )

In [None]:
#@title Configure dataset
#@markdown - choose whether to use wandb id or use checkpoint_path to select checkpoint file

#@markdown Choice 1: wandb id to select checkpoint file
#@markdown Enter project name
#@markdown - the name of the wandb project in the format of {entity}/{project_name}
#@markdown - "bugan/tree-gan", "bugan/handtool-gan", "bugan/chair-gan" are private project reserved for bugan.
#@markdown - "bugan/stylegan2-open" is a open project, which act as a testing ground for the public.
#@markdown - In case you want the experiment results upload to your project, fill it in project_name_option.
#@markdown - if project_name_option is not empty, project_name_option will overwrite project_name
project_name = "bugan/tree-gan" #@param ["bugan/tree-gan", "bugan/handtool-gan", "bugan/chair-gan", "bugan/bu-3dgan-open"]
project_name_option = "" #@param {type:"string"}
#@markdown - set `"run_id"` to resume a run (for example: `u9imsvva`)
#@markdown - The id of the current run is shown below in the cell with `wandb.init()` or you can find it in the wandb web UI.
id = "1xqvp1q4" #@param {type:"string"}

#@markdown Choice 2: file path and model type to select saved checkpoint .ckpt file
#@markdown - For example via the file browser on the left to locate and right click to copy the path.
#@markdown - file path example: `/content/drive/My Drive/h/k/checkpoint.ckpt` 
ckpt_file_location = "" #@param {type:"string"}
#@markdown - Enter trained neural network model type
#@markdown - (may be necessary for wandb_id if selected_model is not saved in config)
selected_model = "CVAEGAN"    #@param ["VAEGAN", "GAN", "VAE", "WGAN", "WGAN_GP", "CGAN", "CVAEGAN"]

#@markdown In case the model is conditional, specify class index here.
class_index =  12#@param {type:"integer"}
#@markdown Enter how many samples to generate
test_num_samples =     100#@param {type:"integer"}
#@markdown OPTIONAL: boolean to choose whether to generate from list of history checkpoints or just the latest checkpoint.
generateMeshHistory = False #@param {type:"boolean"}
#@markdown OPTIONAL: If generateMeshHistory enabled, select "num_selected_checkpoint" checkpoints in the saved history checkpoint list.
num_selected_checkpoint =  0 #@param {type:"integer"}
#@markdown whether to zip file and download with browser
browser_download = False #@param {type:"boolean"}
#@markdown Enter export location for generated mesh (folder/directory).    
export_location = f"/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/exportObjects/{id}" #@param {type:"string"}

#@markdown ### Post process generated mesh
#@markdown whether to post process meshes
post_process = True #@param {type:"boolean"}
#@markdown remove clusters that has no points in the unit sphere of radius.
#@markdown - casting sphere of radius in the center of the cube voxel space.
#@markdown - keep cluster that has at least 1 point in the sphere, those has no points in the sphere will be discarded.
#@markdown - For resolution=64, radius of the sphere that fit the cube voxel space is 32.
radius=28 #@param {type:"number"}
#@markdown remove clusters that has less than point_threshold points.
point_threshold = 50 #@param {type:"integer"}
#@markdown Enter export location for post-processed generated mesh (folder/directory).
#@markdown - After the mesh is generated, they are in array format. Post-processing algorithms will be used to remove outliers / floating voxels.
export_location_processed = f"/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/exportObjects/{id}_postprocessed" #@param {type:"string"}


##project_name_option
if project_name_option:
    project_name = project_name_option
    if "/" in project_name:
        #validate
        project_list = project_name.split("/")
        if len(project_list[0]) < 1:
            print("\"/\" exists, but entity is empty.")
            print(f"set entity as login entity, and set project_name: {project_list[1]}")
            project_name = project_list[1]
        if len(project_list[1]) < 1:
            raise Exception("\"/\" exists, but project_name is empty.")


if id and ckpt_file_location:
    raise Exception("Only one of id / ckpt_file_location can be set!")
if (not id) and (not ckpt_file_location):
    raise Exception("Please set id / ckpt_file_location!")

if id:
    print("id:", id)
else:
    print("selected_model:", selected_model)
    print("ckpt_file_location:", ckpt_file_location)
print("test_num_samples:", test_num_samples)
print("export_location:", export_location)


In [None]:
from argparse import Namespace, ArgumentParser
#@markdown Install wandb and log in
rev_number = None
if id:
    !pip install wandb
    output.clear()
    #find wandb API key file to auto login
    import wandb
    wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
    wandb_local_netrc_path = Path("/root/.netrc")
    if wandb_drive_netrc_path.exists():
        import shutil

        print("Wandb .netrc file found, will use that to log in.")
        shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
    else:
        print(
            f"Wandb config not found at {wandb_drive_netrc_path}.\n"
            f"Using manual login.\n\n"
            f"To use auto login in the future, finish the manual login first and then run:\n\n"
            f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
            f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
            f"Then that file will be used to login next time.\n"
        )

    !wandb login

    #read information (run config, etc) stored online
        #all config will be replaced by the stored one in wandb
    # extract entity from project_name if in format "{entity}/{project_name}"
    if "/" in project_name:
        project_list = project_name.split("/")
        entity = project_list[0]
        project_name = project_list[1]
    else:
        # use default login entity
        entity = None

    run_string = f"{project_name}/{id}"
    if entity:
        run_string = f"{entity}/" + run_string

    # all config will be replaced by the stored one in wandb
    api = wandb.Api()
    run = api.run(run_string)
    config = Namespace(**run.config)
        #load selected_model, rev_number in the config
    if hasattr(config, "selected_model"):
        selected_model = config.selected_model
    if hasattr(config, "rev_number"):
        rev_number = config.rev_number

    output.clear()
    print("run id: " + str(run.id))
    print("run name: " + str(run.name))
    print("ok!")

# To just train a model, no edits should be required in any cells below.

In [None]:
#@markdown package and functions

import os
import sys
import subprocess
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from pathlib import Path
os.environ["WANDB_MODE"] = "dryrun"

%pip install --upgrade git+https://github.com/buganart/BUGAN.git#egg=bugan

run_path = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/"

from bugan.trainPL import _get_models
from bugan.functionsPL import netarray2mesh, eval_cluster

def install_bugan_package(rev_number=None):
    if rev_number:
        subprocess.check_call(
            [
                sys.executable,
                "-m",
                "pip",
                "install",
                "--upgrade",
                f"git+https://github.com/buganart/BUGAN.git@{rev_number}#egg=bugan",
            ]
        )
    else:
        subprocess.check_call(
            [
                sys.executable,
                "-m",
                "pip",
                "install",
                "--upgrade",
                "git+https://github.com/buganart/BUGAN.git#egg=bugan",
            ]
        )

# load model
def generateFromCheckpoint(selected_model, ckpt_filePath, class_index=None, num_samples=1, package_rev_number=None, output_file_name_dict = {}):
    # in case loading model failure can be fixed by modifying config params, modify here.
    try:
        # try newest bugan version
        install_bugan_package(rev_number=package_rev_number)
        from bugan.trainPL import _get_models
        MODEL_CLASS = _get_models(selected_model)
        model = MODEL_CLASS.load_from_checkpoint(ckpt_filePath, config=config)
    except Exception as e:
        print(e)
        # restore bugan version
        install_bugan_package()
        from bugan.trainPL import _get_models
        MODEL_CLASS = _get_models(selected_model)
        model = MODEL_CLASS.load_from_checkpoint(ckpt_filePath, config=config)

    model = model.eval()#.to(device)
    try:
        #assume conditional model
        sample_trees = model.generate_tree(c=class_index, num_trees=num_samples)
        output_file_name_dict["class"] = class_index
    except Exception as e:
        print(e)
        print("generate with class label does not work. Now generate without label")
        #assume unconditional model
        sample_trees = model.generate_tree(num_trees=num_samples)
    
    save_filename_header = ""
    for k,v in output_file_name_dict.items():
        save_filename_header = save_filename_header + f"_{str(k)}_{str(v)}"

    for n in range(num_samples):
        sample_tree_array = sample_trees[n]
        voxelmesh = netarray2mesh(sample_tree_array)
        save_filename = f"sample_{n}{save_filename_header}.obj"
        if export_location:
            export_path = export_location / save_filename
            voxelmesh.export(file_obj=export_path, file_type="obj")
        #if zip, save also to temp_location 
        if browser_download:
            temp_path = temp_location / save_filename
            voxelmesh.export(file_obj=temp_path, file_type="obj")
        
        # post processing
        if post_process and export_location_processed:
            # post process array
            processed_tree_array = post_process_array(sample_tree_array)
            processed_voxelmesh = netarray2mesh(processed_tree_array)
            processed_save_filename = f"sample_{n}{save_filename_header}_processed.obj"
            # save
            export_path = export_location_processed / processed_save_filename
            processed_voxelmesh.export(file_obj=export_path, file_type="obj")

            if browser_download:
                temp_path = temp_location / processed_save_filename
                processed_voxelmesh.export(file_obj=temp_path, file_type="obj")

    return sample_trees


# post processing array
def cluster_in_sphere(voxel_index_list, center, radius):
    center = np.array(center)
    for v in voxel_index_list:
        v = np.array(v)
        dist = np.linalg.norm(v-center)
        if dist < radius:
            return True
    return False

def post_process_array(boolarray):
    boolarray = boolarray > 0
    cluster = eval_cluster(boolarray)

    #post process
    process_cluster = []
    for l in cluster:
        l = list(l)
        if len(l) < point_threshold:
            continue
        if not cluster_in_sphere(l, np.array(boolarray.shape) / 2, radius):
            continue
        process_cluster.append(l)

    #point form back to array form
    processed_tree = np.zeros_like(boolarray)
    for c in process_cluster:
        for index in c:
            i,j,k = index
            processed_tree[i,j,k] = 1
    return processed_tree

output.clear()
print('ok!')

#generate samples

In [None]:
#@markdown generate samples
temp_location = Path("/tmp/generated/")
temp_location.mkdir(parents=True, exist_ok=True)
# copy files to export location
if export_location:
    # create directory
    export_location = Path(export_location)
    export_location.mkdir(parents=True, exist_ok=True)
if post_process and export_location_processed:
    # create directory
    export_location_processed = Path(export_location_processed)
    export_location_processed.mkdir(parents=True, exist_ok=True)

if generateMeshHistory:
    # find necessary checkpoint file
    epoch_list = []
    epoch_file_dict = {}
    for file in run.files():
        filename = file.name
        if not ".ckpt" in filename:
            continue
        if (filename == "checkpoint.ckpt") or (filename == "checkpoint_prev.ckpt"):
            continue
        file_epoch = str((filename.split("_")[1]).split(".")[0])
        epoch_list.append(int(file_epoch))
        epoch_file_dict[file_epoch] = file

    epoch_list = sorted(epoch_list)
    # # quick fix index if generate mesh history crash.
    # epoch_list = epoch_list[22:]
    if len(epoch_list) < num_selected_checkpoint:
        num_selected_checkpoint = len(epoch_list)
    print(f"select {num_selected_checkpoint} out of {len(epoch_list)} checkpoints......")
    selected_epoch_index = [
        int(i / (num_selected_checkpoint - 1) * (len(epoch_list) - 1) + 0.5)
        for i in range(num_selected_checkpoint)
    ]

    #download checkpoint and generate mesh
    for checkpoint_epoch_index in selected_epoch_index:
        file_epoch = str(epoch_list[checkpoint_epoch_index])
        print(f"generate mesh for epoch {file_epoch}......")
        try:
            ckpt_file = epoch_file_dict[file_epoch]
            ckpt_file.download(replace=True)
            generateFromCheckpoint(selected_model, ckpt_file.name, class_index, test_num_samples, rev_number, {"epoch":file_epoch})
        except Exception as e:
            print(e)
            print("generate mesh for epoch {file_epoch} FAILED !!!")
else:
    try:
        ckpt_file = run.file("checkpoint.ckpt").download(replace=True)
        generateFromCheckpoint(selected_model, ckpt_file.name, class_index, test_num_samples, rev_number, {})
    except Exception as e:
        print(e)
        print("loading from checkpoint.ckpt failed. Try checkpoint_prev.ckpt")
        ckpt_file = run.file("checkpoint_prev.ckpt").download(replace=True)
        generateFromCheckpoint(selected_model, ckpt_file.name, class_index, test_num_samples, rev_number, {})

#zip files in the directory
if not browser_download:
    raise Exception("will not download files from the browser. Terminate here.")

!zip -r /tmp/file.zip /tmp/generated/
from google.colab import files
files.download("/tmp/file.zip")

print("ok!")