<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/custom_dataset_of_script_VAEGAN_voxelsize1_lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#mount google drive
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#Information

In [None]:
#set id = None if you want to start a new run
#set id = "run_id" if you want to resume a run (for example: id = "u9imsvva")
#id of the current run is shown below in the cell with wandb.init() (the cell 6)
id = None
#please enter project name (either "handtool-gan" or "tree-gan")
project_name = "tree-gan"
#enter zip file location after drive/My Drive/ (for example: drive/My Drive/h/a.zip, then enter "h/a.zip")
zipfile_location = "IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj_auto_generated/sessions/simplified/tree-session-2020-09-14_23-23-Friedrich_2-target-face-num-1000.zip"  

#env

In [None]:
#right click shared folder IRCMS_GAN_collaborative_database and "Add shortcut to Drive" to My drive
%cd drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/
if project_name == "tree-gan":
    %cd colab-treegan/
else:
    %cd colab-handtool/

#record paths to resources
zipfile_loc_list = zipfile_location.split("/")
zipfile_name = zipfile_loc_list[-1]
zipfile_title = ".".join(zipfile_name.split(".")[:-1])

# data_path = "../../../../../My Drive/Hand-Tool-Data-Set/"    #take care of .shortcut-targets-by-id/"folder-id"/ folders
data_path = "../../../../../../drive/My Drive/" + ("/".join(zipfile_loc_list[:-1]))
run_path = "./"

!apt-get update

!pip install pytorch-lightning
!pip install trimesh
!pip install wandb==0.9.7
!apt install -y xvfb
!pip install trimesh xvfbwrapper
output.clear()

#add libraries, and login to wandb

In [None]:
import io
from io import BytesIO
import os
import zipfile
import trimesh
import numpy as np


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

# Ignore excessive warnings
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

# WandB – Import the wandb library
import wandb

from xvfbwrapper import Xvfb

In [None]:
!wandb login
output.clear()

In [None]:
#id None to start a new run. For resuming run, put the id of the run below
# id = None
resume = False
if id is None:
    id = wandb.util.generate_id()
else:
    resume = True

run = wandb.init(project=project_name, id=id, entity="bugan", resume=True, dir=run_path)
print("run id: " + str(wandb.run.id))
print("run name: " + str(wandb.run.name))
wandb.watch_called = False
wandb.run.save_code = True

