<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/notebook_util/latent_space_exploration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

In [None]:
#@markdown Mount google drive.
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

# Check if we have linked the folder
from pathlib import Path
if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
    print(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Go to the google drive web UI\n"
        "\t2. Right click shared folder IRCMS_GAN_collaborative_database and click \"Add shortcut to Drive\""
    )

# $Description:$

This script takes the trained BUGAN model to generate output meshes based on specifying 2 latent vectors [$z_1$, $z_2$], and then latent vectors in a line interval between $z_1$ and $z_2$ will be drawn, and latent space vectors $Z_0, Z_1, ..., Z_{n-1}$ in between will be sampled evenly according to the test_num_samples ($n$) set above.

$n = test\_num\_samples, i \in [0, n-1]$

$Z_i = \frac{i}{n-1} z_1 + \frac{n-1-i}{n-1} z_2$ (Note that $Z_0 = z_1$ and $Z_{n-1} = z_2$)

Finally, new meshs will be produced based on those latent space vectors $Z_0, Z_1, ..., Z_{n-1}$:

$latent \ space \ vector \ Z_i → decoder/generator → generated \ mesh \ m_i$

All meshes $m_0, m_1, ..., m_i$ will be stored in the export_location set above.

## Note:

1. Please have a look at the **Configure notebook setting** below to set BUGAN model using 
    1. wandb run id (**id**), or 
    2. using the path to saved checkpoint file (**ckpt_file_location**) and **selected_model**

2. Please have a look at the cell to set the latent vectors [$z_1$, $z_2$] in the **SET 2 latent vectors for latent space exploration**.

3. If the loaded model is conditional, the **class_index** will also affect the latent vectors. User can also set class indices [$c_1$, $c_2$] in the **SET 2 latent vectors for latent space exploration**.

In [None]:
#@title Configure notebook setting
#@markdown - choose whether to use wandb id or use checkpoint_path to select checkpoint file

#@markdown Choice 1: wandb id and project_name to select checkpoint file
#@markdown - set `"run_id"` to resume a run (for example: `u9imsvva`)
#@markdown - The id of the current run is shown below in the cell with `wandb.init()` or you can find it in the wandb web UI.
id = "25z72k0w" #@param {type:"string"}

#@markdown Choice 2: file path and model type to select saved checkpoint .ckpt file
#@markdown - For example via the file browser on the left to locate and right click to copy the path.
#@markdown - file path example: `/content/drive/My Drive/h/k/checkpoint.ckpt` 
ckpt_file_location = "" #@param {type:"string"}
#@markdown - Enter trained neural network model type
#@markdown - (may be necessary for wandb_id if selected_model is not saved in config)
selected_model = "VAEGAN"    #@param ["VAEGAN", "GAN", "VAE", "WGAN", "WGAN_GP"]

#@markdown Enter how many samples to generate
test_num_samples = 100    #@param {type:"integer"}
#@markdown whether to zip file and download with browser
browser_download = False #@param {type:"boolean"}
#@markdown Enter export location (folder/directory).    
export_location = f"/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/exportObjects/{id}" #@param {type:"string"}



if id and ckpt_file_location:
    raise Exception("Only one of id / ckpt_file_location can be set!")
if (not id) and (not ckpt_file_location):
    raise Exception("Please set id / ckpt_file_location!")

if id:
    print("id:", id)
else:
    print("selected_model:", selected_model)
    print("ckpt_file_location:", ckpt_file_location)
print("test_num_samples:", test_num_samples)
print("export_location:", export_location)


