<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/Visualize_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

In [None]:
#@markdown Mount google drive.
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

# Check if we have linked the folder
from pathlib import Path
if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
    print(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Go to the google drive web UI\n"
        "\t2. Right click shared folder IRCMS_GAN_collaborative_database and click \"Add shortcut to Drive\""
    )

In [None]:
#@markdown Install wandb and log in
%pip install wandb
output.clear()
import wandb
wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
wandb_local_netrc_path = Path("/root/.netrc")
if wandb_drive_netrc_path.exists():
    import shutil

    print("Wandb .netrc file found, will use that to log in.")
    shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
else:
    print(
        f"Wandb config not found at {wandb_drive_netrc_path}.\n"
        f"Using manual login.\n\n"
        f"To use auto login in the future, finish the manual login first and then run:\n\n"
        f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
        f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
        f"Then that file will be used to login next time.\n"
    )

!wandb login
output.clear()
print("ok!")

In [None]:
#@title Configure dataset
#@markdown Enter project name (either `chair-gan`, `handtool-gan` or `tree-gan`)
project_name = "tree-gan" #@param ["tree-gan", "handtool-gan", "chair-gan"]
#@markdown Enter dataset location.  
#@markdown - For example via the file browser on the left to locate and right click to copy the path.)
#@markdown - zipfile example: `/content/drive/My Drive/h/k/a.zip`
#@markdown - file folder example: `/content/drive/My Drive/h/k` 
data_location = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj_auto_generated/sessions/simplified/tree-session-2020-09-14_23-23-Friedrich_2-target-face-num-1000.zip" #@param {type:"string"}
#@markdown - choose rotation augmentation on-the-fly 
#@markdown (augmentation only support file folder in data_location)
data_augmentation = False    #@param {type:"boolean"}
aug_rotation_type = "random rotation"  #@param ["random rotation", "axis rotation"]
#@markdown - specify the rotation axis [x,y,z] (only for aug_rotation_type = "axis rotation")
rotation_axis_x = 0    #@param {type:"number"}
rotation_axis_y = 1    #@param {type:"number"}
rotation_axis_z = 0    #@param {type:"number"}

#@markdown - resolution of the voxelized array (shape resolution**3)
resolution = "32"    #@param [32, 64]

#@markdown WANDB log
#@markdown - how many samples to log per data batch (size=32)
log_num_samples = 4 #@param {type:"integer"}


#adjust parameter datatype
resolution = int(resolution)

colab_config = {
    "aug_rotation_type": aug_rotation_type,
    "data_augmentation": data_augmentation,
    "aug_rotation_axis": (rotation_axis_x,rotation_axis_y,rotation_axis_z),
    "data_location": data_location,
    "project_name": project_name,
    "log_num_samples": log_num_samples,
    "resolution": resolution
}

for k, v in colab_config.items():
    print(f"=> {k:20}: {v}")


# To just train a model, no edits should be required in any cells below.

In [None]:
import os
from pathlib import Path
# os.environ["WANDB_MODE"] = "dryrun"

%cd /content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/
if project_name == "tree-gan":
    %cd colab-treegan/
elif project_name == "handtool-gan":
    %cd colab-handtool/
else:
    %cd colab-chair/

dataset_path = Path(data_location)
run_path = "./"

!apt-get update

%pip install pytorch-lightning
%pip install trimesh
!apt install -y xvfb
%pip install trimesh xvfbwrapper
output.clear()
print('ok!')

In [None]:
import io
from io import BytesIO
import sys
import zipfile
import trimesh
import numpy as np
from argparse import Namespace, ArgumentParser

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

from xvfbwrapper import Xvfb

In [None]:
entity="bugan"
#keep track of hyperparams
config = Namespace(**colab_config)
#datamodule config
config.batch_size = 32

In [None]:
run_id = wandb.util.generate_id()

run = wandb.init(project=project_name, id=run_id, entity=entity, resume=True, dir=run_path, group="Visualize Data")

