<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/script_GAN_ver3_voxelsize1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#mount google drive
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [2]:
#right click shared folder IRCMS_GAN_collaborative_database and "Add shortcut to Drive" to My drive
%cd drive/My Drive/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj/generated_files/
!ls

/content/drive/.shortcut-targets-by-id/1ylB2p6N0qQ-G4OsBuwcZ9C0tsqVu9ww4/IRCMS_GAN_collaborative_database/Research/Peter/Tree_3D_models_obj/generated_files
old_1_voxel_size1.ply  old_2_voxel_size1.ply  old_3_voxel_size1.ply  wandb


In [3]:
!pip install open3d
!pip install wandb -q
output.clear()

#add libraries, and login to wandb

In [5]:
import open3d as o3d
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

# Ignore excessive warnings
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

# WandB – Import the wandb library
import wandb

In [6]:
!wandb login
output.clear()

In [7]:
wandb.init(project="tree-gan")
wandb.run.name = str(wandb.run.entity) +" modelv3 AUG g0.05d0.01 "+ str(wandb.run.id) # number N after g and d means lr are 0.0001 * N
wandb.watch_called = False

In [8]:
#keep track of hyperparams
config = wandb.config
config.batch_size = 3
config.test_batch_size = 1
config.epochs = 10000
config.g_lr = 0.000005
config.d_lr = 0.000001           
config.seed = 1234
config.log_interval = 200

#dataset

In [9]:
def voxel2arrayCentered(voxel, tree_size_scale = 1):
    array_size = np.array([64, 64, 64])
    vox_array = np.zeros(array_size, dtype=int)  
    tree_size = np.array(voxel.get_axis_aligned_bounding_box().get_extent())
    tree_size = np.ceil(tree_size / tree_size_scale)    #voxel_size = tree_size_scale
    tree_center = (np.ceil(tree_size / 2)).astype(int)

    vox_list = voxel.get_voxels()
    for vox in vox_list:
        coord = vox.grid_index
        #center the tree
        coord = coord - tree_center + (array_size/2)
        coord = tuple(coord.astype(int))

        vox_array[coord] = 1.

    return vox_array.astype(bool)

In [10]:
#process all files in the generated file folder to generate dataset 
import os

dataset = []
for file_name in os.listdir():
    if file_name.endswith("voxel_size1.ply"):
        #note that the voxel_size of vox is 0.3, so we scale it back to one for indexing
        print(file_name)
        dataset.append(voxel2arrayCentered(o3d.io.read_voxel_grid(file_name), 1))


old_1_voxel_size1.ply
old_2_voxel_size1.ply
old_3_voxel_size1.ply


In [11]:
dataset = torch.tensor(dataset)
print(torch.unsqueeze(dataset, -1).shape)
tensor_dataset = TensorDataset(torch.unsqueeze(dataset, 1))

#augment data

dataloader = DataLoader(tensor_dataset, batch_size=config.batch_size)

torch.Size([3, 64, 64, 64, 1])


#model description

