Skip to content

Commit

Permalink
shapenet example
Browse files Browse the repository at this point in the history
  • Loading branch information
taiya committed Oct 20, 2021
1 parent 9b16873 commit ced03cb
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 10 deletions.
4 changes: 3 additions & 1 deletion Makefile
Expand Up @@ -53,7 +53,9 @@ examples/klevr: checkmakeversion
docker run --rm --interactive --user `id -u`:`id -g` --volume `pwd`:/kubric kubricdockerhub/kubruntudev python3 examples/klevr.py
examples/katr: checkmakeversion
docker run --rm --interactive --user `id -u`:`id -g` --volume `pwd`:/kubric kubricdockerhub/kubruntudev python3 examples/katr.py

examples/shapenet: checkmakeversion
docker run --rm --interactive --user `id -u`:`id -g` --volume `pwd`:/kubric kubricdockerhub/kubruntudev python3 examples/shapenet.py

# --- runs the test suite within the dev container (similar to test.yml), e.g.
# USAGE:
# make pytest TEST=test/test_core.py
Expand Down
97 changes: 97 additions & 0 deletions examples/shapenet.py
@@ -0,0 +1,97 @@
import logging
import numpy as np

import kubric as kb
from kubric.renderer import Blender as KubricRenderer

# --- WARNING: this path is not yet public
source_path = "gs://kubric-public/ShapeNetCore.v2"

# --- CLI arguments (and modified defaults)
parser = kb.ArgumentParser()
parser.set_defaults(
seed=1,
frame_start=1,
frame_end=30,
resolution=(256, 256),
)
FLAGS = parser.parse_args()

# --- Common setups
kb.utils.setup_logging(FLAGS.logging_level)
kb.utils.log_my_flags(FLAGS)
job_dir = kb.as_path(FLAGS.job_dir)
rng = np.random.RandomState(FLAGS.seed)
scene = kb.Scene.from_flags(FLAGS)

# --- Add a renderer
renderer = KubricRenderer(scene,
use_denoising=True,
adaptive_sampling=False,
background_transparency=True)

# --- Add Klevr-like lights to the scene
scene += kb.assets.utils.get_clevr_lights(rng=rng)
scene.ambient_illumination = kb.Color(0.05, 0.05, 0.05)

# --- Add floor (~infinitely large sphere)
scene += kb.Sphere(name="floor", scale=1000, position=(0, 0, +1000), background=True, static=True)

# --- Keyframe the camera
scene.camera = kb.PerspectiveCamera()
for frame in range(FLAGS.frame_start, FLAGS.frame_end + 1):
# scene.camera.position = (1, 1, 1) #< frozen camera
scene.camera.position = kb.sample_point_in_half_sphere_shell(1.1, 1.2)
scene.camera.look_at((0, 0, 0))
scene.camera.keyframe_insert("position", frame)
scene.camera.keyframe_insert("quaternion", frame)

# --- Fetch a random (airplane) asset
asset_source = kb.AssetSource(source_path)
ids = list(asset_source.db.loc[asset_source.db['id'].str.startswith('02691156')]['id'])
asset_id = rng.choice(ids) #< e.g. 02691156_10155655850468db78d106ce0a280f87
obj = asset_source.create(asset_id=asset_id)
logging.info(f"selected '{asset_id}'")

# --- make object flat on X/Y and not penetrate floor
obj.quaternion = kb.Quaternion(axis=[1,0,0], degrees=90)
# HACK: bounds are not updated after rotation! supposed to be obj.bounds[0][2]
obj.position = obj.position - (0, 0, obj.bounds[0][1])

obj.metadata = {
"asset_id": obj.asset_id,
"category": asset_source.db[
asset_source.db["id"] == obj.asset_id].iloc[0]["category_name"],
}
scene.add(obj)

# --- Rendering
logging.info("Rendering the scene ...")
renderer.save_state(job_dir / "scene.blend")
data_stack = renderer.render()

# --- Postprocessing
kb.compute_visibility(data_stack["segmentation"], scene.assets)
data_stack["segmentation"] = kb.adjust_segmentation_idxs(
data_stack["segmentation"],
scene.assets,
[obj]).astype(np.uint8)