print("run id: " + str(wandb.run.id))
print("run name: " + str(wandb.run.name))
wandb.watch_called = False

#dataset

In [None]:
### load our package
#directly install using pip
print("loading BUGAN package latest")
%pip install --upgrade git+https://github.com/buganart/BUGAN.git#egg=bugan

import bugan
#EXTRACT package version
    #switch stdout to temperary stringIO
old_stdout = sys.stdout
temp_stdout = io.StringIO()
sys.stdout = temp_stdout
    #get version
%pip freeze | grep bugan
version = temp_stdout.getvalue()
rev_number = version.split("+g")[1].rstrip()
    #switch back stdout
sys.stdout = old_stdout
output.clear()

print("bugan package version: ",version)
print("bugan package revision number: ",rev_number)
config.rev_number = rev_number


from bugan.functionsPL import *

###     load dataset

dataset_path = Path(config.data_location)
if config.data_location.endswith(".zip"):
    config.dataset = dataset_path.stem
else:
    config.dataset = "dataset_array_custom"

dataModule = DataModule_process(config, run, dataset_path)
config.num_data = dataModule.size

print("dataset name: ",config.dataset)
print("dataset path: ",dataset_path)

#base model

In [None]:
class PrintDatasetModel(pl.LightningModule):

    def __init__(self, config):
        super(PrintDatasetModel, self).__init__()
        self.config = config
        self.layer = nn.Linear(1, 1)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.01)

    def training_step(self, dataset_batch, batch_idx):
        initial_log_dict={}

        dataset_batch = dataset_batch[0] * 2 - 1
        dataset_batch = dataset_batch.float().detach().cpu().numpy() 

        sample_trees = dataset_batch[:, 0, :, :, :]
        # log_dict list record
        sample_tree_numpoints = []
        eval_num_cluster = []
        sample_tree_image = []
        sample_tree_voxelmesh = []
        for n in range(self.config.log_num_samples):
            # sample_trees are before sigmoid
            sample_tree_bool_array = sample_trees[n] > 0
            # log number of points to wandb
            sample_tree_indices = netarray2indices(sample_tree_bool_array)
            sample_tree_numpoints.append(sample_tree_indices.shape[0])
            # count number of cluster in the tree (grouped with dist_inf = 1)
            num_cluster = eval_count_cluster(sample_tree_bool_array)
            eval_num_cluster.append(num_cluster)

            voxelmesh = netarray2mesh(sample_tree_bool_array)

            # image / 3D object to log_dict
            image = mesh2wandbImage(voxelmesh)
            if image is not None:
                sample_tree_image.append(image)
            voxelmeshfile = mesh2wandb3D(voxelmesh)
            sample_tree_voxelmesh.append(voxelmeshfile)

        # add list record to log_dict
        initial_log_dict["sample_tree_numpoints"] = np.mean(sample_tree_numpoints)
        initial_log_dict["eval_num_cluster"] = np.mean(eval_num_cluster)
        initial_log_dict["sample_tree_image"] = sample_tree_image
        initial_log_dict["sample_tree_voxelmesh"] = sample_tree_voxelmesh

        wandb.log(initial_log_dict)

        #construct loss
        loss = self.layer(torch.zeros(1))
        return loss - loss


#train

In [None]:
#render setup
vdisplay = Xvfb()
vdisplay.start()

#wandb logger setup
wandb_logger = WandbLogger(experiment=run, log_model=True)
#log config
wandb.config.update(config)
wandb.jupyter.Notebook(Namespace(save_code=True)).save_history()
wandb.save(os.path.join(wandb.run.dir, "code", "_session_history.ipynb"), base_path=wandb.run.dir)

#model
model = PrintDatasetModel(config)

trainer = pl.Trainer(max_epochs = 1, logger=wandb_logger,\
                     default_root_dir=wandb.run.dir, checkpoint_callback = None)


#train
trainer.fit(model, dataModule)