# Treatment planning

This notebook will demonstrate:

1. How to call the ChohoCloud API to generate prescriptions and then generate the target position based on the prescriptions.
2. How to generate the treatment plan (steps) based on the target position.

To use this notebook, you need to prepare the following for a case:

1. Lateral cephalometric X-ray.
2. Smiling photo.
3. Upper and lower jaw oral scan meshes.

In [None]:
import os
import time
import requests
import json
import trimesh
import urllib
import numpy as np

## Define call rules
Please modify the following code blocks based on the information you obtained from us.

In [None]:
# Chohotech service request URL, sent with the API documentation.
base_url = "<service request URL>"

# Chohotech file service URL, sent with the API documentation.
file_server_url = "<service file server URL>"

# The authentication header must be passed in. Please keep the TOKEN confidential!!! If it is leaked, please contact us immediately to reset it. All tasks using this TOKEN will be charged to your account.
zh_token = "<your company's service Token, sent with the contact>" # All API calls must be authenticated with the token.

user_group = "APIClient" # User group, usually named APIClient.

# Your company's user_id, sent with the API documentation.
user_id = "<your company's user_id>"

# If you have received creds.json, it will be read directly below.
if os.path.exists('../../creds.json'):
    creds = json.load(open('../../creds.json', 'r'))
    base_url = creds['base_url']
    file_server_url = creds['file_server_url']
    zh_token = creds['zh_token']
    user_id = creds['user_id']
    print("loaded creds from creds.json")

In [None]:
def upload_file(file_name):
    ext = file_name.split('.')[-1]
    data = open('../../data/' + file_name, 'rb').read()
    resp = requests.get(file_server_url + f"/scratch/{user_group}/{user_id}/upload_url?" +
                        f"postfix={ext}", # Must specify postfix, i.e., file extension
                        headers={"X-ZH-TOKEN": zh_token}) # Get signed upload URL
    resp.raise_for_status()

    upload_url = resp.text[1:-1] # Returns a single string JSON "string", can also use json.loads(resp.text)

    resp = requests.put(upload_url, data) # No auth header is needed for uploading to the cloud storage service

    resp.raise_for_status()
    path = "/".join(urllib.parse.urlparse(upload_url).path.lstrip("/").split("/")[3:])
    urn = f"urn:zhfile:o:s:{user_group}:{user_id}:{path}"
    return urn

def run_job_and_get_results(json_call, timeout_sec):
    headers = {
      "Content-Type": "application/json",
      "X-ZH-TOKEN": zh_token
    }

    url = base_url + '/run'

    response = requests.request("POST", url, headers=headers, data=json.dumps(json_call))
    response.raise_for_status()
    create_result = response.json()
    run_id = create_result['run_id']
    print("workflow id is", run_id)
    url = base_url + f"/run/{run_id}"

    start_time = time.time()
    while time.time()-start_time < timeout_sec:
        time.sleep(0.3)
        response = requests.request("GET", url, headers=headers)
        result = response.json()
        if result['completed'] or result['failed']:
            break

    if not result['completed']:
        if result['failed']:
            raise ValueError("API failed due to " + str(result['reason_public']))
        raise TimeoutError("API timeout")

    print("API finished in {}s".format(time.time()-start_time))
    url = base_url + f"/data/{run_id}"
    response = requests.request("GET", url, headers=headers)
    return response.json()

def retrieve_data(urn):
    return requests.get(file_server_url + f"/file/download?" + urllib.parse.urlencode({
                        "urn": urn}),
                        headers={"X-ZH-TOKEN": zh_token}).content

def retrieve_mesh(mesh_file_json):
    resp = requests.get(file_server_url + f"/file/download?" + urllib.parse.urlencode({
                        "urn": mesh_file_json['data']}),
                        headers={"X-ZH-TOKEN": zh_token})
    return trimesh.load(trimesh.util.wrap_as_stream(resp.content), file_type=mesh_file_json['type'])

## Automatic prescrition and target position

The API calling pipeline in this chapter can be replaced by following single workflow.

```python
json_call = {
  "spec_group": "mesh-processing",
  "spec_name": "oral-arrangement-medical", 
  "spec_version": "1.0-snapshot",
  "user_group": user_group,
  "user_id": user_id
}
```