In [None]:
from argparse import Namespace, ArgumentParser
#@markdown Install wandb and log in
rev_number = None
if id:
    !pip install wandb
    output.clear()
    #find wandb API key file to auto login
    import wandb
    wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
    wandb_local_netrc_path = Path("/root/.netrc")
    if wandb_drive_netrc_path.exists():
        import shutil

        print("Wandb .netrc file found, will use that to log in.")
        shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
    else:
        print(
            f"Wandb config not found at {wandb_drive_netrc_path}.\n"
            f"Using manual login.\n\n"
            f"To use auto login in the future, finish the manual login first and then run:\n\n"
            f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
            f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
            f"Then that file will be used to login next time.\n"
        )

    !wandb login

    #read information (run config, etc) stored online
        #all config will be replaced by the stored one in wandb
    api = wandb.Api()
    try:
        project_name = "tree-gan"
        run = api.run(f"bugan/{project_name}/{id}")
    except Exception as e:
        print(e)
        print("set project_name to tree-gan Failed. Try handtool-gan")
        try:
            project_name = "handtool-gan"
            run = api.run(f"bugan/{project_name}/{id}")
        except Exception as e:
            print(e)
            print("set project_name to handtool-gan Failed. Try chair-gan")
            project_name = "chair-gan"
            run = api.run(f"bugan/{project_name}/{id}")


    config = Namespace(**run.config)
        #load selected_model, rev_number in the config
    if hasattr(config, "selected_model"):
        selected_model = config.selected_model
    if hasattr(config, "rev_number"):
        rev_number = config.rev_number

    output.clear()
    print("run id: " + str(run.id))
    print("run name: " + str(run.name))
    print("ok!")

In [None]:
#@markdown package and functions
import numpy as np
import os
import sys
import subprocess
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from pathlib import Path
os.environ["WANDB_MODE"] = "dryrun"

%pip install --upgrade git+https://github.com/buganart/BUGAN.git#egg=bugan

run_path = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/"

from bugan.trainPL import _get_models
from bugan.functionsPL import netarray2mesh, mesh2arrayCentered

def install_bugan_package(rev_number=None):
    if rev_number:
        subprocess.check_call(
            [
                sys.executable,
                "-m",
                "pip",
                "install",
                "--upgrade",
                f"git+https://github.com/buganart/BUGAN.git@{rev_number}#egg=bugan",
            ]
        )
    else:
        subprocess.check_call(
            [
                sys.executable,
                "-m",
                "pip",
                "install",
                "--upgrade",
                "git+https://github.com/buganart/BUGAN.git#egg=bugan",
            ]
        )

# load model
def getModelFromCheckpoint(selected_model, ckpt_filePath):
    # in case loading model failure can be fixed by modifying config params, modify here.
    # config.z_size=512
    MODEL_CLASS = _get_models(selected_model)
    try:
        # try newest bugan version
        install_bugan_package()
        model = MODEL_CLASS.load_from_checkpoint(ckpt_filePath, config=config)
    except Exception as e:
        print(e)
        # restore bugan version
        install_bugan_package(rev_number=package_rev_number)
        model = MODEL_CLASS.load_from_checkpoint(ckpt_filePath, config=config)

    model = model.eval()#.to(device)
    return model

# generateMesh
def generateMesh(model, z, class_index=None):
    test_num_samples = z.shape[0]
    # get generator
    if hasattr(model, "vae"):
        generator = model.vae.vae_decoder
        if class_index is not None:
            embedding_fn = model.vae.embedding
    else:
        generator = model.generator
        if class_index is not None:
            embedding_fn = model.embedding

    z = torch.tensor(z).type_as(generator.gen_fc.weight)
    if class_index is not None:
        if isinstance (class_index,int):
            # turn class vector the same device as z, but with dtype Long
            c = torch.ones(test_num_samples) * class_index
            c = c.type_as(z).to(torch.int64)

            # combine z and c
            z = model.merge_latent_and_class_vector(
                z,
                c,
                model.config.num_classes,
                embedding_fn=embedding_fn,
            )
        else:
            # assume class_index are processed class_vectors
            class_vectors = torch.tensor(class_index).type_as(z)

            # merge with z to be generator input
            z = torch.cat((z, class_vectors), 1)
    
    generated_tree = generator(z)[:, 0, :, :, :]
    generated_tree = generated_tree.detach().cpu().numpy()

    return generated_tree

