# Imports



In [79]:
from pathlib import Path
import numpy as np
import trimesh
from monai.data import Dataset
from monai.config import print_config
from transforms_templates.utils.utils import watermark, wide_notebook
from transforms_templates.utils.log import log

import torch
import matplotlib.pyplot as plt

# autoreload python modules on the fly when its source is changed
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
print_config()

_ = watermark(packages=['python', 'virtualenv', 'nvidia', 'cudnn', 'hostname', 'torch', 'trimesh', 'transforms_templates'])
wide_notebook()

MONAI version: 0.3.0
Python version: 3.7.5 (default, Apr 19 2020, 20:18:17)  [GCC 9.2.1 20191008]
OS version: Linux (5.3.0-64-generic)
Numpy version: 1.19.2
Pytorch version: 1.4.0
MONAI flags: HAS_EXT = True, USE_COMPILED = False

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.1.1
scikit-image version: 0.17.2
Pillow version: 6.2.2
Tensorboard version: 2.4.0a20201021
gdown version: 3.12.2
TorchVision version: 0.5.0
ITK version: 5.1.0
tqdm version: 4.50.2

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

virtualenv:     (transforms_templates) 
python:         3.7.5
hostname:       GA-970A-UD3
nvidia driver:  b'435.21'
torch:          1.4.0
trimesh:        3.8.11
transforms_templates: 


In [3]:
def show_image(img, figsize=(8, 8), factor=1 / 255):
    fig, ax = plt.subplots(figsize=figsize)
    # ch, h, w --> h, w, ch
    img = img.transpose(1, 2, 0)
    img = img * factor
    ax.imshow(img)

#  Data dir

In [4]:
# Setup data directory

DIR_DATA = Path("/media/Linux_4Tb/synth3D/tablets_30")
assert DIR_DATA.exists()

# Data set

In [5]:
#d = (DIR_DATA / 'corrupted')
train_ply_files = sorted((DIR_DATA / 'corrupted').glob('*.ply'))

def get_corresponding_png_file(fn, foreshortening='top', kind='rgb'):
    assert foreshortening in ['top', 'left', 'right']
    assert kind in ['rgb', 'depth']
    base = fn.with_suffix('').name
    return fn.parent.parent / 'renders' / f'{base}_{foreshortening}_{kind}.png'

def get_corresponding_meta_info_file(fn, foreshortening='top'):
    assert foreshortening in ['top', 'left', 'right']
    base = fn.with_suffix('').name
    return fn.parent.parent / 'renders' / f'{base}_{foreshortening}_meta_info.json'

def get_files_dict(fn):
    return {'mesh': fn, 'image2d': get_corresponding_png_file(fn), 'image2d_meta_dict': get_corresponding_meta_info_file(fn)}

data_dicts = [get_files_dict(fn) for fn in train_ply_files]

In [131]:
ds = Dataset(data=data_dicts)

In [132]:
ds[0]

{'mesh': PosixPath('/media/Linux_4Tb/synth3D/tablets_30/corrupted/tablets_30_00000.ply'),
 'image2d': PosixPath('/media/Linux_4Tb/synth3D/tablets_30/renders/tablets_30_00000_top_rgb.png'),
 'image2d_meta_dict': PosixPath('/media/Linux_4Tb/synth3D/tablets_30/renders/tablets_30_00000_top_meta_info.json')}

In [8]:
ds[0]['mesh'].exists(), ds[0]['image2d'].exists(), ds[0]['image2d_meta_dict'].exists()

(True, True, True)

#  Loader as transform

In [12]:
#from transforms_templates.transforms.io.mesh import LoadPLYd
#from transforms_templates.transforms.io.image2d import LoadPNGandJSONd
#from transforms_templates.transforms.compose import Compose

#from monai.transforms import LoadDatad, LoadPNG
#from monai.config import KeysCollection
#from monai.transforms import Transform

import tt

## Test loader transform

In [27]:

tfms_load = tt.Compose([
    tt.LoadPLYd(keys=['mesh']),
    tt.LoadPNGandJSONd(keys=['image2d'], overwriting=True),
    tt.Transposed(keys=['image2d'], indices=[2, 0, 1]),
    tt.DebugDict(keys=['mesh', 'image2d'])
])

tfms_convert_1 = tt.Compose([
    tt.ScaleIntensityDict(keys=['image2d'], minv=None,maxv=None, factor=-254/255),
    tt.RandFlipX(keys=['mesh', 'image2d'], prob=0.99, image_spatial_axis=1),
    # image_spatial_axis: for 2D images axis starts without channel (ch, 0-h, 1-w), e.g. 1 - width
    # https://github.com/Project-MONAI/MONAI/blob/master/monai/transforms/spatial/array.py#L262  
    tt.ToPointCloud(keys=['mesh'], new_keys=['pc'], method='centers',
                    features_from_vertices=['red', 'green', 'blue'],
                    labels_from_faces=['shape_id']
                   ),
    tt.SamplePoints(keys=['pc'], num_points=10000),
    tt.CropRect(keys=['pc'], sizes=[-25, 10, -5.5, 25, 40, 5.5]), # xyzxyz
    tt.DebugDict(keys=['pc', 'image2d'])
])

