<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/notebook_util/generate_tree(drive).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

Also, check that the google account you mount the drive should have access to the shared public BUGAN drive here:
https://drive.google.com/drive/folders/1wYhB81kdrVaKf2DAgNt7tDupgg6JyJ12?usp=sharing

Please make sure the \"shared_BUGAN_folder\" above appears in the \"Shared with me\" tab

In [None]:
#@markdown Mount google drive.
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

# Check if we have linked the folder
from pathlib import Path
if not Path("/content/drive/My Drive/").exists():
    print(
        "Drive (My Drive) not mount! \n\n"
    )

sharedDrive = Path("/content/drive/.shortcut-targets-by-id/1wYhB81kdrVaKf2DAgNt7tDupgg6JyJ12/shared_BUGAN_folder/")
if not sharedDrive.exists():
    print(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Please click the drive link above.\n"
        "\t2. Make sure the \"shared_BUGAN_folder\" appears in the \"Shared with me\" tab.\n"
    )


# Description

This notebook is used for generating 3D mesh objects based on trained runs in the wandb project "bugan/tree-gan". For training models, please go to [train.ipynb](https://github.com/buganart/BUGAN/blob/master/notebook_util/train.ipynb). 

User will need to specify **one trained model** through run_id/checkpoint_file that are trained in bugan train notebook above and provide the **number of samples to generate**, then the notebook will generate samples to the **export_location** specified in the panel.



---



# Instruction for specify trained model
The trained model is loaded from the saved checkpoint. 

To specify trained model:

> set **selected_tree_category**, which will load the preset *`run_id`* and *`class_index`* based on the **selected_tree_category**.


<!-- 1.   set **run_id** and **class_index** directly, then the saved checkpoint of the run will be downloaded -->

<!-- 3.   specify **checkpoint_path** and **class_index** -->

Note that the **class_index** is necessary if the loaded model is conditional. However, if the model is unconditional, setting **class_index** will not cause failure on generating 3D objects.



---



# Instruction for other generate parameters
 

**test_num_samples**: 
*   number of samples to generate from the loaded model.

*   If the model is unconditional, it will generate **[test_num_samples]** samples.

*   If the model is conditional, it will generate **[test_num_samples]** samples of class **[class_index]**.

**generateMeshHistory**: 
*   boolean to choose whether to generate from list of history checkpoints or just the latest checkpoint.

*   In the model training, if *`history_checkpoint_frequency`* is set, then the model will save history checkpoint based on the number of epoch (e.g. save 1 checkpoint per 40 epochs). As that lots of history checkpoints will be saved, only **num_selected_checkpoint** checkpoints will be selected in the list. Enable this will make the model to generate **[test_num_samples]** samples per selected history checkpoint.

**num_selected_checkpoint**: 

*   If the **generateMeshHistory** option is True, select **num_selected_checkpoint** checkpoints in the saved history checkpoint list. 

**export_location**:
*   the export location of generated 3D meshes in the Drive.

**browser_download**:
*   if **browser_download** is True, after generating meshes into **export_location**, the generated meshes will also be zipped into a zip file, and the browser will download the zip file.
*   Note that the **browser_download** is implemented using colab `files.download()`. The download processing speed will be very slow if the zip file is **> 300 MB** (1 mesh is around 4MB).
*   Also, if **browser_download** is False, `Exception: will not download files from the browser. Terminate here.` will occur, but it is normal, and the meshes are already saved in the **export_location**.



 
















<img src="https://drive.google.com/uc?id=1KId8CncC6Ze_BsiUP-YPl-jhjSFFNCQy"></img>

In [None]:
#@title #Specify Trained Model
#@markdown - choose which option to specify trained model

# #@markdown Choice 1: wandb run_id and to select checkpoint file (leave empty if you want to use preset model option 2)
# #@markdown - set `"run_id"` to resume a run (for example: `u9imsvva`)
# #@markdown - The id of the current run is shown below in the cell with `wandb.init()` or you can find it in the wandb web UI.
# resume_id = "3mzlr9u3" #@param {type:"string"}
# #@markdown - select class index for conditional model (will be ignored for choice 2 or if the model is unconditional)
# class_index = 3 #@param {type:"integer"}
#(please make sure resume_id is empty)
#@markdown generate tree from preset models based on class 
#@markdown - Select tree class: load the preset run_id and class_index based on the selected_tree_category 
selected_tree_category = "zan_gentlemen"    #@param ['double_trunk', 'formal_upright', 'friedrich', 'group', 'informal_upright', 'leaning', 'mustard_reaching', 'mustard_sapling', 'banyan', 'maple', 'old', 'pine', 'tall_straight', 'raft', 'semi_cascade', 'chen_lin', 'constable', 'holten', 'mondrian', 'schiele', 'windswept', 'zan_gentlemen']

class_categories_dict = {'double_trunk': ['double_trunk_1', 'double_trunk_2'],
                         'formal_upright': ['formal_upright_1', 'formal_upright_2'],
                         'friedrich': ['friedrich_1','friedrich_2','friedrich_3','sept_friedrich_4', 'sept_friedrich_5'],
                         'group': ['group_1', 'group_2'],
                         'informal_upright': ['informal_upright_1', 'double_trunk_2'],
                         'leaning': ['leaning_1', 'leaning_2'],
                         'mustard_reaching': ['mustard_reaching_1'],
                         'mustard_sapling': ['mustard_sapling_1', 'mustard_sapling_2'],
                         'banyan': ['pn_banyan_1'],
                         'maple': ['pn_maple_1'],
                         'old': ['pn_old_1'],
                         'pine': ['pn_pine_1', 'pn_pine_2', 'pn_pine_3','pn_pine_4','pn_pine_5','pn_pine_6'],
                         'tall_straight': ['pn_tall_straight', 'pn_tall_straight_old'],
                         'raft': ['raft_1', 'raft_2'],
                         'semi_cascade': ['semi_cascade_1', 'semi_cascade_2'],
                         'chen_lin': ['sept_chen_lin_1', 'sept_chen_lin_2', 'sept_chen_lin_3'],
                         'constable': ['sept_constable_1'],
                         'holten': ['sept_holten_a', 'sept_holten_b', 'sept_holten_c','sept_holten_d','sept_holten_e','sept_holten_f', 'sept_holten_g', 'sept_holten_h' , 'sept_holten_i', 'sept_holten_j', 'sept_holten_k', 'sept_holten_l', 'sept_holten_m', 'sept_holten_n', 'sept_holten_o', 'sept_holten_p', 'sept_holten_q', 'sept_holten_r', 'sept_holten_s', 'sept_holten_t', 'sept_holten_u', 'sept_holten_v', 'sept_holten_w', 'sept_holten_x', 'sept_holten_y', 'sept_holten_z'],
                         'mondrian': ['sept_mondrian_1', 'sept_mondrian_2', 'sept_mondrian_3'],
                         'schiele': ['sept_schiele_1', 'sept_schiele_2', 'sept_schiele_3'],
                         'windswept': ['windswept_1', 'windswept_2'],
                         'zan_gentlemen': ['zan_gentlemen_5'],#['zan_gentlemen_1', 'zan_gentlemen_2', 'zan_gentlemen_3', 'zan_gentlemen_4', 'zan_gentlemen_5'],
                         }


from numpy.random import randint
#from selected_tree_category, find tree_class_list
tree_class_list = class_categories_dict[selected_tree_category]
selected_tree_class_index = randint(len(tree_class_list))
selected_tree_class = tree_class_list[selected_tree_class_index]
#from selected_tree_class, find preset_model
preset_model_checkpoint = str(selected_tree_class) + ".ckpt"
print("selected_tree_class:", selected_tree_class)
print("preset_model_checkpoint:", preset_model_checkpoint)

project_name = "tree-gan"
selected_model = "VAEGAN"

In [None]:
#@title ##other generate parameters

#@markdown number of samples to generate from the loaded model.
test_num_samples = 10    #@param {type:"integer"}

# #@markdown OPTIONAL: boolean to choose whether to generate from list of history checkpoints or just the latest checkpoint.
# generateMeshHistory = False #@param {type:"boolean"}
# #@markdown OPTIONAL: If generateMeshHistory enabled, select "num_selected_checkpoint" checkpoints in the saved history checkpoint list.
# num_selected_checkpoint = 4 #@param {type:"integer"}

#@markdown Enter export location (folder/directory).   
export_location = f"/content/drive/MyDrive/exportTree/{selected_tree_category}" #@param {type:"string"}

#@markdown whether to zip file and download with browser
browser_download = False #@param {type:"boolean"}

print("test_num_samples:", test_num_samples)
print("export_location:", export_location)

# To just load the model, no edits should be required in any cells below.

In [None]:
#@markdown package and functions

import os
import sys
import subprocess
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from pathlib import Path
os.environ["WANDB_MODE"] = "dryrun"

%pip install --upgrade git+https://github.com/buganart/BUGAN.git#egg=bugan

run_path = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/"

from bugan.trainPL import _get_models
from bugan.functionsPL import netarray2mesh

def install_bugan_package(rev_number=None):
    if rev_number:
        subprocess.check_call(
            [
                sys.executable,
                "-m",
                "pip",
                "install",
                "--upgrade",
                f"git+https://github.com/buganart/BUGAN.git@{rev_number}#egg=bugan",
            ]
        )
    else:
        subprocess.check_call(
            [
                sys.executable,
                "-m",
                "pip",
                "install",
                "--upgrade",
                "git+https://github.com/buganart/BUGAN.git#egg=bugan",
            ]
        )

# load model
def generateFromCheckpoint(selected_model, ckpt_filePath, class_index=None, num_samples=1, package_rev_number=None, output_file_name_dict = {}):
    MODEL_CLASS = _get_models(selected_model)
    try:
        # restore bugan version
        install_bugan_package(rev_number=package_rev_number)
        model = MODEL_CLASS.load_from_checkpoint(ckpt_filePath)
    except Exception as e:
        print(e)
        # try newest bugan version
        install_bugan_package()
        model = MODEL_CLASS.load_from_checkpoint(ckpt_filePath)

    model = model.eval()#.to(device)
    try:
        #assume conditional model
        sample_trees = model.generate_tree(c=class_index, num_trees=test_num_samples)
        output_file_name_dict["class"] = class_index
    except Exception as e:
        print(e)
        print("generate with class label does not work. Now generate without label")
        #assume unconditional model
        sample_trees = model.generate_tree(num_trees=test_num_samples)
    
    save_filename_header = ""
    for k,v in output_file_name_dict.items():
        save_filename_header = save_filename_header + f"_{str(k)}_{str(v)}"

    for n in range(num_samples):
        sample_tree_array = sample_trees[n]
        voxelmesh = netarray2mesh(sample_tree_array)
        save_filename = f"sample_{n}{save_filename_header}.obj"
        if export_location:
            export_path = export_location / save_filename
            voxelmesh.export(file_obj=export_path, file_type="obj")
        #if zip, save also to temp_location 
        if browser_download:
            temp_path = temp_location / save_filename
            voxelmesh.export(file_obj=temp_path, file_type="obj")

output.clear()
print('ok!')

In [None]:
#@markdown generate samples
temp_location = Path("/tmp/generated/")
temp_location.mkdir(parents=True, exist_ok=True)
# copy files to export location
if export_location:
    # create directory
    export_location = Path(export_location)
    export_location.mkdir(parents=True, exist_ok=True)

sharedFolder = sharedDrive / "tree/checkpoint"
ckpt_file = sharedFolder / preset_model_checkpoint
generateFromCheckpoint(selected_model, ckpt_file, None, test_num_samples, None, {})


#zip files in the directory
if not browser_download:
    raise Exception("will not download files from the browser. Terminate here.")

!zip -r /tmp/file.zip /tmp/generated/
from google.colab import files
files.download("/tmp/file.zip")

print("ok!")