# Data loader design

In [1]:
import numpy as np
import os

from meshparty import trimesh_io
from neuronencodings.data import transform, cell_dataset, Phase

import ipyvolume.pylab as p3
import ipyvolume as ipv

import torch
import torch.utils

HOME = os.path.expanduser("~")

## Data Viz

In [2]:
def random_color():
    return np.random.rand(3)

def visualize_points(ps):   
    fig = p3.figure(width=1024, height=1024)

    for ps_i in ps:
        X, Y, Z = np.array(ps_i.T*1000)
        ipv.scatter(X, Y, Z, marker='sphere', color=random_color(), size=1)

    ipv.squarelim()
    ipv.show()

## Data loading 

In [3]:
cv_path = "graphene://https://www.dynamicannotationframework.com/segmentation/1.0/pinky100_sv16"
disk_cache_path = f"{HOME}/.meshcache/"

In [4]:
meshmeta = trimesh_io.MeshMeta(cv_path=cv_path, disk_cache_path=disk_cache_path)

In [5]:
ids = [648518346342073945,648518346342317292,648518346349496323,648518346341408751,648518346349508876]

In [6]:
mesh = meshmeta.mesh(seg_id=ids[0])



In [7]:
def get_block(mesh, n_points=750, sample_n_points=2000, n_views=1):
    vertex_list = []
    center_vertex_ids = []
    for i_view in range(n_views):
        vertices, v_id = mesh.get_local_view(n_points, pc_align=True, pc_norm=False,
                                             sample_n_points=sample_n_points)
        vertex_list.append(vertices.squeeze())
        center_vertex_ids.append(v_id[0])
    return np.array(vertex_list), np.array(center_vertex_ids)

In [8]:
vl, cvl = get_block(mesh)
# visualize_points([mesh.vertices[::1000], mesh.vertices[cvl]])
vl_n = transform.norm_to_unit_sphere_many(vl)
visualize_points(vl_n)

mesh.vertices[cvl] / np.array([4, 4, 40])

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=(0.0, 0.0, 0.0, …

TrackedArray([[83401.953125 , 73733.4609375,   432.0378418]])

## CellDataset

In [9]:
train_dset = cell_dataset.CellDataset(gt_dirs=[disk_cache_path],
                                      phase=Phase.FULL, 
                                      n_points=500,
                                      n_views_per_batch=5,
                                      sample_n_points=1500,
                                      random_seed=0,
                                      train_split=.8,
                                      val_split=.1,
                                      test_split=.1)

In [10]:
train_loader = torch.utils.data.DataLoader(train_dset,
                                           batch_size=2,
                                           shuffle=True,
                                           num_workers=4,
                                           drop_last=True)

In [11]:
for i_batch, d in enumerate(train_loader):
    print(i_batch)
    break



0


In [28]:
visualize_points(d.numpy()[1])

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=(0.0, 0.0, 0.0, …

In [29]:
visualize_points(d.reshape(-1, 500, 3).numpy()[5:])

VBox(children=(Figure(camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion=(0.0, 0.0, 0.0, …