tfms_convert_2 = tt.Compose([
    tt.ToSparseVoxels(keys=['pc'],
                      coords_range=[-25, 10, -5.5, 25, 40, 5.5],
                      voxel_size=[1, 2.,  0.5],
                      add_local_pos=True,
                      new_keys_prefixes = ['sp']),
    tt.ToTensorD(keys=['sp_coords', 'sp_features', 'sp_shape_id']),
    tt.DebugDict(keys=None)
])

# all transforms
tfms = tt.Compose(tfms_load.transforms + tfms_convert_1.transforms + tfms_convert_2.transforms)

### Test transfroms


In [28]:
# object before transformation
ds[0]

{'mesh': PosixPath('/media/Linux_4Tb/synth3D/tablets_30/corrupted/tablets_30_00000.ply'),
 'image2d': PosixPath('/media/Linux_4Tb/synth3D/tablets_30/renders/tablets_30_00000_top_rgb.png'),
 'image2d_meta_dict': PosixPath('/media/Linux_4Tb/synth3D/tablets_30/renders/tablets_30_00000_top_meta_info.json')}

In [29]:
o = tfms_load(ds[0], debug=True)

Debug 1
  keys: ['mesh', 'image2d', 'image2d_meta_dict', 'mesh_meta_dict']
  self.keys: ('mesh', 'image2d')
'mesh':
<trimesh.Trimesh(vertices.shape=(20424, 3), faces.shape=(40836, 3))>

  x                        shape: (20424,)              dtype: float64        min:  -26.40680,  max:   27.44418,  mean:    0.08402
  y                        shape: (20424,)              dtype: float64        min:   12.92263,  max:   43.61013,  mean:   31.67863
  z                        shape: (20424,)              dtype: float64        min:   12.92263,  max:   43.61013,  mean:   31.67863
  vertex_features:
    red                      shape: (20424,)              dtype: uint8          min:  128.00000,  max:  128.00000,  mean:  128.00000
    green                    shape: (20424,)              dtype: uint8          min:  128.00000,  max:  128.00000,  mean:  128.00000
    blue                     shape: (20424,)              dtype: uint8          min:  128.00000,  max:  128.00000,  mean:  128.00000
   

In [30]:
o.keys()

dict_keys(['mesh', 'image2d', 'image2d_meta_dict', 'mesh_meta_dict'])

In [31]:
#log(o)

In [32]:
#show_image(o['image2d'], factor=1 / 255)

In [33]:
b = tfms_convert_1(o, debug=True)

Debug 1
  keys: ['image2d', 'image2d_meta_dict', 'mesh_meta_dict', 'pc']
  self.keys: ('pc', 'image2d')
'pc':
x                        shape: (9177,)               dtype: float32        min:  -24.99984,  max:   24.99968,  mean:    0.48211
y                        shape: (9177,)               dtype: float32        min:   12.94880,  max:   43.59680,  mean:   32.57383
z                        shape: (9177,)               dtype: float32        min:   -5.49804,  max:    5.49940,  mean:   -0.03761
features                 shape: (9177, 6)             dtype: float32        min:   -0.99999,  max:  128.00000,  mean:   64.00053
shape_id                 shape: (9177,)               dtype: int8           min:    0.00000,  max:    7.00000,  mean:    3.41528
features_original_keys: ['red', 'green', 'blue', 'nx', 'ny', 'nz']
labels_keys: ['shape_id']
'image2d'                shape: (4, 256, 256)         dtype: float32        min:    0.18824,  max:    1.00000,  mean:    0.73259



In [34]:
#show_image(b['image2d'], factor=1.0)

In [35]:
c = tfms_convert_2(b, debug=True)

Debug 1
  keys: ['image2d', 'image2d_meta_dict', 'mesh_meta_dict', 'sp_coords', 'sp_features', 'sp_shape_id', 'sp_features_original_keys', 'sp_labels_keys']
  self.keys: (None,)
   image2d                  shape: (4, 256, 256)         dtype: float32        min:    0.18824,  max:    1.00000,  mean:    0.73259
image2d_meta_dict:
   filename_or_obj          /media/Linux_4Tb/synth3D/tablets_30/renders/tablets_30_00000_top_rgb.png
   spatial_shape            (256, 256)
   format                   'PNG'
   mode                     'RGBA'
   width                    256
   height                   256
info:
   Software                 'Matplotlib version3.3.2, https://matplotlib.org/'
   dpi                      (100, 100)