#for VAEGAN / VAE, get latent vector from the meshfiles
def get_encoded_vector(vae, f_loc, resolution=32):
    f = Path(f_loc)
    m = trimesh.load(f, force="mesh")
    m = mesh2arrayCentered(m, array_length=resolution)
    m = m[np.newaxis,np.newaxis,:,:,:]
    m = torch.Tensor(m).float().type_as(vae.vae_encoder.dis_fc.weight)
    _, z, _ = vae(m, output_all=True)
    return z[0].detach().cpu().numpy()

output.clear()
print('ok!')

In [None]:
#@markdown setup model

# load model
if ckpt_file_location:
    model = getModelFromCheckpoint(selected_model, ckpt_file_location)
else:
    try:
        ckpt_file = run.file("checkpoint.ckpt").download(replace=True)
        model = getModelFromCheckpoint(selected_model, ckpt_file.name)
    except Exception as e:
        print(e)
        print("loading from checkpoint.ckpt failed. Try checkpoint_prev.ckpt")
        ckpt_file = run.file("checkpoint_prev.ckpt").download(replace=True)
        model = getModelFromCheckpoint(selected_model, ckpt_file.name)

z_size = config.z_size

print("latent vector z_size:", z_size)
# print("class_list:", config.class_list)
print("ok!")

# **SET 2 latent vectors for latent space exploration**

## **latent space walk (for VAE/VAEGAN)**

### **Required**:

The code below requires user to select 2 latent vectors [$z_1$, $z_2$] by manually setting the 2 select latent vector (*latent_1 ($z_1$)* and *latent_2 ($z_2$)*) in the following cell based on the parameter **z_size**. 

If user wants to specify latent vectors instead of using default vector, user will need to manually CODE the vector in the *latent_1 ($z_1$)* and *latent_2 ($z_2$)* below.

If the model is conditional, the class vectors will be appended into the latent vectors. For example, with num_test_samples = 100, the latent vector has shape (100, z_size), and the class index is 2, then the class vector of index 2 will be retrieved from the trained embedding layer and has shape (100, c_dim). The resulting latent vector of the model will be (100, z_size+c_dim).


### **Option only available for models with VAE**:

The 2 latent vectors [$z_1$, $z_2$] can be specified using the VAE encoder to process **2 meshfiles**, then *latent_1 ($z_1$)* and *latent_2 ($z_2$)* will be skipped.

The VAE in the model has 2 components: encoder / decoder.

To generate 2 latent vectors [$z_1$, $z_2$]:

$3D mesh 1 → encoder → latent \ space \ vector \ z1$

$3D mesh 2 → encoder → latent \ space \ vector \ z2$

Then, a line between $z_1$ and $z_2$ will be drawn.

### **Option only available for conditional models**:

The 2 class indices [$c_1$, $c_2$] can be specified here, then the corresponding class vectors [$C_1$, $C_2$] will be retrieved from the trained embedding layer in the model.

To generate 2 class vectors [$C_1$, $C_2$]:

class index $c_1$ $ → $model embedding layer$ → $class vector $C_1$

class index $c_2$ $ → $model embedding layer$ → $class vector $C_2$

Then, a line between $C_1$ and $C_2$ will be drawn to be class vectors.

In [None]:
#@markdown # SET 2 latent vectors for latent space exploration

#@markdown ## For VAEGAN/VAE only: Enter 2 mesh file location to specify 2 latent vectors.   
# meshfile1 = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj/obj_files/old_1.obj" #@param {type:"string"}
meshfile1 = "" #@param {type:"string"}
# meshfile2 = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj/obj_files/maple_example2.obj" #@param {type:"string"}
meshfile2 = "" #@param {type:"string"}
#@markdown For all other models, please specify the latent vector in this cell.

#@markdown If meshfile above is set, and the model contains VAE, latent vector values below will be replaced.
# specify latent vector below. 
latent_1 = np.ones(z_size)
latent_2 = np.ones(z_size)*-1

