<a href="https://colab.research.google.com/github/buganart/BUGAN/blob/master/script_GAN_ver3_voxelsize1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#mount google drive
from google.colab import output
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
#right click shared folder IRCMS_GAN_collaborative_database and "Add shortcut to Drive" to My drive
%cd drive/My Drive/IRCMS_GAN_collaborative_database/

#record paths to resources
data_path = "Research/Peter/Tree_3D_models_obj/obj_files/"
run_path = "Experiments/colab-treegan/"

# !ls

/content/drive/.shortcut-targets-by-id/1ylB2p6N0qQ-G4OsBuwcZ9C0tsqVu9ww4/IRCMS_GAN_collaborative_database


In [3]:
!pip install trimesh
!pip install wandb -q
output.clear()

#add libraries, and login to wandb

In [4]:
import io
import os
import trimesh
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader, TensorDataset
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

# Ignore excessive warnings
import logging
logging.propagate = False 
logging.getLogger().setLevel(logging.ERROR)

# WandB – Import the wandb library
import wandb

In [5]:
!wandb login
output.clear()

In [6]:
#id None to start a new run. For resuming run, put the id of the run below
id = "11d67bs4"
resume = False
if id is None:
    id = wandb.util.generate_id()
else:
    resume = True

run = wandb.init(project="tree-gan", id=id, resume="allow", dir=run_path)
print("run id:" + str(wandb.run.id))
wandb.run.name = str(wandb.run.id)
wandb.watch_called = False
wandb.run.save_code = True

wandb.run.group = "GANv3.2"

run id:11d67bs4


Streaming file created twice in same run: Experiments/colab-treegan/wandb/run-20200814_080813-11d67bs4/wandb-history.jsonl
Streaming file created twice in same run: Experiments/colab-treegan/wandb/run-20200814_080813-11d67bs4/wandb-events.jsonl


In [7]:
#keep track of hyperparams
config = wandb.config
config.batch_size = 16
config.epochs = 10
config.g_lr = 0.0001
config.g_layer = 2
config.d_lr = 0.00003           
config.d_layer = 1
config.seed = 1234
config.log_interval = 20
config.data_augmentation = True
config.num_augment_data = 4

#dataset

In [8]:
def mesh2arrayCentered(mesh, voxel_size = 1, array_length = 64):
    #given array length 64, voxel size 2, then output array size is [128,128,128]
    array_size = np.ceil(np.array([array_length, array_length, array_length]) / voxel_size).astype(int)
    vox_array = np.zeros(array_size, dtype=bool)    #tanh: voxel representation [-1,1], sigmoid: [0,1]
    #scale mesh extent to fit array_length
    max_length = np.max(np.array(mesh.extents))
    mesh = mesh.apply_transform(trimesh.transformations.scale_matrix((array_length-1)/max_length))  #now the extent is [array_length**3]
    v = mesh.voxelized(voxel_size)  #max voxel array length = array_length / voxel_size

    #find indices in the v.matrix to center it in vox_array
    indices = ((array_size - v.matrix.shape)/2).astype(int)
    vox_array[indices[0]:indices[0]+v.matrix.shape[0], indices[1]:indices[1]+v.matrix.shape[1], indices[2]:indices[2]+v.matrix.shape[2]] = v.matrix

    return vox_array


def data_augmentation(mesh, array_length = 64, num_augment_data = 4, scale_max_margin = 3):

    retval = np.zeros((num_augment_data, array_length, array_length, array_length))

    for i in range(num_augment_data):

        #first select rotation angle (angle in radian)
        angle = 2 * np.pi * (np.random.rand(1)[0])

        #scale is implemented based on the bounding box with box margin (larger margin, smaller scale)
        box_margin = np.random.randint(scale_max_margin + 1)

        #pick a random starting point within margin as translation
        initial_position = np.random.randint(box_margin + 1, size=3)

        result_array = modify_mesh(mesh, array_length, angle, box_margin, initial_position)
        retval[i] = result_array

    return retval