camera_info:
   projection_matrix        [[3.7290582683229117, 0.0, 0.0, 0.0], [0.0, 3.7290582683229117, 0.0, 0.0], [0.0, 0.0, -1.0, -0.1], [0.0, 0.0, -1.0, 0.0]]
   pos                      [[0.5207025366977089, 0.0, 0.0, 0.06674625131234299], [0.0, 0.5470937647723557, 0.0

# Dataset

In [39]:
ds = Dataset(data=data_dicts, transform=tfms)


In [41]:
o = ds[0]

In [42]:
log(o)

   image2d                  shape: (4, 256, 256)         dtype: float32        min:    0.18824,  max:    1.00000,  mean:    0.73259
image2d_meta_dict:
   filename_or_obj          /media/Linux_4Tb/synth3D/tablets_30/renders/tablets_30_00000_top_rgb.png
   spatial_shape            (256, 256)
   format                   'PNG'
   mode                     'RGBA'
   width                    256
   height                   256
info:
   Software                 'Matplotlib version3.3.2, https://matplotlib.org/'
   dpi                      (100, 100)
camera_info:
   projection_matrix        [[3.7290582683229117, 0.0, 0.0, 0.0], [0.0, 3.7290582683229117, 0.0, 0.0], [0.0, 0.0, -1.0, -0.1], [0.0, 0.0, -1.0, 0.0]]
   pos                      [[0.5207025366977089, 0.0, 0.0, 0.06674625131234299], [0.0, 0.5470937647723557, 0.0, -0.04717951268335549], [0.0, 0.0, 0.22825356909249078, 2.0187850111351904], [0.0, 0.0, 0.0, 1.0]]
   mesh_transform           [[0.02997280864274262, 0.0008099670012058981, -0.0

# Collate/merge

In [50]:
#from fastai_sparse.data_items import extract_data

In [51]:
ex1 = ds[0]
ex2 = ds[1]
ex3 = ds[2]

In [52]:
examples = [ex1, ex2, ex3]

In [124]:
# TODO: dtypes are converted in transforms before
# TODO: work with numpy (work with tensors now)


class Collater():
    """
    For use in configs (hydra)
    """
    def __init__(self,
                 as_list = ['id', 'num_points'],
                 as_stack = ['image2d'],  # tensor [(C, H, W), (C, H, W), ] ---> [B, C, H, W, ...]
                 as_pack=['sp_features', 'sp_shape_id'],
                 as_pack_with_index=['sp_coords'],  # for SparseConNet
                 num_points_source_key='sp_coords',
                 ):
        self.as_list = as_list
        self.as_stack = as_stack
        self.as_pack = as_pack
        self.as_pack_with_index = as_pack_with_index
        self.num_points_source_key = num_points_source_key
        
    # __call__ ?
    def collate(self, examples):
        res = {}
    
        for key in self.as_pack_with_index:
            a = [d[key] for d in examples]
            # TODO: 
            ones = torch.from_numpy(np.vstack([idx * np.ones((x.shape[0], 1), dtype="int64") for idx, x in enumerate(a)]))
            a = torch.cat(a, dim=0)
            a = torch.cat([a, ones], dim=1)
            res[key] = a

        for key in self.as_pack:
            a = [d[key] for d in examples]
            a = torch.cat(a, dim=0)
            res[key] = a

        if "num_points" in self.as_list:
            num_points = [len(d[self.num_points_source_key]) for d in examples]
            res["num_points"] = num_points

        return res

    def num_points_of_example(self, example, key=['sp_coords']):
        # TODO: as transform which fill aditioanal key 'num_points'
        return len(example[key])

    
def custom_merge_fn(
    examples,              # List of examples(items) of dataset to be merged
    # separate_labels=True,  # Return {'x1': any, 'x2': any, 'y': any} or ({'x1': any, 'x2': any }, y)  if True
    # TODO: names of params
    # https://github.com/facebookresearch/pytorch3d/blob/master/docs/notes/batching.md

    as_list = ['id', 'num_points'],
    as_stack = ['image2d'],  # tensor [(C, H, W), (C, H, W), ] ---> [B, C, H, W, ...]
    as_pack_with_index=['sp_coords'],  # for SparseConNet
    as_pack=['sp_features', 'sp_shape_id'],
    num_points_source_key='sp_coords',
):
    
    res = {}
    
    for key in as_pack_with_index:
        a = [d[key] for d in examples]
        ones = torch.from_numpy(np.vstack([idx * np.ones((x.shape[0], 1), dtype="int64") for idx, x in enumerate(a)]))
        a = torch.cat(a, dim=0)
        a = torch.cat([a, ones], dim=1)
        res[key] = a
    
    for key in as_pack:
        a = [d[key] for d in examples]
        a = torch.cat(a, dim=0)
        res[key] = a
    
    if "num_points" in as_list:
        num_points = [len(d[num_points_source_key]) for d in examples]
        res["num_points"] = num_points

    return res

In [125]:
[len(d['sp_coords']) for d in examples]

[9151, 9222, 9037]

In [126]:
batch = custom_merge_fn(examples)

In [127]:
collater = Collater()

In [128]:
batch = collater.collate(examples)

In [129]:
log(batch)

   sp_coords                shape: (27410, 4)            dtype: torch.int64    min:          0,  max:         49,  mean:   11.64326
   sp_features              shape: (27410, 9)            dtype: torch.float32  min:  -25.49997,  max:  128.00000,  mean:   40.97920
   sp_shape_id              shape: (27410,)              dtype: torch.int8     min:    0.00000,  max:    7.00000,  mean:    3.46738
   num_points               [9151, 9222, 9037]


In [134]:
from transforms_templates.collate.collate import Collater