[34m[1mwandb[0m: Wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


run id: 11jq3pv3
run name: 11jq3pv3


In [None]:
#keep track of hyperparams
config = wandb.config

config.batch_size = 8
config.array_size = 32

config.z_size = 128
config.gen_num_layer_unit = [256, 1024, 512, 128]
config.dis_num_layer_unit = [32, 64, 128, 128]
config.leakyReLU = False    #leakyReLU implementation still not in modelPL
config.balance_voxel_in_space = False

config.epochs = 3000
config.vae_lr = 0.0025
config.vae_encoder_layer = 1
config.vae_decoder_layer = 2
config.d_lr = 0.00005            
config.d_layer = 1
config.vae_recon_loss_factor = 1
config.seed = 1234
config.log_image_interval = 5
config.log_mesh_interval = 50
config.data_augmentation = True
config.num_augment_data = 4

config.vae_opt = "Adam"
config.dis_opt = "Adam"

[34m[1mwandb[0m: Wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


#dataset

In [None]:
### load our package

#clone then install
# !git clone https://github.com/buganart/BUGAN repo
# !pip install -e ./repo/
# import site
# site.main()

#directly install using pip
!pip install git+https://github.com/buganart/BUGAN.git#egg=bugan
output.clear()


from bugan.functionsPL import *
from bugan.modelsPL import VAEGAN, VAE, Discriminator, Generator

# from functionsPL import *
# from modelsPL import VAEGAN, VAE, Discriminator, Generator

run.tags.append("VAEGAN")
run.group = "VAEGAN"

###     load dataset
np.random.seed(config.seed)
# dataModule = DataModule(config, run)
# config.num_data = dataModule.size

class DataModule_custom(pl.LightningDataModule):

    def __init__(self, config, run, filepath):
        super().__init__()
        self.config = config
        self.run = run
        self.dataset_artifact = None
        self.dataset = None
        self.size = 0
        self.filepath = filepath

    def prepare_data(self):
        #process zipfile path information
        zipfile_loc_list = self.filepath.split("/")
        zipfile_name = zipfile_loc_list[-1]
        zipfile_title = ".".join(zipfile_name.split(".")[:-1])
            #check if exist .npy (the npy and zip files should be in the same location)
        npy_path = "/".join(zipfile_loc_list[:-1]) + zipfile_title + ".npy"
        if os.path.isfile(npy_path):
            print(zipfile_title + ".npy file already exists!")
            self.filepath = npy_path
            return
        #process zip file
        zf = zipfile.ZipFile(self.filepath, 'r')
        #array to hold process information
        data = []
        failed = []
        dataset_array = []
        for file_name in zf.namelist():
            if file_name.endswith(".obj"):
                try:
                    # print(file_name)
                    file = zf.open(file_name, "r")
                    file = BytesIO(file.read())
                    m = trimesh.load(file, file_type="obj", force='mesh')
                    array = mesh2arrayCentered(m, array_length = 32)
                    # #get filename that can be read by trimesh
                    data.append(file_name)
                    dataset_array.append(array)
                except IndexError:
                    failed.append(file_name)
                    print(file_name+" failed")
        #save as numpy array
        dataset_array = np.stack(dataset_array, axis=0)
        np.save(npy_path, dataset_array)
        self.filepath = npy_path
        print("processed dataset_array shape: " + str(dataset_array.shape))
        print("number of failed data: " + str(len(failed)))
        return

    def setup(self, stage=None):
        config = self.config
        dataset = np.load(self.filepath)
                
        #now all the returned array contains multiple samples
        self.size = dataset.shape[0]
        self.dataset = torch.unsqueeze(torch.tensor(dataset), 1)

    def train_dataloader(self):
        config = self.config
        tensor_dataset = TensorDataset(self.dataset)
        return DataLoader(tensor_dataset, batch_size=config.batch_size, shuffle=True)
        
config.dataset = zipfile_title
dataModule = DataModule_custom(config, run, os.path.join(data_path, zipfile_name))
config.num_data = dataModule.size

[34m[1mwandb[0m: Wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


#train

In [None]:
#set seed
torch.manual_seed(config.seed)
torch.autograd.set_detect_anomaly(True)

#render setup
vdisplay = Xvfb()
vdisplay.start()

#wandb logger setup
wandb_logger = WandbLogger(experiment=run, log_model=True)

checkpoint_path = os.path.join(wandb.run.dir, 'checkpoint.ckpt')

if resume:
    #get file from the wandb cloud
    load_checkpoint_from_cloud(checkpoint_path = 'checkpoint.ckpt')
    #restore training state completely
    trainer = pl.Trainer(max_epochs = config.epochs, logger=wandb_logger, checkpoint_callback = None, resume_from_checkpoint = checkpoint_path)
else:
    trainer = pl.Trainer(max_epochs = config.epochs, logger=wandb_logger, checkpoint_callback = None)

#model
vaegan = VAEGAN(config, trainer, save_model_path = checkpoint_path).to(device)
wandb_logger.watch(vaegan)

trainer.fit(vaegan, dataModule)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores


tree-session-2020-09-14_23-23-Friedrich_2-target-face-num-1000.npy file already exists!


[34m[1mwandb[0m: Wandb version 0.10.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade

  | Name          | Type          | Params
------------------------------------------------
0 | vae           | VAE           | 62 M  
1 | discriminator | Discriminator | 1 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

  "See the documentation of nn.Upsample for details.".format(mode))
        TrainResult and EvalResult were deprecated in 0.9.1 and support will drop in 1.0.0.
        Use self.log and .write from the LightningModule to log metrics and write predictions.
        training_step can now only return a scalar (for the loss) or a dictionary with anything you want.

        Option 1:
        return loss

        Option 2:
        return {'loss': loss, 'anything_else': ...}

        Option 3:
        return {'loss': loss, 'hiddens': hiddens, 'anything_else': ...}
            







1