In [None]:
from typing import Callable, Dict, Tuple, Iterator

from pandas import DataFrame

import pyarrow as pa
from pyarrow.flight import (
    FlightInfo,
    FlightServerBase,
    ServerCallContext,
    FlightDescriptor,
    FlightEndpoint,
    MetadataRecordBatchReader,
    MetadataRecordBatchWriter,
)
from sklearn.linear_model import LinearRegression


def linear_regression(data: DataFrame) -> DataFrame:
    reg = LinearRegression().fit(data[["x"]], data["y"])
    test = DataFrame([{"x": i * 1.0} for i in range(0, 1000, 100)])
    test["predict"] = reg.predict(test[["x"]])
    return test


OPERATIONS: Dict[str, Tuple[pa.Schema, Callable[[DataFrame], DataFrame]]] = {
    "linear_regression": (
        pa.schema(
            [
                pa.field("x", pa.float64()),
                pa.field("y", pa.float64()),
            ]
        ),
        linear_regression,
    )
}


class CustomServer(FlightServerBase):
    def __init__(self, location: str):
        super().__init__(location)
        self.location = location

    def list_flights(
        self,
        context: ServerCallContext,
        criteria: bytes,
    ) -> Iterator[FlightInfo]:

        return [
            FlightInfo(
                schema=schema,
                descriptor=FlightDescriptor.for_path(name),
                endpoints=[FlightEndpoint(name, [self.location])],
                total_records=-1,
                total_bytes=-1,
            )
            for name, (schema, _) in OPERATIONS.items()
        ]

    def do_exchange(
        self,
        context: ServerCallContext,
        descriptor: FlightDescriptor,
        reader: MetadataRecordBatchReader,
        writer: MetadataRecordBatchWriter,
    ):
        path = descriptor.path[0].decode()
        if path in OPERATIONS.keys():
            data = reader.read_pandas()
            _, op = OPERATIONS[path]
            result = op(data)

            table = pa.Table.from_pandas(result)
            writer.begin(table.schema)
            writer.write_table(table)
        else:
            raise Exception(f"invalid path: {path}")

In [None]:
server = CustomServer("grpc://0.0.0.0:8080")
server.serve()