In [12]:
#input: 128-d noise vector
#output: (250,250,250) array with values in [0,1]

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.fc_channel = 8 #16
        self.fc_size = 4

        num_unit1 = self.fc_channel   
        num_unit2 = 16   #32
        num_unit3 = 32   #64
        num_unit4 = 64   #128
        num_unit5 = 128   #256
        num_unit6 = 256   #512
        self.gen_fc = nn.Linear(128, num_unit1 * self.fc_size * self.fc_size * self.fc_size)
        self.gen = nn.Sequential(

            nn.ConvTranspose3d(num_unit1, num_unit2, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit2),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit2, num_unit3, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit3),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit3, num_unit4, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit4),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='trilinear'),

            nn.ConvTranspose3d(num_unit4, num_unit4, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit4),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit4, num_unit5, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit5),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit5, num_unit6, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit6),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='trilinear'),

            nn.ConvTranspose3d(num_unit6, num_unit6, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit6),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit6, num_unit5, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit5),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='trilinear'),


            nn.ConvTranspose3d(num_unit5, num_unit4, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit4),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit4, num_unit3, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit3),
            nn.ReLU(True),

            nn.Upsample(scale_factor=2, mode='trilinear'),

            
            nn.ConvTranspose3d(num_unit3, num_unit2, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit2),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit2, num_unit1, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit1),
            nn.ReLU(True),
            nn.ConvTranspose3d(num_unit1, 1, 3, 1, padding = 1),
            nn.Sigmoid()
        )

    def forward(self, x):

        x = self.gen_fc(x)
        x = x.view(x.shape[0], self.fc_channel, self.fc_size, self.fc_size, self.fc_size)
        x = self.gen(x)
        return x


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        num_unit1 = 8   #16
        num_unit2 = 16   #32
        num_unit3 = 32   #64
        num_unit4 = 64  #128
        num_unit5 = 128   #256
        num_unit6 = 256   #512
        
        self.dis = nn.Sequential(
            nn.Conv3d(1, num_unit1, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit1),
            nn.ReLU(True),
            nn.Conv3d(num_unit1, num_unit2, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit2),
            nn.ReLU(True),
            nn.Conv3d(num_unit2, num_unit3, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit3),
            nn.ReLU(True),

            nn.MaxPool3d((2, 2, 2)),

            nn.Conv3d(num_unit3, num_unit3, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit3),
            nn.ReLU(True),
            nn.Conv3d(num_unit3, num_unit4, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit4),
            nn.ReLU(True),

            nn.MaxPool3d((2, 2, 2)),

            nn.Conv3d(num_unit4, num_unit5, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit5),
            nn.ReLU(True),
            nn.Conv3d(num_unit5, num_unit5, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit5),
            nn.ReLU(True),
            nn.Conv3d(num_unit5, num_unit4, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit4),
            nn.ReLU(True),

            nn.MaxPool3d((2, 2, 2)),

            nn.Conv3d(num_unit4, num_unit3, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit3),
            nn.ReLU(True),            
            nn.Conv3d(num_unit3, num_unit2, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit2),
            nn.ReLU(True),

            nn.MaxPool3d((2, 2, 2)),


            nn.Conv3d(num_unit2, num_unit1, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit1),
            nn.ReLU(True),
            nn.Conv3d(num_unit1, num_unit1, 3, 1, padding = 1),
            nn.BatchNorm3d(num_unit1),
            nn.ReLU(True),
        )

        self.dis_fc = nn.Sequential(
            nn.Linear(num_unit1 * 4 * 4 * 4, 128),
            nn.LeakyReLU(0.1, True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):

        x = self.dis(x)
        x = x.view(x.shape[0], -1)
        x = self.dis_fc(x)
        return x

In [13]:
G = Generator().to(device)
summary(G, (128,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 512]          66,048
   ConvTranspose3d-2          [-1, 16, 4, 4, 4]           3,472
       BatchNorm3d-3          [-1, 16, 4, 4, 4]              32
              ReLU-4          [-1, 16, 4, 4, 4]               0
   ConvTranspose3d-5          [-1, 32, 4, 4, 4]          13,856
       BatchNorm3d-6          [-1, 32, 4, 4, 4]              64
              ReLU-7          [-1, 32, 4, 4, 4]               0
   ConvTranspose3d-8          [-1, 64, 4, 4, 4]          55,360
       BatchNorm3d-9          [-1, 64, 4, 4, 4]             128
             ReLU-10          [-1, 64, 4, 4, 4]               0
         Upsample-11          [-1, 64, 8, 8, 8]               0
  ConvTranspose3d-12          [-1, 64, 8, 8, 8]         110,656
      BatchNorm3d-13          [-1, 64, 8, 8, 8]             128
             ReLU-14          [-1, 64, 

  "See the documentation of nn.Upsample for details.".format(mode))


In [14]:
D = Discriminator().to(device)
summary(D, (1, 64, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [-1, 8, 64, 64, 64]             224
       BatchNorm3d-2        [-1, 8, 64, 64, 64]              16
              ReLU-3        [-1, 8, 64, 64, 64]               0
            Conv3d-4       [-1, 16, 64, 64, 64]           3,472
       BatchNorm3d-5       [-1, 16, 64, 64, 64]              32
              ReLU-6       [-1, 16, 64, 64, 64]               0
            Conv3d-7       [-1, 32, 64, 64, 64]          13,856
       BatchNorm3d-8       [-1, 32, 64, 64, 64]              64
              ReLU-9       [-1, 32, 64, 64, 64]               0
        MaxPool3d-10       [-1, 32, 32, 32, 32]               0
           Conv3d-11       [-1, 32, 32, 32, 32]          27,680
      BatchNorm3d-12       [-1, 32, 32, 32, 32]              64
             ReLU-13       [-1, 32, 32, 32, 32]               0
           Conv3d-14       [-1, 64, 32,

#functions for pytorch network

In [15]:
def array2voxel(array):
    coord_list = []
    if len(array.shape) == 5:
        array = array[0][0]
    x,y,z = array.shape
    for i in range(x):
        for j in range(y):
            for k in range(z):
                if array[i,j,k] > 0.5:
                    coord_list.append([i,j,k])
    print(len(coord_list))
    if len(coord_list) == 0:
        return np.array([[0,0,0]])  #return at least one point to prevent wandb 3dobject error
    return np.array(coord_list)

def train_model(generator, discriminator, dataloader):   
    generator.to(device)
    discriminator.to(device)
    
    #loss function
    criterion = nn.BCELoss(reduction='sum')
    #optimizer
    # dis_optimizer = optim.SGD(discriminator.parameters(), lr=config.d_lr)
    # gen_optimizer = optim.SGD(generator.parameters(), lr=config.g_lr)   
    dis_optimizer = optim.Adam(discriminator.parameters(), lr=config.d_lr)
    gen_optimizer = optim.Adam(generator.parameters(), lr=config.g_lr)

    #log models
    wandb.watch(generator, log="all")
    wandb.watch(discriminator, log="all")

    d_losses = []
    g_losses = []
    for epoch in range(config.epochs):
        generator.train()
        discriminator.train()

        d_ep_loss = 0.
        g_ep_loss = 0.
        for dataset_batch in dataloader:

            dataset_batch = dataset_batch[0]
            #add data augmentation (rotate 90/180/270)
            #this modify the whole batch at the same time
            rand_num = torch.randn(2)
            if rand_num[0] > 0.5:
                dataset_batch = dataset_batch.permute(0, 1, 3,2,4)
            if rand_num[1] > 0.5: 
                dataset_batch = dataset_batch.flip([4])
                
            dataset_batch = dataset_batch.float().to(device)
            dloss, gloss = compute_loss(generator, discriminator, dataset_batch)


            #optimize generator
            gen_optimizer.zero_grad()
            gloss.backward()
            gen_optimizer.step()

            #optimize discriminator
            dis_optimizer.zero_grad()
            dloss.backward()
            dis_optimizer.step()
          

            #record loss
            d_ep_loss += dloss.detach()  
            g_ep_loss += gloss.detach()

        #after each epoch, record total loss and sample generated obj
        d_losses.append(d_ep_loss)
        g_losses.append(g_ep_loss)
        print("discriminator, epoch"+str(epoch)+" : "+str(d_ep_loss))
        print("generator, epoch"+str(epoch)+" : "+str(g_ep_loss))

        #save model if necessary
        if epoch % config.log_interval == 0:
            sample_tree = array2voxel(generate_tree(generator))

            wandb.log({
            "discriminator loss": d_ep_loss,
            "generator loss": g_ep_loss,
            "sample_tree": wandb.Object3D(sample_tree)})
            torch.save(generator.state_dict(), os.path.join(wandb.run.dir, 'generator.pth'))
            torch.save(discriminator.state_dict(), os.path.join(wandb.run.dir, 'discriminator.pth'))
        else:
            wandb.log({
            "discriminator loss": d_ep_loss,
            "generator loss": g_ep_loss})
    
    #training end, save model again
    sample_tree = array2voxel(generate_tree(generator))

    wandb.log({
    "discriminator loss": d_ep_loss,
    "generator loss": g_ep_loss,
    "sample_tree": wandb.Object3D(sample_tree)})
    torch.save(generator.state_dict(), os.path.join(wandb.run.dir, 'generator.pth'))
    torch.save(discriminator.state_dict(), os.path.join(wandb.run.dir, 'discriminator.pth'))
    
    print(d_losses)
    print(g_losses)

# this function calculate loss of the model, 
def compute_loss(generator, discriminator, dataset_batch):
    
    #loss function
    criterion = nn.BCELoss(reduction='sum')   

    batch_size = dataset_batch.shape[0]
        
    #labels
    real_label = torch.unsqueeze(torch.ones(batch_size),1).float().to(device)
    fake_label = torch.unsqueeze(torch.zeros(batch_size),1).float().to(device)

    ############
    #   discriminator
    ############
    #generate fake trees
    z = torch.randn(batch_size, 128).float().to(device) #128-d noise vector
    tree_fake = generator(z)

    #real data (data from dataloader)
    dout_real = discriminator(dataset_batch)
    dloss_real = criterion(dout_real, real_label)
    score_real = dout_real
    #fake data (data from generator)            
    dout_fake = discriminator(tree_fake.clone().detach())   #detach so no double update on the same batch of tree_fake
    dloss_fake = criterion(dout_fake, fake_label)
    score_fake = dout_fake

    #loss function (discriminator classify real data vs generated data)
    dloss = dloss_real + dloss_fake

    ############
    #   generator
    ############

    #tree_fake is already computed above
    dout_fake = discriminator(tree_fake)
    #generator should generate trees that discriminator think they are real
    gloss = criterion(dout_fake, real_label)    

    return dloss, gloss


def save_model(generator, discriminator, g_path = os.path.join(wandb.run.dir, 'generator.pth') , d_path = os.path.join(wandb.run.dir, 'discriminator.pth')):
    torch.save(generator.state_dict(), g_path)
    torch.save(discriminator.state_dict(), d_path)

def load_model(g_path = os.path.join(wandb.run.dir, 'generator.pth'), d_path = os.path.join(wandb.run.dir, 'discriminator.pth')):
    generator = Generator()
    generator.load_state_dict(torch.load(g_path))
    generator.eval()

    discriminator = Discriminator()
    discriminator.load_state_dict(torch.load(d_path))
    discriminator.eval()
    return generator, discriminator

def generate_tree(generator, num_trees = 1):
    
    #generate noise vector
    z = torch.randn(num_trees, 128).to(device)
    generator.to(device).eval()
    tree_fake = generator(z)
    return tree_fake.detach().cpu().numpy()

#train

In [None]:
#set seed
torch.manual_seed(config.seed)
torch.autograd.set_detect_anomaly(True)

G = Generator()
D = Discriminator()
train_model(G, D, dataloader)        #if dataloader has only 1 tree, the training time is 72s per epoch.

  "See the documentation of nn.Upsample for details.".format(mode))


discriminator, epoch0 : tensor(3.9215, device='cuda:0')
generator, epoch0 : tensor(2.3877, device='cuda:0')
0


  elif 'type' in data_or_path:


discriminator, epoch1 : tensor(4.0577, device='cuda:0')
generator, epoch1 : tensor(2.2979, device='cuda:0')
discriminator, epoch2 : tensor(3.9657, device='cuda:0')
generator, epoch2 : tensor(2.3923, device='cuda:0')
discriminator, epoch3 : tensor(4.1879, device='cuda:0')
generator, epoch3 : tensor(2.2806, device='cuda:0')
discriminator, epoch4 : tensor(3.8914, device='cuda:0')
generator, epoch4 : tensor(2.3577, device='cuda:0')
discriminator, epoch5 : tensor(4.0257, device='cuda:0')
generator, epoch5 : tensor(2.4557, device='cuda:0')
discriminator, epoch6 : tensor(3.8058, device='cuda:0')
generator, epoch6 : tensor(2.4418, device='cuda:0')
discriminator, epoch7 : tensor(3.8089, device='cuda:0')
generator, epoch7 : tensor(2.4209, device='cuda:0')
discriminator, epoch8 : tensor(3.8015, device='cuda:0')
generator, epoch8 : tensor(2.4144, device='cuda:0')
discriminator, epoch9 : tensor(3.9274, device='cuda:0')
generator, epoch9 : tensor(2.3534, device='cuda:0')
discriminator, epoch10 : ten