In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import renderstim
import matplotlib.pyplot as plt
import numpy as np

import datajoint as dj
dj.config["enable_python_native_blobs"] = True

from renderstim.schema.main import LatentDataset
from renderstim.schema.scenes import RenderedScenes

Connecting nikoskar@at-database.ad.bcm.edu:3306


pybullet build time: May 20 2022 19:45:31


In [3]:
def clean_external(schema):
    for store in dj.config['stores']:
        print(f'Deleting store {store}')
        schema.external[store].delete(delete_external_files=True)
        
# schema = Schema("nikoskar_rendered_images")
# clean_external(schema)

# Generating datasets of rendered Stimuli

#### Overview

There are 2 tables that we use to prepare the latent variables needed to generate the scenes:
- 1. The `LatentDataset` table stores the dataset_function, the generator_function, the dataset_config, and the data saving folder
- 2. The `SceneConfig` table is a part table of (1), and stores the configs for all individual scenes

## 1. Insert a Dataset function

In [4]:
def insert_ds_function(num_scenes, animal, bg_type, comment):
    assert animal in ['mouse', 'monkey']
    assert bg_type in ['realistic', 'artificial']
    # function used to generate a dataset of latent configs
    dataset_fn = "renderstim.latents.latent_dataset"

    # dataset master config used to generate individual scene configs
    if animal == 'mouse':
        dataset_config = dict(
            num_scenes=num_scenes,
            resolution=[256, 144],
            min_num_objects=2,
            max_num_objects=5,
            spawn_region=[[-2.5, -3.0, 1.0], [2.5, 1.0, 2.5]],
            sun_position=[0.0, 0.0, 6.0],
            camera_position=[0.0, -6.4, 7.67],
            camera_look_at=[0.0, 0.0, 0.0],
            camera_focal_length=35.0,
            camera_sensor_width=32.0,
            floor_scale=[8.0, 8.0, 0.1],
            floor_position=[0.0, 0.0, 0.0],
            background_type=bg_type
        )
    else:
        dataset_config = dict(
            num_scenes=num_scenes,
            resolution=[256, 256],
            min_num_objects=3,
            max_num_objects=6,
            spawn_region=[[-2.5, -3.0, 1.0], [2.5, 2.0, 2.5]],
            sun_position=[0.0, 0.0, 6.0],
            camera_position=[0.0, -6.1, 7.27],
            camera_look_at=[0.0, 0.0, 0.0],
            camera_focal_length=35.0,
            camera_sensor_width=32.0,
            floor_scale=[10.0, 10.0, 0.01],
            floor_position=[0.0, 0.0, 0.0],
            background_type=bg_type
        )
    # function used to generate individual scenes
    generator_fn = "renderstim.generators.render_scene"

    key = LatentDataset().add_entry(
        dataset_fn=dataset_fn,
        dataset_config=dataset_config,
        generator_fn=generator_fn,
        dataset_comment=comment
    )

    # define a key to restrict the table to one entry
    key = dict(dataset_hash=key['dataset_hash'])
    print(key)

    # fill up entries in the part table
    LatentDataset().SceneConfig().fill(key)

def del_ds(key):
    (LatentDataset() & key).delete()
    
def delete_errors_from_jobs():
    schema = dj.schema("pipeline_rendered_images")
    if len(schema.jobs) != 0:
        print("Deleting errors from schema.jobs. Ready to repopulate...")
        schema.jobs.delete()
    
def get_non_generated_hashes(key):
    # all dataset hashes
    init_hashes = (LatentDataset().SceneConfig() & key).fetch('scene_hash')

    # successfully generated hashes
    generated_hashes = (RenderedScenes() & key).fetch('scene_hash')
    bad_hashes = set(init_hashes) - set(generated_hashes)
    return list(bad_hashes)

def check_and_correct_errors(key):    
    # check if there were scenes that were not rendered
    bad_hashes = get_non_generated_hashes(key)
    if len(bad_hashes) == 0:
        print("All scenes for this dataset were rendered successfully")
        return
    else:
        print(
            f"There were {len(bad_hashes)} scenes that failed to render.\n",
            f"... Replacing the non generated scene configs with {len(bad_hashes)} new ones"
        )
    
        LatentDataset().SceneConfig().replace(key, bad_hashes)
        delete_errors_from_jobs()

In [5]:
# insert_ds_function(
#     num_scenes=100,
#     animal="monkey", 
#     bg_type="artificial", 
#     comment="monkey d2 test: artificial bg: 100 scenes"
# )

#### delete the dataset entries from all three tables

In [6]:
# key = dict(dataset_hash="8d4433808eda2178ee24ffc6ed441f9e")
# del_ds(key)