def modify_mesh(mesh, out_array_length, rot_angle, scale_box_margin, array_init_pos):
    #first copy mesh
    mesh = mesh.copy()
    #rotate mesh by rot_angle in radian
    mesh = mesh.apply_transform(trimesh.transformations.rotation_matrix(rot_angle, (0,1,0)))

    #scale is implemented based on the bounding box with box margin (larger margin, smaller scale)
        #example (assume out_array_length=64): margin = 0, bounding box shape = (64,64,64); margin = 3, bounding box shape = (61,61,61)
    scaled_size = out_array_length - scale_box_margin
    mesh_array = mesh2arrayCentered(mesh, array_length = scaled_size)

    #put them into bounding box (and translation)
    retval = np.zeros((out_array_length, out_array_length, out_array_length))
    #apply translation by selecting initial position
        #example: same mesh array of size (61,61,61) but with two position (0,1,0) and (1,0,0) is just a translation of 2 units
    x,y,z = array_init_pos
    retval[x:x+scaled_size, y:y+scaled_size, z:z+scaled_size] = mesh_array

    return retval

In [9]:
dataset_artifact = run.use_artifact("dataset-tree:full", type='dataset')
dir_dict = dataset_artifact.metadata['dir_dict']
artifact_dir = dataset_artifact.download()
print(dir_dict)

