# ROS gRPC API generator

This notebook generates a gRPC server implementing the available ROS topics and services. Please see the [README](/README.md) for more detail.


### Steps:
 - [Snapshot ROS topics and serivces](#snapshot)
 - [Load snapshots](#load)
 - [Generate the .proto file](#proto)
 - [Generate the gRPC server](#server)

In [None]:
import rospy
import rosmsg
import rostopic
import rosservice
import os
import re
import io
import configparser
import doctest

In [None]:
# Config

# Path to the generated server sources
OUT_DIR = os.path.abspath(os.path.join(os.path.abspath(''), 'generated'))
# TODO
SNAPSHOT_FILE = os.path.join(OUT_DIR, 'snapshot.ini')
# Path to the generated .proto file
PROTO_FILE = os.path.join(OUT_DIR, 'ros.proto')

In [None]:
def write_file(path, content):
    folder = os.path.dirname(path)
    
    if not os.path.exists(folder):
        os.makedirs(folder)
    
    f = open(os.path.join(path), 'w+')
    f.write(content)
    f.close()
    print("file was written to {}".format(path))

<a id="snapshot"></a>
# Snapshot

In [None]:

config = configparser.ConfigParser()
# enable case-sensitive keys (for the ROS types)
config.optionxform=str
config['MESSAGE_DEFINITIONS'] = {}
config['TOPICS'] = {}
config['SERVICES'] = {}
config['ACTIONS'] = {}

published, subscribed = rostopic.get_topic_list()
published_topics = dict(map(lambda x: (x[0], x[1]), (published + subscribed))).items()
    
for (message_name, ros_type) in published_topics:
    config['TOPICS'][message_name] = ros_type
    config['MESSAGE_DEFINITIONS'][ros_type] = rosmsg.get_msg_text(ros_type)

services = rosservice.get_service_list()

for service_name in services:
    ros_type = rosservice.get_service_type(service_name)
    config['SERVICES'][service_name] = ros_type
    config['MESSAGE_DEFINITIONS'][ros_type] = rosmsg.get_srv_text(ros_type)
    
def flatten_types():
    changed = False
    
    for type_name in config['MESSAGE_DEFINITIONS'].keys():
        fields = config['MESSAGE_DEFINITIONS'][type_name].split('\n')
        top_level_fields = []
        sub_fields = []
        for field in fields:
            if field.startswith('  '):
                sub_fields.append(field)
            else:
                if sub_fields:
                    sub_type_name = top_level_fields[-1].split(' ')[0]
                    config['MESSAGE_DEFINITIONS'][sub_type_name] = '\n'.join(sub_fields)
                    sub_fields = []
                top_level_fields.append(field)

        config['MESSAGE_DEFINITIONS'][type_name] = '\n'.join(top_level_fields)
    
    if (changed):
        flatten_types()

flatten_types()
    
    
# convert snapshot to string and save to file
with io.StringIO() as ss:
    config.write(ss)
    ss.seek(0) # rewind
    write_file(SNAPSHOT_FILE, ss.read())


<a id="load"></a>
# Load

In [None]:
class RosSnapshot:
    def __init__(self):
        self.config = configparser.ConfigParser()
        # enable case-sensitive keys (for the ROS types)
        self.config.optionxform=str
        
    def load(self, path):
        self.config.read(path)
    
    def get_message_definitions(self):
        return self.config["MESSAGE_DEFINITIONS"]
    
    def get_message_definition_packages(self):
        ros_types = self.get_message_definitions().keys()
        return set(map(lambda t: t.split('/')[0], ros_types))
    
    def get_topics(self):
        """
        Returns Map<topic, ros_type>
        """
        return self.config["TOPICS"]
    
    def get_services(self):
        """
        Returns Map<service, ros_type>
        """
        return self.config["SERVICES"]

    def get_sections(self, ros_type):
        """
        Returns the parts of a ROS message as a list. (one part for topics, two for services, three for actions)
        """
        if ros_type not in self.config["MESSAGE_DEFINITIONS"]:
            raise KeyError("Can't find message definition for '{}'".format(ros_type))

        return self.config["MESSAGE_DEFINITIONS"][ros_type].split('---')
    
    def get_fields(self, ros_type, section=0):
        """
        Returns (ros_type, field_name)[]
        """
        sections = self.get_sections(ros_type)
        fields = sections[section].strip().split('\n')
        fields = filter(lambda f: f != '', fields)
        # skip constants
        fields = filter(lambda f: not '=' in f, fields)
        fields = map(lambda f: f.split(' '), fields)
        return fields

    def __str__(self):
        return '<RosSnapshot topics={} services={} message_definitions={}>'.format(
            len(self.get_topics()),
            len(self.get_services()),
            len(self.get_message_definitions()))
    

snap = RosSnapshot()
snap.load(SNAPSHOT_FILE)
print(list(snap.get_message_definitions()))
print(list(snap.get_fields('roscpp/GetLoggers')))

# print(snap.config["MESSAGE_DEFINITIONS"]["turtlesim/Spawn"].split('\n'), 1)
# list(snap.get_message_definition_packages())

In [None]:
# Example ROS snapshot for unit tesing
TEST_SNAPSHOT_FILE = """
[MESSAGE_DEFINITIONS]
msgs/Foo = std_msgs/Header header
	string name
	uint32[] numbers
	msg/Bar[5] bar

msgs/Bar = uint8 number

srvs/Empty = ---

srvs/Baz = string logger
	---
	string level


[TOPICS]
/foo = msgs/Foo

[SERVICES]
/baz = srvs/Baz
"""
def tsnap():
    """
    Returns a RosSnapshot instance filled with the test data
    """
    snap = RosSnapshot()
    snap.config.read_string(TEST_SNAPSHOT_FILE)
    return snap

In [None]:
scalar_ros2pb = {
    'bool': 'bool',
    'int8': 'int32',
    'uint8': 'uint32',
    'int16': 'int32',
    'uint16': 'uint32',
    'int32': 'int32',
    'uint32': 'uint32',
    'int64': 'int64',
    'uint64': 'uint64',
    'float32': 'float',
    'float64': 'double',
    'string': 'string',
    'time': 'Time',
    'duration': 'Duration',
    'char': 'uint32',
    'byte': 'int32',
}



# TODO remame strip_ros_array_notation
def strip_array_notation(name):
    """Remove array notation from ROS msg types
    
    >>> strip_array_notation('Foo[]')
    'Foo'
    >>> strip_array_notation('Foo[123]')
    'Foo'
    >>> strip_array_notation('Foo')
    'Foo'
    """
    return re.sub(r'\[\d*\]$', '', name)

def type_ros2pb(ros_type):
    """Convert ROS msg type to protobuff type
    
    >>> type_ros2pb('byte')
    'int32'
    >>> type_ros2pb('std_msgs/Header')
    'std_msgs.Header'
    >>> type_ros2pb('string[]')
    'repeated string'
    >>> type_ros2pb('time')
    'Time'
    >>> type_ros2pb('string')
    'string'
    >>> type_ros2pb('uint32')
    'uint32'
    """
    pb_type = strip_array_notation(ros_type)
    if pb_type in scalar_ros2pb:
        pb_type = scalar_ros2pb[pb_type]

    if ros_type.endswith(']'):
        pb_type = 'repeated ' + pb_type
    
    return pb_type.replace('/', '.')

def topic2service_name(topic):
    """Convert ROS topic names to valid protobuff names (replace slash with underscore)
    
    >>> topic2service_name('/rosout')
    'rosout'
    >>> topic2service_name('/rosout_agg')
    'rosout_agg'
    >>> topic2service_name('/turtle1/pose')
    'turtle1_pose'
    >>> topic2service_name('/turtle1/color_sensor')
    'turtle1_color_sensor'
    """
    return topic.replace('/', '_')[1:]

def parse_ros_type(ros_type):
    """Convert a line in a ROS msg file to (package, typename, is_array)
    
    >>> parse_ros_type('rosgraph_msgs/Log')
    ('rosgraph_msgs', 'Log', False)
    >>> parse_ros_type('turtlesim/Pose')
    ('turtlesim', 'Pose', False)
    >>> parse_ros_type('uint32')
    (None, 'uint32', False)
    >>> parse_ros_type('time')
    (None, 'time', False)
    >>> parse_ros_type('rosgraph_msgs/Log')
    ('rosgraph_msgs', 'Log', False)
    """
    is_array = ros_type.endswith(']')
    ros_type = strip_array_notation(ros_type)
    
    if '/' in ros_type:
        package, typename = ros_type.split('/')
        return package, typename, is_array
    else:
        return None, ros_type, is_array

import doctest
doctest.testmod()

<a id="proto"></a>
# Proto

In [None]:
topic_service_template = """
service {service_name} {{
    rpc Publish({pb_type}) returns (Empty);
    rpc Subscribe(Empty) returns (stream {pb_type});
}}

"""

srv_service_template = """
service {service_name} {{
    rpc Call({pb_type}Request) returns ({pb_type}Response);
}}

"""

header = '''
syntax = 'proto3';

package ros;

message Empty {}

message Time {
    uint32 secs = 1;
    uint32 nsecs = 2;
}

message Duration {
    int32 secs = 1;
    int32 nsecs = 2;
}

'''

In [None]:
def generate_pb_message_definition_for_package(snap: RosSnapshot, pkg):
    """Gerate the proto definitions of a ros package
    >>> print(generate_pb_message_definition_for_package(tsnap(), "msgs"))
    message msgs {
      message Bar {
        uint32 number = 1;
      }
      message Foo {
        std_msgs.Header header = 1;
        string name = 2;
        repeated uint32 numbers = 3;
        repeated msg.Bar bar = 4;
      }
    }
    <BLANKLINE>
    <BLANKLINE>

    >>> print(generate_pb_message_definition_for_package(tsnap(), "srvs"))
    message srvs {
      message BazRequest {
        string logger = 1;
      }
      message BazResponse {
        string level = 1;
      }
      message EmptyRequest {
      }
      message EmptyResponse {
      }
    }
    <BLANKLINE>
    <BLANKLINE>
    """
    proto = 'message %s {\n' % pkg
    for ros_type in sorted(snap.get_message_definitions().keys()):
        pkgname, msgname = ros_type.split('/')
        msgname = strip_array_notation(msgname)
        if pkg == pkgname:
            section_count = len(snap.get_sections(ros_type))
            is_msg = section_count == 1
            is_srv = section_count == 2
            for section in range(section_count):
                postfix = ''
                if is_srv:
                    postfix = 'Request' if section == 0 else 'Response'

                proto += '  message %s {\n' % (msgname + postfix)

                fields = snap.get_fields(ros_type, section)
                for key, (field_type, field_name) in enumerate(fields):
                    definition = '{} {} = {};'.format(
                        type_ros2pb(field_type), field_name, key+1)
                    proto += '    {}\n'.format(definition)
                proto += '  }\n'
    proto += '}\n\n'
    return proto
doctest.testmod()
def generate_pb_message_all_definitions(snap: RosSnapshot):
    proto = ''
    for pkg in sorted(snap.get_message_definition_packages()):
        proto += generate_pb_message_definition_for_package(snap, pkg)
    return proto
                

def generate_pb_service(topic, ros_type):
    return msg_service_template.format(service_name=topic2service_name(topic), pb_type=type_ros2pb(ros_type))

def generate_proto_file(snap: RosSnapshot):
    print('Found {} messages'.format(len(snap.get_message_definitions())))
    
    content = header
    
    content += generate_pb_message_all_definitions(snap)
            
        
    for (topic, ros_type) in sorted(snap.get_topics().items()):
        content += topic_service_template.format(
            service_name=topic2service_name(topic), 
            pb_type=type_ros2pb(ros_type))

    for (service, ros_type) in sorted(snap.get_services().items()):
        content += srv_service_template.format(
            service_name=topic2service_name(service), 
            pb_type=type_ros2pb(ros_type))
            
    write_file(os.path.join(PROTO_FILE), content)



    
generate_proto_file(snap)

<a id="server"></a>
# Server

In [None]:
frame_template = '''
from concurrent import futures
import time
import math
import logging
import argparse
import sys
import threading
import time

import grpc

import rospy
import roslib.message

import ros_pb2 as ros_pb
import ros_pb2_grpc as ros_grpc


{classes}

def create_server():
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
{add_servicers}
    return server


def run_server(address):
    rospy.init_node('grpc_server', anonymous=True)
    server = create_server()
    server.add_insecure_port(address)
    server.start()
    print("gRPC server is running at %s" % address )
    rospy.spin()

def parse_args_and_run_server():
    run_server(parse_args())


def parse_args(args=None):
    parser = argparse.ArgumentParser(
        description='Generated ROS gRPC server')
    parser.add_argument(
        '-a', 
        '--address',
        help='host:port: host and port of the gRPC server to connect to', 
        default='[::]:50051')


    results = parser.parse_args(args)
    return results.address


if __name__ == '__main__':
    parse_args_and_run_server()

'''

add_servicer_template = 'ros_grpc.add_{servicer_class}_to_server({servicer_class}(), server)'

topic_class_template = '''
class {servicer_class}(ros_grpc.{servicer_class}):
    def __init__(self):
        self.pub = None
        self.Msg = roslib.message.get_message_class('{ros_type}')

    def Publish(self, pb_msg, context):
        if self.pub == None:
            self.pub = rospy.Publisher('{topic}', self.Msg, queue_size=10)

        ros_msg = self.Msg()
{copy_pb2ros}
        print('publishing', ros_msg)
        self.pub.publish(ros_msg)
        return ros_pb.Empty()

    def Subscribe(self, request, context):
        c = {{'unsubscribed': False}}
        ros_messages = []

        def callback(ros_msg):
            print('ros in', ros_msg)
            ros_messages.append(ros_msg)
        subscription = rospy.Subscriber('{topic}', self.Msg, callback)

        def on_rpc_done():
            c['unsubscribed'] = True
            print("Attempting to regain servicer thread...", c)
            subscription.unregister()

        context.add_callback(on_rpc_done)

        while not c['unsubscribed']:
            while ros_messages:
                ros_msg = ros_messages.pop(0)
{copy_ros2pb}
                yield pb_msg
            rospy.sleep(0.01)
'''

service_class_template = """
class {servicer_class}(ros_grpc.{servicer_class}):
    def Call(self, pb_msg, context):
        Srv = roslib.message.get_service_class('{ros_type}')
        call = rospy.ServiceProxy('{service}', Srv)
        ros_msg = Srv._request_class()
{copy_pb2ros}
        ros_msg = call(ros_msg)
{copy_ros2pb}
        return pb_msg
"""

def add_tab(lines, tabs=1):
    return re.sub(r'^([^$])',  '    ' * tabs + '\\1', lines, flags=re.MULTILINE)


def generate_msg_copier(snap: RosSnapshot, ros_type, left='pb_msg', right='ros_msg', new_instance=False):
    # TODO tests
    result = ''
    package, typename, _ = parse_ros_type(ros_type)
    
    if new_instance:
        if package:
            if left.startswith('pb_'):
                result += '{} = ros_pb.{}.{}()\n'.format(left, package, typename)
        else:
            result += '{} = ros_pb.{}()\n'.format(left, typename)

    if snap.get_fields(ros_type): #TODO is this needed
        for ros_fieldtype, fieldname in snap.get_fields(ros_type):
            package, typename, is_array = parse_ros_type(ros_fieldtype)

            is_complex = package is not None # TODO is_scalar

            if is_array:
                sub_left = '{}_'.format(left.split('.')[0])
                sub_right = '{}_'.format(right.split('.')[0])
                result += 'for {sub_right} in {right}.{fieldname}:\n'.format(
                    sub_right=sub_right, right=right, fieldname=fieldname)
                body = ''
                if is_complex:
                    body += generate_msg_copier(
                        snap, ros_fieldtype, sub_left, sub_right, True)
                    body += '{left}.{fieldname}.append({sub_left})\n'.format(
                        left=left, sub_left=sub_left, fieldname=fieldname)
                else:
                    body += '{left}.{fieldname}.append({sub_right})\n'.format(
                        left=left, sub_right=sub_right, fieldname=fieldname)
                result += add_tab(body)
            elif is_complex:
                sub_left = '{}.{}'.format(left, fieldname)
                sub_right = '{}.{}'.format(right, fieldname)
                result += generate_msg_copier(snap, ros_fieldtype, sub_left, sub_right, False)
            else:
                result += '{left}.{fieldname} = {right}.{fieldname}\n'.format(
                    left=left, right=right, fieldname=fieldname)

    return result


def generate_server(snap: RosSnapshot):
    add_servicers = []
    servicer_classes = []
    for topic, ros_type in sorted(snap.get_topics().items()):
        servicer_class = topic2service_name(topic) + 'Servicer'
        add_servicers.append(add_servicer_template.format(
            servicer_class=servicer_class))
        copy_ros2pb = generate_msg_copier(snap, ros_type, 'pb_msg', 'ros_msg')
        copy_pb2ros = generate_msg_copier(snap, ros_type, 'ros_msg', 'pb_msg')
        copy_ros2pb = add_tab(copy_ros2pb, 4)
        copy_pb2ros = add_tab(copy_pb2ros, 2)
        servicer_classes.append(topic_class_template.format(
            servicer_class=servicer_class, ros_type=ros_type, topic=topic, copy_ros2pb=copy_ros2pb, copy_pb2ros=copy_pb2ros))
    
    for service, ros_type in sorted(snap.get_services().items()):
        servicer_class = topic2service_name(service) + 'Servicer'
        add_servicers.append(add_servicer_template.format(
            servicer_class=servicer_class))
        copy_ros2pb = generate_msg_copier(snap, ros_type, 'pb_msg', 'ros_msg')
        copy_pb2ros = generate_msg_copier(snap, ros_type, 'ros_msg', 'pb_msg')
        copy_ros2pb = add_tab(copy_ros2pb, 2)
        copy_pb2ros = add_tab(copy_pb2ros, 2)
        servicer_classes.append(service_class_template.format(
            servicer_class=servicer_class, ros_type=ros_type, service=service, copy_ros2pb=copy_ros2pb, copy_pb2ros=copy_pb2ros))

    return frame_template.format(add_servicers=add_tab('\n'.join(add_servicers)), classes='\n'.join(servicer_classes))

write_file(os.path.join(OUT_DIR, '__init__.py'), "")
write_file(os.path.join(OUT_DIR, 'grpc_server.py'), 
           generate_server(snap))
print('grpc_server.py file generated')


In [None]:
!python3 -m grpc_tools.protoc \
    -I={os.path.relpath(OUT_DIR)} \
    --python_out={os.path.relpath(OUT_DIR)} \
    --grpc_python_out={os.path.relpath(OUT_DIR)} \
    {os.path.relpath(PROTO_FILE)}

### All done! Run `rosrun grpc_api_generator run_server` to kick in the gRPC server.