In [1]:
import numpy as np
import os.path as osp
from scipy.io import loadmat
from numpy.random import choice
from geo_tool import Point_Cloud
from general_tools.in_out.basics import create_dir



In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [3]:
n_samples_dict = {'chair': 6778,
                  'car': 7497,
                  'desk': 8509,
                  'sofa': 3173
                }

In [21]:
# Parameters
class_name = 'sofa'
mit_data_dir = '/orions4-zfs/projects/lins2/Panos_Space/DATA/NIPS/mit_3dgan_synthetic_samples/voxel_grids/2K_models_per_class/'
mit_data_dir = osp.join(mit_data_dir, class_name + '_sample.mat')

prune_low_voxels = False

n_pc_points = 2048
out_data_dir = '/orions4-zfs/projects/lins2/Panos_Space/DATA/NIPS/mit_3dgan_synthetic_samples/point_clouds/'
out_data_dir = osp.join(out_data_dir, str(n_pc_points))
create_dir(out_data_dir)

'/orions4-zfs/projects/lins2/Panos_Space/DATA/NIPS/mit_3dgan_synthetic_samples/point_clouds/2048'

In [22]:
mit_data = loadmat(mit_data_dir)
mit_data = np.squeeze(mit_data['voxels'])
print mit_data.shape

# n_models = n_samples_dict[class_name]   # how many synthetic data to sample
n_models = len(mit_data)
gen_data = np.zeros((n_models, n_pc_points, 3))

(2000, 64, 64, 64)


In [23]:
def voxel_field_to_point_cloud(voxel_field, n_points):
    weights = voxel_field.reshape(-1) / np.sum(voxel_field)
    elements = range(len(weights))
    indices = choice(elements, p=weights, size=n_points, replace=False)
    pc = np.zeros_like(voxel_field)
    pc = pc.reshape(-1)
    pc[indices] = 1
    pc = pc.reshape(voxel_field.shape)
    x, y, z = np.where(pc)
    points = np.vstack((x, y, z)).T
    return points.astype(np.float32)

In [24]:
for i in xrange(n_models):
    voxel_field = mit_data[i]    
    if prune_low_voxels:
        voxel_field[voxel_field < 0.1] = 0
        good_cells = np.sum(voxel_field >= 0.1)    
    else:
        good_cells = n_pc_points
    n_pc_samples = min(good_cells, n_pc_points)
    pc = voxel_field_to_point_cloud(voxel_field, n_pc_samples)
    pc = Point_Cloud(pc)
    if n_pc_samples < n_pc_points:  # boot-strap
        pc = pc.sample(n_pc_points)[0]    
    assert(len(pc.points) == n_pc_points)
    gen_data[i] = pc.points

In [25]:
if prune_low_voxels:
    out_file_tag = class_name + '_pruned'
else:
    out_file_tag = class_name + '_no_pruned'

np.savez(osp.join(out_data_dir, out_file_tag), gen_data)