[34m[1mwandb[0m: Downloading large artifact dataset-tree:full, 566.02MB. 216 files... Done. 19.4s


{'old': ['old_1.obj', 'old_2.obj', 'old_3.obj'], 'raft': ['raft_1_1.obj', 'raft_1_2.obj', 'raft_1_4.obj', 'raft_1_3.obj', 'raft_1_5.obj', 'raft_1_6.obj', 'raft_1_7.obj', 'raft_1_8.obj', 'raft_1_9.obj', 'raft_1_10.obj'], 'group': ['group_1_1.obj', 'group_1_2.obj', 'group_1_3.obj', 'group_1_4.obj', 'group_1_5.obj', 'group_1_6.obj', 'group_1_7.obj', 'group_1_8.obj', 'group_1_9.obj', 'group_1_10.obj'], 'leaning': ['leaning_1_1.obj', 'leaning_1_2.obj', 'leaning_1_3.obj', 'leaning_1_4.obj', 'leaning_1_5.obj', 'leaning_1_6.obj', 'leaning_1_7.obj', 'leaning_1_8.obj', 'leaning_1_9.obj', 'leaning_1_10.obj', 'leaning_2_1.obj', 'leaning_2_2.obj', 'leaning_2_3.obj', 'leaning_2_4.obj', 'leaning_2_5.obj', 'leaning_2_6.obj', 'leaning_2_8.obj', 'leaning_2_10.obj', 'leaning_2_7.obj', 'leaning_2_9.obj'], 'windswept': ['windswept_1_1.obj', 'windswept_1_2.obj', 'windswept_1_3.obj', 'windswept_1_4.obj', 'windswept_1_5.obj', 'windswept_1_6.obj', 'windswept_1_7.obj', 'windswept_1_8.obj', 'windswept_1_9.obj', 

In [10]:
dataset = []

for data_cat in dir_dict:
    filename_list = dir_dict[data_cat]
    for filename in filename_list:
        filename = artifact_dir + "/" + data_cat + "/" + filename
        m = trimesh.load(filename, force='mesh')
        #augment data
        if config.data_augmentation:
            array = data_augmentation(m, num_augment_data = config.num_augment_data)
        else:
            array = mesh2arrayCentered(mesh)[np.newaxis, :, :, :]
        dataset.append(array)
#now all the returned array contains multiple samples
dataset = np.concatenate(dataset)

In [11]:
dataset = torch.tensor(dataset)
print(torch.unsqueeze(dataset, -1).shape)
tensor_dataset = TensorDataset(torch.unsqueeze(dataset, 1))

dataloader = DataLoader(tensor_dataset, batch_size=config.batch_size, shuffle=True)

torch.Size([864, 64, 64, 64, 1])


#model description

In [12]:
#input: 128-d noise vector
#output: (250,250,250) array with values in [0,1]

class Generator(nn.Module):
    def __init__(self, layer_per_block=1):
        super(Generator, self).__init__()


        #layer_per_block must be >= 1
        if layer_per_block < 1:
            layer_per_block = 1

        self.fc_channel = 8 #16
        self.fc_size = 4

        num_unit1 = self.fc_channel   
        num_unit2 = 16   #32
        num_unit3 = 32   #64
        num_unit4 = 64   #128
        num_unit5 = 32   #256
        num_unit6 = 16   #512

        num_layer_unit_list = [num_unit1, num_unit2, num_unit3, num_unit4, num_unit5, num_unit6]
        gen_module = []
        #5 blocks (need 4 pool to reduce size)
        for i in range(5):
            num_layer_unit1, num_layer_unit2 = num_layer_unit_list[i], num_layer_unit_list[i+1]

            gen_module.append(nn.ConvTranspose3d(num_layer_unit1, num_layer_unit2, 3, 1, padding = 1))
            gen_module.append(nn.BatchNorm3d(num_layer_unit2))
            gen_module.append(nn.ReLU(True))

            for _ in range(layer_per_block):
                gen_module.append(nn.ConvTranspose3d(num_layer_unit2, num_layer_unit2, 3, 1, padding = 1))
                gen_module.append(nn.BatchNorm3d(num_layer_unit2))
                gen_module.append(nn.ReLU(True))

            gen_module.append(nn.Upsample(scale_factor=2, mode='trilinear'))

        #remove extra pool layer
        gen_module = gen_module[:-1]

        #add final sigmoid 
        gen_module.append(nn.ConvTranspose3d(num_unit6, 1, 3, 1, padding = 1))
        gen_module.append(nn.Sigmoid())

        

        self.gen_fc = nn.Linear(128, num_unit1 * self.fc_size * self.fc_size * self.fc_size)
        self.gen = nn.Sequential(*gen_module)

    def forward(self, x):

        x = self.gen_fc(x)
        x = x.view(x.shape[0], self.fc_channel, self.fc_size, self.fc_size, self.fc_size)
        x = self.gen(x)
        return x


class Discriminator(nn.Module):
    def __init__(self, layer_per_block=1):
        super(Discriminator, self).__init__()

         #layer_per_block must be >= 1
        if layer_per_block < 1:
            layer_per_block = 1

        num_unit1 = 1   #input channel number
        num_unit2 = 8   #32
        num_unit3 = 16   #64
        num_unit4 = 32  #128
        num_unit5 = 16   #256
        num_unit6 = 8   #512


        num_layer_unit_list = [num_unit1, num_unit2, num_unit3, num_unit4, num_unit5, num_unit6]
        dis_module = []
        #5 blocks (need 4 pool to reduce size)
        for i in range(5):
            num_layer_unit1, num_layer_unit2 = num_layer_unit_list[i], num_layer_unit_list[i+1]

            dis_module.append(nn.Conv3d(num_layer_unit1, num_layer_unit2, 3, 1, padding = 1))
            dis_module.append(nn.BatchNorm3d(num_layer_unit2))
            dis_module.append(nn.ReLU(True))

            for _ in range(layer_per_block):
                dis_module.append(nn.Conv3d(num_layer_unit2, num_layer_unit2, 3, 1, padding = 1))
                dis_module.append(nn.BatchNorm3d(num_layer_unit2))
                dis_module.append(nn.ReLU(True))

            dis_module.append(nn.MaxPool3d((2, 2, 2)))

        #remove extra pool layer
        dis_module = dis_module[:-1]

        
        self.dis = nn.Sequential(*dis_module)

        self.dis_fc1 = nn.Sequential(
            nn.Linear(num_unit6 * 4 * 4 * 4, 128),
            nn.ReLU(True)
        )
        self.dis_fc2 = nn.Sequential(
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):

        x = self.dis(x)
        x = x.view(x.shape[0], -1)
        fx = self.dis_fc1(x)
        x = self.dis_fc2(fx)
        return x, fx


class GAN(nn.Module):
    def __init__(self, g_layer = config.g_layer, d_layer = config.d_layer):
        super(GAN, self).__init__()
        self.generator = Generator(g_layer)
        self.discriminator = Discriminator(d_layer)

    def forward(self, x):
        x = self.generator(x)
        x = self.discriminator(x)
        return x

In [13]:
G = GAN(config.g_layer, config.d_layer).to(device)
summary(G, (128,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 512]          66,048
   ConvTranspose3d-2          [-1, 16, 4, 4, 4]           3,472
       BatchNorm3d-3          [-1, 16, 4, 4, 4]              32
              ReLU-4          [-1, 16, 4, 4, 4]               0
   ConvTranspose3d-5          [-1, 16, 4, 4, 4]           6,928
       BatchNorm3d-6          [-1, 16, 4, 4, 4]              32
              ReLU-7          [-1, 16, 4, 4, 4]               0
   ConvTranspose3d-8          [-1, 16, 4, 4, 4]           6,928
       BatchNorm3d-9          [-1, 16, 4, 4, 4]              32
             ReLU-10          [-1, 16, 4, 4, 4]               0
         Upsample-11          [-1, 16, 8, 8, 8]               0
  ConvTranspose3d-12          [-1, 32, 8, 8, 8]          13,856
      BatchNorm3d-13          [-1, 32, 8, 8, 8]              64
             ReLU-14          [-1, 32, 

  "See the documentation of nn.Upsample for details.".format(mode))


#functions for pytorch network

In [14]:
def netarray2indices(array):
    coord_list = []
    if len(array.shape) == 5:
        array = array[0][0]
    x,y,z = array.shape
    for i in range(x):
        for j in range(y):
            for k in range(z):
                if array[i,j,k] > 0.5:        #tanh: voxel representation [-1,1], sigmoid: [0,1]
                    coord_list.append([i,j,k])
    print(len(coord_list))
    if len(coord_list) == 0:
        return np.array([[0,0,0]])  #return at least one point to prevent wandb 3dobject error
    return np.array(coord_list)

# array should be 3d
def netarray2mesh(array):
    if len(array.shape) != 3:
        raise Exception("netarray2mesh: input array should be 3d")

    #convert to bool dtype
    array = array > 0.5
    #array all zero gives error
    if np.sum(array) == 0:
        array[0,0,0] = True
    voxelmesh = trimesh.voxel.base.VoxelGrid(trimesh.voxel.encoding.DenseEncoding(array)).marching_cubes
    voxelmeshfile = voxelmesh.export(file_type='obj')
    voxelmeshfile = wandb.Object3D(io.StringIO(voxelmeshfile),file_type='obj')

    return voxelmesh, voxelmeshfile

def train_model(GAN, dataloader):

    
    torch.save(GAN, os.path.join(wandb.run.dir, 'GAN_model.pth'))
    wandb.save(os.path.join(wandb.run.dir, 'GAN_model.pth'))

    #start training
    GAN.to(device)
    generator = GAN.generator.to(device)
    discriminator = GAN.discriminator.to(device)

    
    
    #loss function
    criterion = nn.BCELoss(reduction='mean')
    #optimizer 
    dis_optimizer = optim.Adam(discriminator.parameters(), lr=config.d_lr)
    gen_optimizer = optim.Adam(generator.parameters(), lr=config.g_lr)

    #log models
    wandb.watch(GAN, log="all")

    d_losses = []
    g_losses = []
    for epoch in range(config.epochs):
        generator.train()
        discriminator.train()

        d_ep_loss = 0.
        g_ep_loss = 0.
        for dataset_batch in dataloader:

            dataset_batch = dataset_batch[0]        #dataset_batch was a list: [array], so just take the array inside                
            dataset_batch = dataset_batch.float().to(device)
            dloss, gloss = compute_loss(generator, discriminator, dataset_batch)


            #optimize generator
            gen_optimizer.zero_grad()
            gloss.backward(retain_graph=True)
            gen_optimizer.step()

            #optimize discriminator
            dis_optimizer.zero_grad()
            dloss.backward(retain_graph=False)
            dis_optimizer.step()
          

            #record loss
            d_ep_loss += dloss.detach()  
            g_ep_loss += gloss.detach()

        #after each epoch, record total loss and sample generated obj
        d_losses.append(d_ep_loss)
        g_losses.append(g_ep_loss)
        print("discriminator, epoch"+str(epoch)+" : "+str(d_ep_loss))
        print("generator, epoch"+str(epoch)+" : "+str(g_ep_loss))

        #save model if necessary
        if epoch % config.log_interval == 0:

            sample_tree_array = generate_tree(GAN)[0]  #only 1 tree
            sample_tree_indices = netarray2indices(sample_tree_array)
            _, voxelmeshfile = netarray2mesh(sample_tree_array)

            wandb.log({
            "discriminator loss": d_ep_loss,
            "generator loss": g_ep_loss,
            "sample_tree_indices": sample_tree_indices,
            "sample_tree_voxelmesh": voxelmeshfile})
            save_model(GAN)

        else:
            wandb.log({
            "discriminator loss": d_ep_loss,
            "generator loss": g_ep_loss})
    
    #training end, save model again
    sample_tree_array = generate_tree(GAN)[0] #only 1 tree
    sample_tree_indices = netarray2indices(sample_tree_array)
    _, voxelmeshfile = netarray2mesh(sample_tree_array)

    wandb.log({
    "discriminator loss": d_ep_loss,
    "generator loss": g_ep_loss,
    "sample_tree_indices": sample_tree_indices,
    "sample_tree_voxelmesh": voxelmeshfile})
    save_model(GAN)
    
    
    print(d_losses)
    print(g_losses)

# this function calculate loss of the model, 
def compute_loss(generator, discriminator, dataset_batch):
    
    #loss function
    criterion = nn.BCELoss(reduction='mean')   

    batch_size = dataset_batch.shape[0]
        
    #labels
    real_label = torch.unsqueeze(torch.ones(batch_size),1).float().to(device)
    fake_label = torch.unsqueeze(torch.zeros(batch_size),1).float().to(device)

    ############
    #   discriminator
    ############
    #generate fake trees
    z = torch.randn(batch_size, 128).float().to(device) #128-d noise vector
    tree_fake = generator(z)

    #real data (data from dataloader)
    dout_real, features_real = discriminator(dataset_batch)
    dloss_real = criterion(dout_real, real_label)
    score_real = dout_real
    #fake data (data from generator)            
    dout_fake, _ = discriminator(tree_fake.clone().detach())   #detach so no update to generator
    dloss_fake = criterion(dout_fake, fake_label)
    score_fake = dout_fake

    #loss function (discriminator classify real data vs generated data)
    dloss = (dloss_real + dloss_fake)/2

    ############
    #   generator
    ############

    #tree_fake is already computed above
    dout_fake, features_fake = discriminator(tree_fake)
    #generator should generate trees that discriminator think they are real
    gloss = criterion(dout_fake, real_label)
    #add feature matching
    # mseloss = nn.MSELoss(reduction="sum")
    # gloss += mseloss(torch.mean(features_fake), torch.mean(features_real))

    return dloss, gloss


def save_model(model, model_path = os.path.join(wandb.run.dir, 'model_dict.pth')):

    torch.save(model.state_dict(), model_path)
    wandb.save(model_path)

def load_model(model_path = 'model_dict.pth'):
    model = GAN()

    model_file = wandb.restore(model_path)
    model.load_state_dict(torch.load(model_file.name))

    return model

# def generate_tree(generator, num_trees = 1):
    
#     #generate noise vector
#     z = torch.randn(num_trees, 128).to(device)
#     generator.to(device).eval()
#     tree_fake = generator(z)
#     return tree_fake.detach().cpu().numpy()

def generate_tree(model, check_D = False, num_trees = 1, num_try = 100):
    #num_try is number of trial to generate a tree that can fool D
    #total number of sample generated = num_trees * num_try
    

    model.to(device).eval()
    generator = model.generator.to(device).eval()
    discriminator = model.discriminator.to(device).eval()

    result = None


    if not check_D:
        num_tree_total = num_trees
        num_runs = int(np.ceil(num_tree_total / config.batch_size))
        #ignore discriminator
        for i in range(num_runs):
            #generate noise vector
            z = torch.randn(config.batch_size, 128).to(device)
            
            tree_fake = generator(z)[:,0,:,:,:]
            selected_trees = tree_fake.detach().cpu().numpy()
            if result is None:
                result = selected_trees
            else:
                result = np.concatenate((result, selected_trees), axis=0)
    else:
        num_tree_total = num_trees * num_try
        num_runs = int(np.ceil(num_tree_total / config.batch_size))
        #only show samples can fool discriminator
        for i in range(num_runs):
            #generate noise vector
            z = torch.randn(config.batch_size, 128).to(device)
            
            tree_fake = generator(z)
            dout, _ = discriminator(tree_fake)
            dout = dout > 0.5
            selected_trees = tree_fake[dout].detach().cpu().numpy()
            if result is None:
                result = selected_trees
            else:
                result = np.concatenate((result, selected_trees), axis=0)
    #select at most num_trees
    if result.shape[0] > num_trees:
        result = result[:num_trees]
    #in case no good result
    if result.shape[0] <= 0:
        result = np.zeros((1,64,64,64))
        result[0,0,0,0] = 1
    return result

#train

In [15]:
# check if resume
if resume:
    gan = load_model()
else:
    gan = GAN(config.g_layer, config.d_layer)

#set seed
torch.manual_seed(config.seed)
torch.autograd.set_detect_anomaly(True)

train_model(gan, dataloader)        #if dataloader has only 1 tree, the training time is 72s per epoch.

  "See the documentation of nn.Upsample for details.".format(mode))


discriminator, epoch0 : tensor(0.0101, device='cuda:0')
generator, epoch0 : tensor(598.7605, device='cuda:0')
1809
discriminator, epoch1 : tensor(0.0068, device='cuda:0')
generator, epoch1 : tensor(526.7350, device='cuda:0')
discriminator, epoch2 : tensor(0.0032, device='cuda:0')
generator, epoch2 : tensor(541.1215, device='cuda:0')
discriminator, epoch3 : tensor(0.0060, device='cuda:0')
generator, epoch3 : tensor(565.5887, device='cuda:0')
discriminator, epoch4 : tensor(0.0004, device='cuda:0')
generator, epoch4 : tensor(639.7161, device='cuda:0')
discriminator, epoch5 : tensor(0.0005, device='cuda:0')
generator, epoch5 : tensor(626.5934, device='cuda:0')
discriminator, epoch6 : tensor(0.0021, device='cuda:0')
generator, epoch6 : tensor(592.4025, device='cuda:0')
discriminator, epoch7 : tensor(0.0002, device='cuda:0')
generator, epoch7 : tensor(646.8451, device='cuda:0')
discriminator, epoch8 : tensor(0.0001, device='cuda:0')
generator, epoch8 : tensor(667.7565, device='cuda:0')
discr