<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/script_GAN_ver2_voxel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#mount google drive
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [2]:
#right click shared folder IRCMS_GAN_collaborative_database and "Add shortcut to Drive" to My drive
%cd drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj/generated_files/
!ls

/content/drive/.shortcut-targets-by-id/1ylB2p6N0qQ-G4OsBuwcZ9C0tsqVu9ww4/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj/generated_files
maple_1_voxel_size03.ply  old_1_voxel_size03.ply  wandb
maple_2_voxel_size03.ply  old_2_voxel_size03.ply
maple_3_voxel_size03.ply  old_3_voxel_size03.ply


In [3]:
!pip install open3d
!pip install wandb -q
output.clear()

#add libraries, and login to wandb

In [4]:
import open3d as o3d
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

# Ignore excessive warnings
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

# WandB – Import the wandb library
import wandb

In [5]:
!wandb login
output.clear()

In [6]:
wandb.init(project="tree-gan")
wandb.run.name = str(wandb.run.entity) +" modelv2 g1d1 "+ str(wandb.run.id) # number N after g and d means lr are 0.0001 * N
wandb.watch_called = False

In [7]:
#keep track of hyperparams
config = wandb.config
config.batch_size = 1
config.test_batch_size = 1
config.epochs = 100
config.g_lr = 0.0001
config.d_lr = 0.0001            
config.seed = 1234
config.log_interval = 10

#dataset

In [8]:
def voxel2arrayCentered(voxel, tree_size_scale = 1):
    array_size = np.array([250, 250, 250])
    vox_array = np.zeros(array_size, dtype=int)  
    tree_size = np.array(voxel.get_axis_aligned_bounding_box().get_extent())
    tree_size = np.ceil(tree_size / tree_size_scale)    #voxel_size = tree_size_scale
    tree_center = (np.ceil(tree_size / 2)).astype(int)

    vox_list = voxel.get_voxels()
    for vox in vox_list:
        coord = vox.grid_index
        #center the tree
        coord = coord - tree_center + (array_size/2)
        coord = tuple(coord.astype(int))

        vox_array[coord] = 1.

    return vox_array.astype(bool)

In [9]:
#process all files in the generated file folder to generate dataset 
import os

dataset = []
for file_name in os.listdir():
    if file_name.endswith("voxel_size03.ply"):
        #note that the voxel_size of vox is 0.3, so we scale it back to one for indexing
        print(file_name)
        dataset.append(voxel2arrayCentered(o3d.io.read_voxel_grid(file_name), 0.3))


maple_2_voxel_size03.ply
maple_3_voxel_size03.ply
maple_1_voxel_size03.ply
old_1_voxel_size03.ply
old_2_voxel_size03.ply
old_3_voxel_size03.ply


In [10]:
dataset = torch.tensor(dataset)
print(torch.unsqueeze(dataset, -1).shape)
tensor_dataset = TensorDataset(torch.unsqueeze(dataset, 1))
dataloader = DataLoader(tensor_dataset, batch_size=config.batch_size)

torch.Size([6, 250, 250, 250, 1])


#model description