### Acquire teeth segmentation

In [None]:
json_call = {
  "spec_group": "mesh-processing",
  "spec_name": "oral-denoise-prod",
  "spec_version": "1.0-snapshot",
  "user_group": user_group,
  "user_id": user_id,
  "input_data": {
      "mesh": {"type":"drc", "data": upload_file("upper_jaw_scan.drc")},
      "jaw_type": "Upper"
  },
  'output_config': {
    "teeth_comp": {"type": "ply"}
  }
}
result_upper_jaw = run_job_and_get_results(json_call, 300)

In [None]:
json_call = {
  "spec_group": "mesh-processing",
  "spec_name": "oral-denoise-prod",
  "spec_version": "1.0-snapshot",
  "user_group": user_group,
  "user_id": user_id,
  "input_data": {
      "mesh": {"type":"ply", "data": upload_file("lower_jaw_scan.ply")},
      "jaw_type": "Lower"
  },
  'output_config': {
    "teeth_comp": {"type": "ply"}
  }
}
result_lower_jaw = run_job_and_get_results(json_call, 300)

### Acquire analysis results of ceph and frontal smile images

In [None]:
json_call = {
  "spec_group": "ceph",
  "spec_name": "ceph-analysis",
  "spec_version": "1.0-snapshot",
  "user_group": user_group,
  "user_id": user_id,
  "input_data": {
      "image": upload_file("ceph.jpg")
  }
}
result_ceph = run_job_and_get_results(json_call, 20)

In [None]:
json_call = {
  "spec_group": "face",
  "spec_name": "smile-analysis",
  "spec_version": "1.0-snapshot",
  "user_group": user_group,
  "user_id": user_id,
  "input_data": {
      "image": upload_file("face_smile.jpg")
  }
}
result_smile = run_job_and_get_results(json_call, 20)

### Acquire automatic prescriptions

In [None]:
json_call = {
  "spec_group": "mesh-processing",
  "spec_name": "auto-form",
  "spec_version": "1.0-snapshot",
  "user_group": user_group,
  "user_id": user_id,
  "input_data": {
      "upper_teeth_dict": result_upper_jaw["teeth_comp"],
      "upper_align_matrix": result_upper_jaw["align_matrix"],
      "upper_axis_matrix_dict": result_upper_jaw["axis"],
      "lower_teeth_dict": result_lower_jaw["teeth_comp"],
      "lower_align_matrix": result_lower_jaw["align_matrix"],
      "lower_axis_matrix_dict": result_lower_jaw["axis"],
      "ceph_metric_pts_dict": result_ceph['result']["kps"],
      "frontal_smiling_pts_dict": result_smile['result']['kps'],
      "meta": result_ceph['result']['meta']
  }
}
result_form = run_job_and_get_results(json_call, 500)