#### Find, replace, the generate scenes that were not generated
#### This can happen if there is a failure to connect to pybullet's 
#### physics server or if not all objects were placed without overlap

In [7]:
# key = dict(dataset_hash="86b0f40049d6576502ff06a7cc0f3a30")
# check_and_correct_errors(key)

# Delete data from tables if you want

In [8]:
LatentDataset()

dataset_fn  name of the dataset loader function,dataset_hash  hash of the config object,generator_fn  name of the generator function,dataset_config  dataset config object,dataset_comment  short description,dataset_ts  UTZ timestamp at time of insertion
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,=BLOB=,mouse v2.0 mixed bg: 5100 scenes,2023-02-22 12:18:16
renderstim.latents.latent_dataset,86b0f40049d6576502ff06a7cc0f3a30,renderstim.generators.render_scene,=BLOB=,monkey d1: artificial bg: 25000 scenes,2023-02-24 09:18:31
renderstim.latents.latent_dataset,8d7ef17c3ff0c5e6b6cd10dc31d052ff,renderstim.generators.render_scene,=BLOB=,mouse v1.0 mixed bg: 5100 scenes,2023-02-18 16:36:55
renderstim.latents.latent_dataset,be73e6f5eec1784c971aad85bcf9d9d6,renderstim.generators.render_scene,=BLOB=,mouse v1.0 realistic bg: 5100 scenes,2023-02-18 16:36:14


In [10]:
LatentDataset().SceneConfig()

dataset_fn  name of the dataset loader function,dataset_hash  hash of the config object,generator_fn  name of the generator function,scene_hash  hash of the config object,scene_config  scene config object
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,0008fa9185054bcbcb2fbf5f122ddc37,=BLOB=
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,000c6c3d22ca5741d5dbefe1e417d828,=BLOB=
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,000db97e33ca00b8828e55077a56eb36,=BLOB=
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,001d4b5b1d30ebefbffe7b394a0b799c,=BLOB=
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,00227177467306196a9cf8e5d78af3a6,=BLOB=
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,0046421c2dadf8b3968fe303bcd1702d,=BLOB=
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,00647cf259d9ff64babed98743e15e06,=BLOB=
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,006acb0b0f70e18e3a877e3a08ab3f99,=BLOB=
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,0075c5bfebb43f82e0a00be99d18d75a,=BLOB=
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,007c24a70aae586eea7dd4e1c0b2ddf4,=BLOB=


In [11]:
RenderedScenes()

dataset_fn  name of the dataset loader function,dataset_hash  hash of the config object,generator_fn  name of the generator function,scene_hash  hash of the config object,scene,segmentation,object_coordinates,normals,depth,metadata  dict containing metadata about the scene,rendering_ts  UTZ timestamp at time of insertion
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,0008fa9185054bcbcb2fbf5f122ddc37,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,2023-02-22 12:22:04
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,000c6c3d22ca5741d5dbefe1e417d828,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,2023-02-22 12:21:54
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,000db97e33ca00b8828e55077a56eb36,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,2023-02-22 12:21:55
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,001d4b5b1d30ebefbffe7b394a0b799c,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,2023-02-22 12:21:58
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,00227177467306196a9cf8e5d78af3a6,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,2023-02-22 12:22:17
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,0046421c2dadf8b3968fe303bcd1702d,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,2023-02-22 12:22:16
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,00647cf259d9ff64babed98743e15e06,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,2023-02-22 12:22:22
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,006acb0b0f70e18e3a877e3a08ab3f99,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,2023-02-22 12:22:25
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,0075c5bfebb43f82e0a00be99d18d75a,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,2023-02-22 12:22:53
renderstim.latents.latent_dataset,62ef5b74d0bc319c9c437c67fa6f9252,renderstim.generators.render_scene,007c24a70aae586eea7dd4e1c0b2ddf4,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,=BLOB=,2023-02-22 12:22:42


In [33]:
# shs = (RenderedScenes() & 'dataset_hash="8a807fb3c3534868dd872f622aecd27b"').fetch('scene_hash')[:10]

# outs = [
#     "scene", 
#     "segmentation", 
#     "object_coordinates", 
#     "normals", 
#     "depth", 
#     "metadata"
# ]

# for sh in shs:
#     fig, axs = plt.subplots(1, 5, figsize=(12, 3), dpi=100)
#     axs = axs.ravel()

#     img_outs = [(RenderedScenes() & dict(scene_hash=sh)).fetch1(out) for out in outs]

#     for i, ax in enumerate(axs):
#         ax.imshow(img_outs[i], cmap="gray")
#         ax.set_title(outs[i], fontsize=10)
#         ax.axis("off")

#     plt.tight_layout()
#     plt.show()