<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/notebook_util/generate_tree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

In [None]:
#@markdown Mount google drive.
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

# Check if we have linked the folder
from pathlib import Path
if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
    print(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Go to the google drive web UI\n"
        "\t2. Right click shared folder IRCMS_GAN_collaborative_database and click \"Add shortcut to Drive\""
    )

# Description

This notebook is used for generating 3D mesh objects based on trained runs in the wandb project "bugan/tree-gan". For training models, please go to [train.ipynb](https://github.com/buganart/BUGAN/blob/master/notebook_util/train.ipynb). 

User will need to specify **one trained model** through run_id/checkpoint_file that are trained in bugan train notebook above and provide the **number of samples to generate**, then the notebook will generate samples to the **export_location** specified in the panel.



---



# Instruction for specify trained model
The trained model is loaded from the saved checkpoint. There are 2 options to specify trained model: 

1.   set **run_id** and **class_index** directly, then the saved checkpoint of the run will be downloaded
2.   set **selected_tree_class**, which will load the preset *`run_id`* and *`class_index`* based on the **selected_tree_class**.
<!-- 3.   specify **checkpoint_path** and **class_index** -->

Note that the **class_index** is necessary if the loaded model is conditional. However, if the model is unconditional, setting **class_index** will not cause failure on generating 3D objects.



---



# Instruction for other generate parameters
 

**test_num_samples**: 
*   number of samples to generate from the loaded model.

*   If the model is unconditional, it will generate **[test_num_samples]** samples.

*   If the model is conditional, it will generate **[test_num_samples]** samples of class **[class_index]**.

**generateMeshHistory**: 
*   boolean to choose whether to generate from list of history checkpoints or just the latest checkpoint.

*   In the model training, if *`history_checkpoint_frequency`* is set, then the model will save history checkpoint based on the number of epoch (e.g. save 1 checkpoint per 40 epochs). As that lots of history checkpoints will be saved, only **num_selected_checkpoint** checkpoints will be selected in the list. Enable this will make the model to generate **[test_num_samples]** samples per selected history checkpoint.

**num_selected_checkpoint**: 

*   If the **generateMeshHistory** option is True, select **num_selected_checkpoint** checkpoints in the saved history checkpoint list. 

**export_location**:
*   the export location of generated 3D meshes in the Drive.

**browser_download**:
*   if **browser_download** is True, after generating meshes into **export_location**, the generated meshes will also be zipped into a zip file, and the browser will download the zip file.
*   Note that the **browser_download** is implemented using colab `files.download()`. The download processing speed will be very slow if the zip file is **> 300 MB** (1 mesh is around 4MB).
*   Also, if **browser_download** is False, `Exception: will not download files from the browser. Terminate here.` will occur, but it is normal, and the meshes are already saved in the **export_location**.



 
















In [None]:
#@title specify trained model
#@markdown - choose which option to specify trained model

#@markdown - select class index for conditional model (will be ignored for choice 2 or if the model is unconditional)
class_index = 3 #@param {type:"integer"}

#@markdown Choice 1: wandb run_id and to select checkpoint file (leave empty if you want to use preset model option 2)
#@markdown - set `"run_id"` to resume a run (for example: `u9imsvva`)
#@markdown - The id of the current run is shown below in the cell with `wandb.init()` or you can find it in the wandb web UI.
resume_id = "3mzlr9u3" #@param {type:"string"}

#@markdown Choice 2: generate tree from preset models based on class (please make sure resume_id is empty)
#@markdown - Select tree class: load the preset run_id and class_index based on the selected_tree_class 
selected_tree_class = "mustard_reaching_1"    #@param ['double_trunk_1', 'double_trunk_2', 'formal_upright_1', 'formal_upright_2', 'friedrich_1', 'friedrich_2', 'friedrich_3', 'group_1', 'group_2', 'informal_upright_1', 'informal_upright_2', 'leaning_1', 'leaning_2', 'mustard_reaching_1', 'mustard_sapling_1', 'mustard_sapling_2', 'pn_banyan_1', 'pn_maple_1', 'pn_old_1', 'pn_pine_1', 'pn_pine_2', 'pn_pine_3', 'pn_pine_4', 'pn_pine_5', 'pn_pine_6', 'pn_tall_straight', 'pn_tall_straight_old', 'raft_1', 'raft_2', 'semi_cascade_1', 'semi_cascade_2', 'sept_chen_lin_1', 'sept_chen_lin_2', 'sept_chen_lin_3', 'sept_constable_1', 'sept_friedrich_4', 'sept_friedrich_5', 'sept_holten_a', 'sept_holten_b', 'sept_holten_c', 'sept_holten_d', 'sept_holten_e', 'sept_holten_f', 'sept_holten_g', 'sept_holten_h', 'sept_holten_i', 'sept_holten_j', 'sept_holten_k', 'sept_holten_l', 'sept_holten_m', 'sept_holten_n', 'sept_holten_o', 'sept_holten_p', 'sept_holten_q', 'sept_holten_r', 'sept_holten_s', 'sept_holten_t', 'sept_holten_u', 'sept_holten_v', 'sept_holten_w', 'sept_holten_x', 'sept_holten_y', 'sept_holten_z', 'sept_mondrian_1', 'sept_mondrian_2', 'sept_mondrian_3', 'sept_schiele_1', 'sept_schiele_2', 'sept_schiele_3', 'windswept_1', 'windswept_2', 'zan_gentlemen_1', 'zan_gentlemen_2', 'zan_gentlemen_3', 'zan_gentlemen_4', 'zan_gentlemen_5']

preset_models = {
'double_trunk_1': ['vtcf6k3t', 0],
'double_trunk_2': ['vtcf6k3t', 0],
'formal_upright_1': ['vtcf6k3t', 0],
'formal_upright_2': ['vtcf6k3t', 0],
'friedrich_1': ['vtcf6k3t', 0],
'friedrich_2': ['vtcf6k3t', 0],
'friedrich_3': ['1v3odhkm', 2],
'group_1': ['1fj7x4dk', 4],
'group_2': ['1fj7x4dk', 1],
'informal_upright_1': ['1fj7x4dk', 2],
'informal_upright_2': ['vtcf6k3t', 0],
'leaning_1': ['vtcf6k3t', 0],
'leaning_2': ['vtcf6k3t', 2],
'mustard_reaching_1': ['29k3qjns', 4],
'mustard_sapling_1': ['vtcf6k3t', 0],
'mustard_sapling_2': ['1fj7x4dk', 3],
'pn_banyan_1': ['1v3odhkm', 3],
'pn_maple_1': ['1fj7x4dk', 0],
'pn_old_1': ['vtcf6k3t', 0],
'pn_pine_1': ['29k3qjns', 2],
'pn_pine_2': ['1v3odhkm', 0],
'pn_pine_3': ['29k3qjns', 0],
'pn_pine_4': ['29k3qjns', 1],
'pn_pine_5': ['vtcf6k3t', 0],
'pn_pine_6': ['vtcf6k3t', 4],
'pn_tall_straight': ['29k3qjns', 3],
'pn_tall_straight_old': ['1v3odhkm', 4],
'raft_1': ['vtcf6k3t', 0],
'raft_2': ['1v3odhkm', 1],
'semi_cascade_1': ['vtcf6k3t', 3],
'semi_cascade_2': ['vtcf6k3t', 1],
'sept_chen_lin_1': ['vtcf6k3t', 0],
'sept_chen_lin_2': ['vtcf6k3t', 0],
'sept_chen_lin_3': ['vtcf6k3t', 0],
'sept_constable_1': ['vtcf6k3t', 0],
'sept_friedrich_4': ['vtcf6k3t', 0],
'sept_friedrich_5': ['vtcf6k3t', 0],
'sept_holten_a': ['vtcf6k3t', 0],
'sept_holten_b': ['vtcf6k3t', 0],
'sept_holten_c': ['vtcf6k3t', 0],
'sept_holten_d': ['vtcf6k3t', 0],
'sept_holten_e': ['vtcf6k3t', 0],
'sept_holten_f': ['vtcf6k3t', 0],
'sept_holten_g': ['vtcf6k3t', 0],
'sept_holten_h': ['vtcf6k3t', 0],
'sept_holten_i': ['vtcf6k3t', 0],
'sept_holten_j': ['vtcf6k3t', 0],
'sept_holten_k': ['vtcf6k3t', 0],
'sept_holten_l': ['vtcf6k3t', 0],
'sept_holten_m': ['vtcf6k3t', 0],
'sept_holten_n': ['vtcf6k3t', 0],
'sept_holten_o': ['vtcf6k3t', 0],
'sept_holten_p': ['vtcf6k3t', 0],
'sept_holten_q': ['vtcf6k3t', 0],
'sept_holten_r': ['vtcf6k3t', 0],
'sept_holten_s': ['vtcf6k3t', 0],
'sept_holten_t': ['vtcf6k3t', 0],
'sept_holten_u': ['vtcf6k3t', 0],
'sept_holten_v': ['vtcf6k3t', 0],
'sept_holten_w': ['vtcf6k3t', 0],
'sept_holten_x': ['vtcf6k3t', 0],
'sept_holten_y': ['vtcf6k3t', 0],
'sept_holten_z': ['vtcf6k3t', 0],
'sept_mondrian_1': ['vtcf6k3t', 0],
'sept_mondrian_2': ['vtcf6k3t', 0],
'sept_mondrian_3': ['vtcf6k3t', 0],
'sept_schiele_1': ['vtcf6k3t', 0],
'sept_schiele_2': ['vtcf6k3t', 0],
'sept_schiele_3': ['vtcf6k3t', 0],
'windswept_1': ['vtcf6k3t', 0],
'windswept_2': ['vtcf6k3t', 0],
'zan_gentlemen_1': ['vtcf6k3t', 0],
'zan_gentlemen_2': ['vtcf6k3t', 0],
'zan_gentlemen_3': ['vtcf6k3t', 0],
'zan_gentlemen_4': ['vtcf6k3t', 0],
'zan_gentlemen_5': ['vtcf6k3t', 0]}

if resume_id == "":
    #find preset_model_id
    preset_model_list = preset_models[selected_tree_class]
    preset_model_id, class_index = preset_model_list
    resume_id = preset_model_id
    print("selected_tree_class:", selected_tree_class)
    print("preset_model_id:", preset_model_id)
else:
    print("resume_id:", resume_id)

project_name = "tree-gan"
print("class_index:", class_index)

In [None]:
#@title other generate parameters

#@markdown number of samples to generate from the loaded model.
test_num_samples = 10    #@param {type:"integer"}

#@markdown OPTIONAL: boolean to choose whether to generate from list of history checkpoints or just the latest checkpoint.
generateMeshHistory = False #@param {type:"boolean"}
#@markdown OPTIONAL: If generateMeshHistory enabled, select "num_selected_checkpoint" checkpoints in the saved history checkpoint list.
num_selected_checkpoint = 4 #@param {type:"integer"}

#@markdown Enter export location (folder/directory).   
export_location = f"/content/drive/MyDrive/IRCMS_GAN_collaborative_database/Experiments/exportObjects/{resume_id}_{str(class_index)}" #@param {type:"string"}

#@markdown whether to zip file and download with browser
browser_download = False #@param {type:"boolean"}

print("test_num_samples:", test_num_samples)
print("export_location:", export_location)

In [None]:
from argparse import Namespace, ArgumentParser
#@markdown Install wandb and log in
entity="bugan"
rev_number = None
if resume_id:
    !pip install wandb
    output.clear()
    #find wandb API key file to auto login
    import wandb
    wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
    wandb_local_netrc_path = Path("/root/.netrc")
    if wandb_drive_netrc_path.exists():
        import shutil

        print("Wandb .netrc file found, will use that to log in.")
        shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
    else:
        print(
            f"Wandb config not found at {wandb_drive_netrc_path}.\n"
            f"Using manual login.\n\n"
            f"To use auto login in the future, finish the manual login first and then run:\n\n"
            f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
            f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
            f"Then that file will be used to login next time.\n"
        )

    !wandb login

    #read information (run config, etc) stored online
        #all config will be replaced by the stored one in wandb
    api = wandb.Api()
    run = api.run(f"{entity}/{project_name}/{resume_id}")
    config = Namespace(**run.config)
        #load selected_model, rev_number in the config
    if hasattr(config, "selected_model"):
        selected_model = config.selected_model
    if hasattr(config, "rev_number"):
        rev_number = config.rev_number

    output.clear()
    print("run id: " + str(run.id))
    print("run name: " + str(run.name))
    print("ok!")

# To just load the model, no edits should be required in any cells below.

In [None]:
#@markdown package and functions

import os
import sys
import subprocess
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from pathlib import Path
os.environ["WANDB_MODE"] = "dryrun"

%pip install --upgrade git+https://github.com/buganart/BUGAN.git#egg=bugan

run_path = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/"

from bugan.trainPL import _get_models
from bugan.functionsPL import netarray2mesh

def install_bugan_package(rev_number=None):
    if rev_number:
        subprocess.check_call(
            [
                sys.executable,
                "-m",
                "pip",
                "install",
                "--upgrade",
                f"git+https://github.com/buganart/BUGAN.git@{rev_number}#egg=bugan",
            ]
        )
    else:
        subprocess.check_call(
            [
                sys.executable,
                "-m",
                "pip",
                "install",
                "--upgrade",
                "git+https://github.com/buganart/BUGAN.git#egg=bugan",
            ]
        )

# load model
def generateFromCheckpoint(selected_model, ckpt_filePath, class_index=None, num_samples=1, package_rev_number=None, output_file_name_dict = {}):
    MODEL_CLASS = _get_models(selected_model)
    try:
        # restore bugan version
        install_bugan_package(rev_number=package_rev_number)
        model = MODEL_CLASS.load_from_checkpoint(ckpt_filePath)
    except Exception as e:
        print(e)
        # try newest bugan version
        install_bugan_package()
        model = MODEL_CLASS.load_from_checkpoint(ckpt_filePath)

    model = model.eval()#.to(device)
    try:
        #assume conditional model
        sample_trees = model.generate_tree(c=class_index, num_trees=test_num_samples)
        output_file_name_dict["class"] = class_index
    except Exception as e:
        print(e)
        print("generate with class label does not work. Now generate without label")
        #assume unconditional model
        sample_trees = model.generate_tree(num_trees=test_num_samples)
    
    save_filename_header = ""
    for k,v in output_file_name_dict.items():
        save_filename_header = save_filename_header + f"_{str(k)}_{str(v)}"

    for n in range(num_samples):
        sample_tree_array = sample_trees[n]
        voxelmesh = netarray2mesh(sample_tree_array)
        save_filename = f"sample_{n}{save_filename_header}.obj"
        if export_location:
            export_path = export_location / save_filename
            voxelmesh.export(file_obj=export_path, file_type="obj")
        #if zip, save also to temp_location 
        if browser_download:
            temp_path = temp_location / save_filename
            voxelmesh.export(file_obj=temp_path, file_type="obj")

output.clear()
print('ok!')

In [None]:
#@markdown generate samples
temp_location = Path("/tmp/generated/")
temp_location.mkdir(parents=True, exist_ok=True)
# copy files to export location
if export_location:
    # create directory
    export_location = Path(export_location)
    export_location.mkdir(parents=True, exist_ok=True)

if generateMeshHistory:
    # find necessary checkpoint file
    epoch_list = []
    epoch_file_dict = {}
    for file in run.files():
        filename = file.name
        if not ".ckpt" in filename:
            continue
        if (filename == "checkpoint.ckpt") or (filename == "checkpoint_prev.ckpt"):
            continue
        file_epoch = str((filename.split("_")[1]).split(".")[0])
        epoch_list.append(int(file_epoch))
        epoch_file_dict[file_epoch] = file

    epoch_list = sorted(epoch_list)
    if len(epoch_list) < num_selected_checkpoint:
        num_selected_checkpoint = len(epoch_list)
    selected_epoch_index = [
        int(i / (num_selected_checkpoint - 1) * (len(epoch_list) - 1) + 0.5)
        for i in range(num_selected_checkpoint)
    ]

    #download checkpoint and generate mesh
    for checkpoint_epoch_index in selected_epoch_index:
        file_epoch = str(epoch_list[checkpoint_epoch_index])
        print(f"generate mesh for epoch {file_epoch}......")
        try:
            ckpt_file = epoch_file_dict[file_epoch]
            ckpt_file.download(replace=True)
            generateFromCheckpoint(selected_model, ckpt_file.name, class_index, test_num_samples, rev_number, {"epoch":file_epoch})
        except Exception as e:
            print(e)
            print("generate mesh for epoch {file_epoch} FAILED !!!")
else:
    try:
        ckpt_file = run.file("checkpoint.ckpt").download(replace=True)
        generateFromCheckpoint(selected_model, ckpt_file.name, class_index, test_num_samples, rev_number, {})
    except Exception as e:
        print(e)
        print("loading from checkpoint.ckpt failed. Try checkpoint_prev.ckpt")
        ckpt_file = run.file("checkpoint_prev.ckpt").download(replace=True)
        generateFromCheckpoint(selected_model, ckpt_file.name, class_index, test_num_samples, rev_number, {})

#zip files in the directory
if not browser_download:
    raise Exception("will not download files from the browser. Terminate here.")

!zip -r /tmp/file.zip /tmp/generated/
from google.colab import files
files.download("/tmp/file.zip")

print("ok!")