<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/notebook_util/generate_handtool(drive).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

Also, check that the google account you mount the drive should have access to the shared public BUGAN handtool drive here:
https://drive.google.com/drive/folders/1wYhB81kdrVaKf2DAgNt7tDupgg6JyJ12?usp=sharing

Please make sure the \"shared_BUGAN_handtool_folder\" above appears in the \"Shared with me\" tab

In [None]:
#@markdown Mount google drive.
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

# Check if we have linked the folder
from pathlib import Path
if not Path("/content/drive/My Drive/").exists():
    print(
        "Drive (My Drive) not mount! \n\n"
    )

sharedDrive = Path("/content/drive/.shortcut-targets-by-id/1wYhB81kdrVaKf2DAgNt7tDupgg6JyJ12/shared_BUGAN_folder/")
if not sharedDrive.exists():
    print(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Please click the drive link above.\n"
        "\t2. Make sure the \"shared_BUGAN_folder\" appears in the \"Shared with me\" tab.\n"
    )

# Description

This notebook is used for generating 3D mesh objects based on trained runs in the wandb project "bugan/handtool-gan". For training models, please go to [train.ipynb](https://github.com/buganart/BUGAN/blob/master/notebook_util/train.ipynb). 

---



# Instruction for generating handtool mesh
The trained model is loaded from the pretrained saved checkpoint stored in shared drive above. 

After running all the cells in the notebook, the final cell will generate and display 1 handtool mesh. A download button is shown there to download the mesh from the browser.

The mesh is generated randomly from the model. For generating another handtool meshes, just rerun the last cell.

In [None]:
#@markdown ### Post process generated mesh
#@markdown whether to post process meshes
post_process = True #@param {type:"boolean"}
#@markdown remove clusters that has no points in the unit sphere of radius.
#@markdown - casting sphere of radius in the center of the cube voxel space.
#@markdown - keep cluster that has at least 1 point in the sphere, those has no points in the sphere will be discarded.
#@markdown - For resolution=64, radius of the sphere that fit the cube voxel space is 32.
radius=28 #@param {type:"number"}
#@markdown remove clusters that has less than point_threshold points.
point_threshold = 50 #@param {type:"integer"}

temp_location = Path("/tmp/generated/")
temp_location.mkdir(parents=True, exist_ok=True)
mesh_index = 0

# To just load the model, no edits should be required in any cells below.

In [None]:
#@markdown package and functions
import numpy as np
from numpy.random import randint
import os
import sys
import subprocess
import torch
import ipywidgets
from google.colab import files
from IPython.display import display
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from pathlib import Path
os.environ["WANDB_MODE"] = "dryrun"

%pip install --upgrade git+https://github.com/buganart/BUGAN.git#egg=bugan

run_path = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/"

from bugan.trainPL import _get_models
from bugan.functionsPL import netarray2mesh, eval_cluster

def install_bugan_package(rev_number=None):
    if rev_number:
        subprocess.check_call(
            [
                sys.executable,
                "-m",
                "pip",
                "install",
                "--upgrade",
                f"git+https://github.com/buganart/BUGAN.git@{rev_number}#egg=bugan",
            ]
        )
    else:
        subprocess.check_call(
            [
                sys.executable,
                "-m",
                "pip",
                "install",
                "--upgrade",
                "git+https://github.com/buganart/BUGAN.git#egg=bugan",
            ]
        )


def load_model(selected_model, ckpt_filePath, package_rev_number=None):
    MODEL_CLASS = _get_models(selected_model)
    try:
        # restore bugan version
        install_bugan_package(rev_number=package_rev_number)
        model = MODEL_CLASS.load_from_checkpoint(ckpt_filePath)
    except Exception as e:
        print(e)
        # try newest bugan version
        install_bugan_package()
        model = MODEL_CLASS.load_from_checkpoint(ckpt_filePath)

    model = model.eval()#.to(device)
    return model

# load model
def generateFromCheckpoint(model, class_index=None, num_samples=1):
    try:
        #assume conditional model
        sample_trees = model.generate_tree(c=class_index, num_trees=num_samples)
    except Exception as e:
        print(e)
        print("generate with class label does not work. Now generate without label")
        #assume unconditional model
        sample_trees = model.generate_tree(num_trees=num_samples)

    for n in range(num_samples):
        sample_tree_array = sample_trees[n]
        if post_process:
            sample_tree_array = post_process_array(sample_tree_array)
        voxelmesh = netarray2mesh(sample_tree_array)
    return voxelmesh

# post processing array
def cluster_in_sphere(voxel_index_list, center, radius):
    center = np.array(center)
    for v in voxel_index_list:
        v = np.array(v)
        dist = np.linalg.norm(v-center)
        if dist < radius:
            return True
    return False

def post_process_array(boolarray):
    boolarray = boolarray > 0
    cluster = eval_cluster(boolarray)

    #post process
    process_cluster = []
    for l in cluster:
        l = list(l)
        if len(l) < point_threshold:
            continue
        if not cluster_in_sphere(l, np.array(boolarray.shape) / 2, radius):
            continue
        process_cluster.append(l)

    #point form back to array form
    processed_tree = np.zeros_like(boolarray)
    for c in process_cluster:
        for index in c:
            i,j,k = index
            processed_tree[i,j,k] = 1
    return processed_tree

output.clear()
print('ok!')

In [None]:
#@markdown load model
selected_model = "VAEGAN"


sharedFolder = sharedDrive / "handtool/checkpoint"
ckpt_file_list = sorted(sharedFolder.rglob("*.ckpt"))
selected_ckpt_index = randint(len(ckpt_file_list))
selected_ckpt_file = ckpt_file_list[selected_ckpt_index]
model = load_model(selected_model, selected_ckpt_file)

print("checkpoint name:", selected_ckpt_file.stem)

In [None]:
#@markdown generate samples
mesh_index += 1
def on_buttonDownload_clicked(b):
    files.download(export_path)
buttonDownload = ipywidgets.widgets.Button(description="Download")
display(buttonDownload)
buttonDownload.on_click(on_buttonDownload_clicked)

voxelmesh = generateFromCheckpoint(model)
save_filename = f"sample{mesh_index}.obj"
export_path = temp_location / save_filename
voxelmesh.export(file_obj=export_path, file_type="obj")


#fix rendering by copy mesh
voxelmesh = voxelmesh.copy()
voxelmesh.show()