In [11]:
#input: 128-d noise vector
#output: (250,250,250) array with values in [0,1]

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.fc_channel = 8 #16
        self.fc_size = 5

        num_unit1 = self.fc_channel   
        num_unit2 = 8   #32
        num_unit3 = 8   #64
        num_unit4 = 8   #128
        num_unit5 = 8   #256
        num_unit6 = 8   #512
        self.gen_fc = nn.Linear(128, num_unit1 * self.fc_size * self.fc_size * self.fc_size)
        self.gen = nn.Sequential(

            nn.ConvTranspose3d(num_unit1, num_unit2, 3, 1),
            nn.BatchNorm3d(num_unit2),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit2, num_unit3, 3, 1),
            nn.BatchNorm3d(num_unit3),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit3, num_unit4, 3, 1),
            nn.BatchNorm3d(num_unit4),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='trilinear'),

            nn.ConvTranspose3d(num_unit4, num_unit5, 3, 1),
            nn.BatchNorm3d(num_unit5),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit5, num_unit5, 3, 1),
            nn.BatchNorm3d(num_unit5),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit5, num_unit6, 3, 1),
            nn.BatchNorm3d(num_unit6),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='trilinear'),


            nn.ConvTranspose3d(num_unit6, num_unit6, 3, 1),
            nn.BatchNorm3d(num_unit6),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit6, num_unit6, 3, 1),
            nn.BatchNorm3d(num_unit6),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='trilinear'),

            nn.ConvTranspose3d(num_unit6, num_unit5, 3, 1),
            nn.BatchNorm3d(num_unit5),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='trilinear'),

            nn.ConvTranspose3d(num_unit5, num_unit3, 3, 1),
            nn.BatchNorm3d(num_unit3),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit3, num_unit1, 3, 1),
            nn.BatchNorm3d(num_unit1),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit1, 1, 3, 1),
            nn.Sigmoid()
        )

    def forward(self, x):

        x = self.gen_fc(x)
        x = x.view(x.shape[0], self.fc_channel, self.fc_size, self.fc_size, self.fc_size)
        x = self.gen(x)
        return x


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        num_unit1 = 8   #16
        num_unit2 = 8   #32
        num_unit3 = 8   #64
        num_unit4 = 8   #128
        num_unit5 = 8   #256
        num_unit6 = 8   #512
        
        self.dis = nn.Sequential(
            nn.Conv3d(1, num_unit1, 3, 1),
            nn.ReLU(True),
            nn.Conv3d(num_unit1, num_unit3, 3, 1),
            nn.ReLU(True),

            nn.MaxPool3d((2, 2, 2)),

            nn.Conv3d(num_unit3, num_unit5, 3, 1),
            nn.ReLU(True),
            nn.Conv3d(num_unit5, num_unit6, 3, 1),
            nn.ReLU(True),

            nn.MaxPool3d((2, 2, 2)),

            nn.Conv3d(num_unit6, num_unit5, 3, 1),
            nn.ReLU(True),
            nn.Conv3d(num_unit5, num_unit4, 3, 1),
            nn.ReLU(True),

            nn.MaxPool3d((2, 2, 2)),

            nn.Conv3d(num_unit4, num_unit3, 3, 1),
            nn.ReLU(True),
            nn.Conv3d(num_unit3, num_unit3, 3, 1),
            nn.ReLU(True),

            nn.MaxPool3d((2, 2, 2)),

            nn.Conv3d(num_unit3, num_unit2, 3, 1),
            nn.ReLU(True),
            nn.Conv3d(num_unit2, num_unit2, 3, 1),
            nn.ReLU(True),
            nn.Conv3d(num_unit2, num_unit1, 3, 1),
            nn.ReLU(True),
            nn.Conv3d(num_unit1, num_unit1, 3, 1),
            nn.ReLU(True),
        )

        self.dis_fc = nn.Sequential(
            nn.Linear(num_unit1 * 3 * 3 * 3, 128),
            nn.LeakyReLU(0.1, True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):

        x = self.dis(x)
        x = x.view(x.shape[0], -1)
        x = self.dis_fc(x)
        return x

In [12]:
# G = Generator().to(device)
# summary(G, (128,))

In [13]:
# D = Discriminator().to(device)
# summary(D, (1, 250, 250, 250))

#functions for pytorch network

In [14]:
def array2voxel(array):
    coord_list = []
    if len(array.shape) == 5:
        array = array[0][0]
    x,y,z = array.shape
    for i in range(x):
        for j in range(y):
            for k in range(z):
                if array[i,j,k] > 0.5:
                    coord_list.append([i,j,k])
    print(len(coord_list))
    if len(coord_list) == 0:
        return np.array([[0,0,0]])  #return at least one point to prevent wandb 3dobject error
    return np.array(coord_list)

def train_model(generator, discriminator, dataloader):   
    generator.to(device)
    discriminator.to(device)
    
    #loss function
    criterion = nn.BCELoss(reduction='sum')
    #optimizer
    dis_optimizer = optim.Adam(discriminator.parameters(), lr=config.d_lr)
    gen_optimizer = optim.Adam(generator.parameters(), lr=config.g_lr)   

    #log models
    wandb.watch(generator, log="all")
    wandb.watch(discriminator, log="all")

    d_losses = []
    g_losses = []
    for epoch in range(config.epochs):
        generator.train()
        discriminator.train()

        d_ep_loss = 0.
        g_ep_loss = 0.
        for dataset_batch in dataloader:
            dis_optimizer.zero_grad()
            gen_optimizer.zero_grad()
            
            dloss, gloss = compute_loss(generator, discriminator, dataset_batch)


            #optimize generator
            gloss.backward()
            gen_optimizer.step()

            #optimize discriminator
            dloss.backward()
            dis_optimizer.step()
          

            #record loss
            d_ep_loss += dloss.detach()  
            g_ep_loss += gloss.detach()

        #after each epoch, record total loss and sample generated obj
        d_losses.append(d_ep_loss)
        g_losses.append(g_ep_loss)
        print("discriminator, epoch"+str(epoch)+" : "+str(d_ep_loss))
        print("generator, epoch"+str(epoch)+" : "+str(g_ep_loss))

        #save model if necessary
        if epoch % config.log_interval == 0:
            sample_tree = array2voxel(generate_tree(generator))

            wandb.log({
            "discriminator loss": d_ep_loss,
            "generator loss": g_ep_loss,
            "sample_tree": wandb.Object3D(sample_tree)})
            torch.save(generator.state_dict(), os.path.join(wandb.run.dir, 'generator.pth'))
            torch.save(discriminator.state_dict(), os.path.join(wandb.run.dir, 'discriminator.pth'))
        else:
            wandb.log({
            "discriminator loss": d_ep_loss,
            "generator loss": g_ep_loss})
    
    #training end, save model again
    torch.save(generator.state_dict(), os.path.join(wandb.run.dir, 'generator.pth'))
    torch.save(discriminator.state_dict(), os.path.join(wandb.run.dir, 'discriminator.pth'))
    print(d_losses)
    print(g_losses)

# this function calculate loss of the model, 
def compute_loss(generator, discriminator, dataset_batch):
    
    #loss function
    criterion = nn.BCELoss(reduction='sum')   


    dataset_batch = dataset_batch[0]
    dataset_batch = dataset_batch.float().to(device)
    batch_size = dataset_batch.shape[0]
        
    #labels
    real_label = torch.unsqueeze(torch.ones(batch_size),1).float().to(device)
    fake_label = torch.unsqueeze(torch.zeros(batch_size),1).float().to(device)

    ############
    #   discriminator
    ############
    #generate fake trees
    z = torch.randn(batch_size, 128).float().to(device) #128-d noise vector
    tree_fake = generator(z)

    #real data (data from dataloader)
    dout_real = discriminator(dataset_batch)
    dloss_real = criterion(dout_real, real_label)
    score_real = dout_real
    #fake data (data from generator)            
    dout_fake = discriminator(tree_fake.clone().detach())   #detach so no double update on the same batch of tree_fake
    dloss_fake = criterion(dout_fake, fake_label)
    score_fake = dout_fake

    #loss function (discriminator classify real data vs generated data)
    dloss = dloss_real + dloss_fake

    ############
    #   generator
    ############

    #tree_fake is already computed above
    dout_fake = discriminator(tree_fake)
    #generator should generate trees that discriminator think they are real
    gloss = criterion(dout_fake, real_label)    

    return dloss, gloss


def save_model(generator, discriminator, g_path = os.path.join(wandb.run.dir, 'generator.pth') , d_path = os.path.join(wandb.run.dir, 'discriminator.pth')):
    torch.save(generator.state_dict(), g_path)
    torch.save(discriminator.state_dict(), d_path)

def load_model(g_path = os.path.join(wandb.run.dir, 'generator.pth'), d_path = os.path.join(wandb.run.dir, 'discriminator.pth')):
    generator = Generator()
    generator.load_state_dict(torch.load(g_path))
    generator.eval()

    discriminator = Discriminator()
    discriminator.load_state_dict(torch.load(d_path))
    discriminator.eval()
    return generator, discriminator

def generate_tree(generator, num_trees = 1):
    
    #generate noise vector
    z = torch.randn(num_trees, 128).to(device)
    generator.to(device).eval()
    tree_fake = generator(z)
    return tree_fake.detach().cpu().numpy()

#train

In [None]:
#set seed
torch.manual_seed(config.seed)
torch.autograd.set_detect_anomaly(True)

G = Generator()
D = Discriminator()
train_model(G, D, dataloader)        #if dataloader has only 1 tree, the training time is 72s per epoch.

  "See the documentation of nn.Upsample for details.".format(mode))


discriminator, epoch0 : tensor(8.3226, device='cuda:0')
generator, epoch0 : tensor(3.9907, device='cuda:0')
15625000


  elif 'type' in data_or_path:


discriminator, epoch1 : tensor(8.3237, device='cuda:0')
generator, epoch1 : tensor(3.9730, device='cuda:0')
discriminator, epoch2 : tensor(8.3251, device='cuda:0')
generator, epoch2 : tensor(3.9531, device='cuda:0')
