# Deep Learning on Synthetic Data

## 1. Make interactive visualization of the cup

In [None]:
import plotly.graph_objects as go
import plotly.io as pio
import plotly.figure_factory as ff
import trimesh

In [None]:
mesh = trimesh.load("../data/cup/cup_triangle.ply")

vertices = mesh.vertices
faces = mesh.faces
face_colors = mesh.visual.face_colors

x = vertices[:, 0]
y = vertices[:, 1]
z = vertices[:, 2]
i = faces[:, 0]
j = faces[:, 1]
k = faces[:, 2]

fig = go.Figure(ff.create_trisurf(x=x, y=y, z=z,
                                  simplices=faces,
                                  plot_edges=True,
                                  edges_color="black",
                                  colormap="rgb(200, 200, 200)",
                                  show_colorbar=False).data)

buttons = [dict(label="w/ mesh", method="update", args=[dict(visible=[True, True])]),
           dict(label="w/o mesh", method="update", args=[dict(visible=[True, False])]),
           dict(label="wireframe", method="update", args=[dict(visible=[False, True])])]

camera = dict(eye=dict(x=2, y=2, z=2))

fig.update_layout(scene=dict(
                    xaxis=dict(visible=False),
                    yaxis=dict(visible=False),
                    zaxis=dict(visible=False),
                    aspectmode='data'),
                  height=500,
                  margin=dict(r=0, l=0, b=0, t=0, pad=0),
                  scene_dragmode="orbit",
                  scene_camera=camera,
                  updatemenus=[dict(buttons=buttons, x=0.1, y=1)],
                  showlegend=False)

In [None]:
# Save figure
pio.write_html(fig,
               file='../_includes/figures/cup.html',
               full_html=False,
               include_plotlyjs='cdn')

## 2. Train MASK R-CNN from Detectron2 on cup data

In [None]:
import os
import random
import yaml
import cv2
import matplotlib.pyplot as plt
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets import load_coco_json, register_coco_instances
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode

In [None]:
# Setup
path_to_coco_json = "/path/to/blenderproc/coco_annotations.json"
path_to_images = "/path/to/blenderproc/images/coco_data"
path_to_config_yaml = "/path/to/detectron2/config/mask_rcnn_R_50_FPN_3x.yaml"

# Use this for training. Use the below two lines instead for inference if you want "cup" as label instead of "1".
register_coco_instances("cup", {}, path_to_coco_json, path_to_images)
# DatasetCatalog.register("cup", lambda: load_coco_json(path_to_coco_json, path_to_images))
# MetadataCatalog.get("cup").set(thing_classes=["cup"], json_file=path_to_coco_json, image_root=path_to_images)

In [None]:
# Config settings
cfg = get_cfg()
cfg.merge_from_file(path_to_config_yaml)
cfg.INPUT.MASK_FORMAT="bitmask"  # Standard output format of BlenderProc
cfg.DATASETS.TRAIN = ("cup",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 8
# initialize from model zoo
cfg.MODEL.WEIGHTS = "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"
cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.0025
cfg.SOLVER.MAX_ITER = 300
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

In [None]:
# Train model
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

In [None]:
# Load trained weights and run inference (on train data; just for visualization)
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set the testing threshold for this model
cfg.DATASETS.TEST = ("cup",)
predictor = DefaultPredictor(cfg)

metadata = MetadataCatalog.get("cup")
dataset_dicts = DatasetCatalog.get("cup")
figure, axes = plt.subplots(1, 3, figsize=(16, 16), tight_layout=True)
axes = axes.tolist()
for d in random.sample(dataset_dicts, 3):    
    im = cv2.imread(d["file_name"])
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1],
                   metadata=metadata, 
                   instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels
    )
    v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    axis = axes.pop()
    axis.imshow(v.get_image())
    axis.axis('off')
plt.show()