try:
    if meshfile1 and hasattr(model, "vae"):
        latent_1 = get_encoded_vector(model.vae, meshfile1, resolution=config.resolution)
    else:
        print("meshfile1 is empty or model does not contain VAE. Use specified latent_1 vector in the cell.")
except Exception as e:
    print(e)
    print("Error on processing meshfile1, use specified latent_1 vector in the cell.")

try:
    if meshfile2 and hasattr(model, "vae"):
        latent_2 = get_encoded_vector(model.vae, meshfile2, resolution=config.resolution)
    else:
        print("meshfile2 is empty or model does not contain VAE. Use specified latent_2 vector in the cell.")
except Exception as e:
    print(e)
    print("Error on processing meshfile2, use specified latent_2 vector in the cell.")


latent_array = [latent_1, latent_2]
print("latent_1:")
print(latent_1[:20], "......")
print("latent_2:")
print(latent_2[:20], "......")
print(f"linear interpolate {test_num_samples} vectors between these 2 latent vectors")
latent_vectors = np.linspace(latent_array[0], latent_array[1], num=test_num_samples)

#@markdown ## For conditional model, specify class index here.
class_index1 = 2 #@param {type:"integer"}
class_index2 =  2#@param {type:"integer"}


if hasattr(model, "classifier"):
    print("conditional model. Process class index.......")
    if class_index1 == class_index2:
        class_index = class_index1
        print(f"class_index:{class_index1}")
    else:
        class_index = [class_index1, class_index2]

        #process class_index
        if hasattr(model, "vae"):
            embedding_fn = model.vae.embedding
        else:
            embedding_fn = model.embedding
        c_start, c_end = class_index
        c_start = embedding_fn(torch.tensor(c_start)).detach().cpu().numpy()
        c_end = embedding_fn(torch.tensor(c_end)).detach().cpu().numpy()

        class_vectors = np.linspace(c_start, c_end, num=test_num_samples)
        print(f"1st class index:{class_index1}, the class vector is:")
        print(c_start[:20], "......")
        print(f"2nd class index:{class_index2}, the class vector is:")
        print(c_end[:20], "......")
        print(f"linear interpolate {test_num_samples} vectors between these 2 class vectors")
else:
    print("unconditional model. Class index will not be processed.")

In [None]:
#@markdown generate samples
# set saving location
temp_location = Path("/tmp/generated/")
temp_location.mkdir(parents=True, exist_ok=True)
# copy files to export location
if export_location:
    # create directory
    export_location = Path(export_location)
    export_location.mkdir(parents=True, exist_ok=True)

meshes = []
batch_size = 8
loops = int(np.ceil(test_num_samples/batch_size))
for i in range(loops):
    start = i*batch_size
    end = (i+1)*batch_size 
    if end > test_num_samples:
        end = test_num_samples
    latent_vector = latent_vectors[start:end]
    try:
        if isinstance (class_index, int):
            mesh = generateMesh(model, latent_vector, class_index=class_index)
        else:
            class_vector = class_vectors[start:end]
            mesh = generateMesh(model, latent_vector, class_index=class_vector)
    except Exception as e:
        print(e)
        print("try generate with class_index failed. Try generate without class_index")
        mesh = generateMesh(model, latent_vector, class_index=None)
    meshes.append(mesh)
meshes = np.concatenate(meshes)

#store generated meshes to export_location
for n in range(test_num_samples):
    sample_tree_array = meshes[n] > 0
    voxelmesh = netarray2mesh(sample_tree_array)
    save_filename = f"sample_{n}.obj"
    if export_location:
        export_path = export_location / save_filename
        voxelmesh.export(file_obj=export_path, file_type="obj")
    #if zip, save also to temp_location 
    if browser_download:
        temp_path = temp_location / save_filename
        voxelmesh.export(file_obj=temp_path, file_type="obj")


#zip files in the directory
if not browser_download:
    raise Exception("will not download files from the browser. Terminate here.")

!zip -r /tmp/file.zip /tmp/generated/
from google.colab import files
files.download("/tmp/file.zip")