In [None]:
def form2str(form):
    def collision2str(collisions):
        collision_name_dict = {
            'left_back_teeth_IPR': 'Left posterior teeth IPR',
            'front_teeth_IPR': 'Front teeth IPR',
            'right_back_teeth_IPR': 'Right posterior teeth IPR',
            'move_left_molar': 'Move left molar distally',
            'move_right_molar': 'Move right molar distally',
            'move_both_molar_forward': 'Move molars mesially'
        }

        return ', '.join(
            [collision_name_dict[
                 collision[0] if isinstance(collision, tuple) or isinstance(collision, list) else collision] for
             collision in
             collisions])

    cusp_name_dict = {
        'near_cusp': 'Near cusp',
        'mid_cusp': 'Mid cusp',
        'far_cusp': 'Far cusp'
    }
    direction_name_dict = {
        'backward': 'Retrusion',
        'forward': 'Protrusion',
        'down': 'Depression',
        'up': 'Elonagtion',
        'left': 'Left shift',
        'right': 'Right shift'
    }

    form_str = f"""
    Mandibular Target Position
    1. Anterior teeth sagittal target position: Based on tooth #__, __ and __mm
    {form['front_y_axis_position'][0]},{cusp_name_dict[form['front_y_axis_position'][1]]},{direction_name_dict[form['front_y_axis_position'][2]]} {form['front_y_axis_position'][3]:.1f}
    2. Vertical target position: Posterior teeth stationary, anterior teeth __mm
    {direction_name_dict[form['z_axis_position']['front_teeth'][0]]} {form['z_axis_position']['front_teeth'][1]:.1f}

    Upper and Lower Jaw Occlusion Relationship
    1. Occlusion relationship [Left]: Canine to __ relationship, molar to __ relationship
       Class I, Class I
    2. Occlusion relationship [Right]: Canine to __ relationship, molar to __ relationship
       Class I, Class I
    3. Occlusion relationship [Overjet]: __ Overjet
       Standard
    4. Occlusion relationship [Overbite]: __ Overbite, __ posterior teeth
       Standard, Maintain
    5. Occlusion relationship [Molar Class II|Class III]
       Correction

    Midline Target Position
    1. [Upper Jaw] Midline target position
       {direction_name_dict[form['middle_line_position']['U'][1]]} {form['middle_line_position']['U'][2]:.1f} mm
    2. [Lower Jaw] Midline target position
       Align with upper jaw midline

    Collision Removal|Gap Closure
    1. Upper Jaw
       {collision2str(form['collision_removal']['U']['1'])}
    2. Lower Jaw
       {collision2str(form['collision_removal']['L']['1'])}
    3. Extracted Tooth Numbers
       {' '.join(map(str, form['remove_teeth_set']))}
    4. Reserved Gap
       {' '.join([f'{k}={v}' for k, v in form['init_gap'].items()]) if len(form['init_gap']) > 0 else 'None'}
    """

    return form_str

print(form2str(json.loads(result_form['result']['form'])))

### Generate target position based on prescritions

Note: you can adjust the generated prescription or just manually write it.

In [None]:
json_call = {
  "spec_group": "mesh-processing",
  "spec_name": "arrange-with-form",
  "spec_version": "1.0-snapshot",
  "user_group": user_group,
  "user_id": user_id,
  "input_data": {
      "upper_teeth_dict": result_upper_jaw["teeth_comp"],
      "upper_align_matrix": result_upper_jaw["align_matrix"],
      "upper_axis_matrix_dict": result_upper_jaw["axis"],
      "lower_teeth_dict": result_lower_jaw["teeth_comp"],
      "lower_align_matrix": result_lower_jaw["align_matrix"],
      "lower_axis_matrix_dict": result_lower_jaw["axis"],
      "form": result_form['result']['form'],
      "matrix_3d": result_form['result']['matrix_3d'],
  },
  'output_config': {
    "arranged_comp": {"type": "ply"}
  }
}
result_arrangement = run_job_and_get_results(json_call, 500)

In [None]:
sum([retrieve_mesh(result_arrangement['result']['arranged_comp'][k]) for k in result_arrangement['result']['arranged_comp'].keys()]).show()

## Automatic treatment planning (stepping)


In [None]:
json_call = {
  "spec_group": "mesh-processing",
  "spec_name": "auto-step",
  "spec_version": "1.0-snapshot",
  "user_group": user_group,
  "user_id": user_id
}
json_call["input_data"] = {
    "upper_teeth_dict": result_upper_jaw["teeth_comp"],
    "upper_align_matrix": result_upper_jaw["align_matrix"],
    "upper_axis_matrix_dict": result_upper_jaw["axis"],
    "lower_teeth_dict": result_lower_jaw["teeth_comp"],
    "lower_align_matrix": result_lower_jaw["align_matrix"],
    "lower_axis_matrix_dict": result_lower_jaw["axis"],
    "transformation_dict": result_arrangement["result"]['transformation_dict']
}
result_step = run_job_and_get_results(json_call, 2200)

In [None]:
print("step count:", len(result_step['result']['step_dict']))

In [None]:
original_comp = {**{k: retrieve_mesh(v) for k, v in result_upper_jaw["teeth_comp"].items()},
                 **{k: retrieve_mesh(v) for k, v in result_lower_jaw["teeth_comp"].items()}}
def show_step(step_index):
    result_mesh = None
    for k, m in original_comp.items():
        if k in result_step['result']['step_dict'][step_index]:
            result_mesh += m.copy().apply_transform(result_step['result']['step_dict'][step_index][k])
    return result_mesh

In [None]:
show_step(10).show()