# --- Discard non-used information
del data_stack["uv"]
del data_stack["forward_flow"]
del data_stack["backward_flow"]
del data_stack["depth"]
del data_stack["normal"]

# --- Save to image files
kb.file_io.write_image_dict(data_stack, job_dir)

# --- Collect metadata
logging.info("Collecting and storing metadata for each object.")
data = {
"metadata": kb.get_scene_metadata(scene),
"camera": kb.get_camera_info(scene.camera),
}
kb.file_io.write_json(filename=job_dir / "metadata.json", data=data)
kb.done()
8 changes: 1 addition & 7 deletions kubric/__init__.py
Expand Up @@ -47,15 +47,9 @@
from kubric.core.objects import Cube
from kubric.core.objects import FileBasedObject

from kubric.core.traits import Vector3D
from kubric.core.traits import Scale
from kubric.core.traits import Quaternion
from kubric.core.traits import RGB
from kubric.core.traits import RGBA
from kubric.core.traits import AssetInstance

from kubric.custom_types import AddAssetFunction
from kubric.custom_types import PathLike
from kubric.custom_types import Quaternion

from kubric import assets
from kubric.assets import AssetSource
Expand Down
10 changes: 10 additions & 0 deletions kubric/assets/asset_source.py
Expand Up @@ -99,6 +99,16 @@ def fetch(self, object_id):
return urdf_path, vis_path, properties

def get_test_split(self, fraction=0.1):
"""
Generates a train/test split for the asset source.
Args:
fraction: the fraction of the asset source to use for the heldout set.
Returns:
train_objects: list of asset ID strings
held_out_objects: list of asset ID strings
"""
held_out_objects = list(self.db.sample(frac=fraction, replace=False, random_state=42)["id"])
train_objects = [i for i in self.db["id"] if i not in held_out_objects]
return train_objects, held_out_objects
Expand Down
4 changes: 2 additions & 2 deletions kubric/core/objects.py
Expand Up @@ -97,7 +97,7 @@ def look_at_quat(
return tuple(pyquat.Quaternion(matrix=(rotation_matrix1.T @ rotation_matrix2)))


def euler_to_quat(euler_angles):
def _euler_to_quat(euler_angles):
""" Convert three (euler) angles around XYZ to a single quaternion."""
q1 = pyquat.Quaternion(axis=[1., 0., 0.], angle=euler_angles[0])
q2 = pyquat.Quaternion(axis=[0., 1., 0.], angle=euler_angles[1])
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(self, position=(0., 0., 0.), quaternion=None,
quaternion = look_at_quat(position, look_at, up, front)
elif euler is not None:
assert look_at is None and quaternion is None
quaternion = euler_to_quat(euler)
quaternion = _euler_to_quat(euler)
elif quaternion is None:
quaternion = (1., 0., 0., 0.)

Expand Down
4 changes: 4 additions & 0 deletions kubric/core/traits.py
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import traitlets as tl
import pyquaternion as pyquat

from kubric.core import color
from kubric.core.assets import UndefinedAsset
Expand Down Expand Up @@ -59,6 +60,9 @@ class Quaternion(tl.TraitType):
info_text = "a 4D vector (WXYZ quaternion) of floats"

def validate(self, obj, value):
if isinstance(value, pyquat.Quaternion):
value = tuple(value)

value = np.array(value, dtype=np.float32)
if value.shape != (4,):
self.error(obj, value)
Expand Down
3 changes: 3 additions & 0 deletions kubric/custom_types.py
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable, Union, Sequence

import numpy as np
import pyquaternion as pyquat
import tensorflow_datasets.public_api as tfds

from kubric import core # pylint: disable=unused-import
Expand All @@ -25,3 +26,5 @@
PathLike = Union[str, tfds.core.ReadWritePath]

ArrayLike = Union[Sequence[float], np.ndarray]

Quaternion = pyquat.Quaternion

0 comments on commit ced03cb

Please sign in to comment.