In [1]:
from __future__ import annotations

from typing import Generator
import pandas as pd
from contextlib import contextmanager
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from tqdm.autonotebook import tqdm
import more_itertools as mit
from pathlib import Path
from collections.abc import Iterable
from typing import Any



@contextmanager
def local_spark(config: dict[str, str] | None = None) -> Generator[SparkSession, None, None]:
    config = config or {}
    try:
        builder = SparkSession.builder
        builder.enableHiveSupport()
        builder.appName("local_spark")
        for key, value in config.items():
            builder.config(key, value)
        
        spark = builder.getOrCreate()
        yield spark
    finally:
        spark.stop()

  from tqdm.autonotebook import tqdm


In [2]:

from dataclasses import dataclass
#from protos.fileformat_pb2 import Blob, BlobHeader
#from protos.osmformat_pb2 import HeaderBlock, PrimitiveBlock

@dataclass
class BlobData:
    header: BlobHeader
    blob: Blob


@dataclass
class RawBlobData:
    blob_type: str
    data: bytes


@dataclass
class OsmInfo:
    version: int
    timestamp: int
    changeset: int
    uid: int
    user_sid: str


@dataclass
class OsmNode:
    id: int
    info: OsmInfo
    tags: dict[str, str]
    latitude: float
    longitude: float


@dataclass
class OsmWay:
    id: int
    info: OsmInfo
    tags: dict[str, str]
    nodes: list[int]



@dataclass
class OsmRelationMember:
    id: int
    role: str
    type: str


@dataclass
class OsmRelation:
    id: int
    info: OsmInfo
    tags: dict[str, str]
    members: list[OsmRelationMember]


@dataclass
class OsmData:
    nodes: list[OsmNode]
    ways: list[OsmWay]
    relations: list[OsmRelation]


In [3]:
import zlib

def decompress_blob(blob: Blob) -> bytes:
    match blob.WhichOneof("data"):
        case "raw":
            return blob.raw
        case "zlib_data":
            return zlib.decompress(blob.zlib_data)
        case _:
            raise ValueError("Blob has no data")


def decode_blob(header: BlobHeader, blob: Blob) -> HeaderBlock | PrimitiveBlock:
    data = decompress_blob(blob)
    match header.type:
        case "OSMHeader":
            return HeaderBlock.FromString(data)
        case "OSMData":
            return PrimitiveBlock.FromString(data)
        case _:
            raise ValueError(f"Unknown blob type: {header.type}")



def decode_blob_data(blob_data: BlobData) -> HeaderBlock | PrimitiveBlock:
    block = decode_blob(blob_data.header, blob_data.blob)
    return block

In [4]:
#from protos.osmformat_pb2 import Way, Node, Relation, DenseNodes, Info, DenseInfo


def delta_decode(values: Iterable[int]) -> list[int]:
    result = []
    current = 0
    for value in values:
        current += value
        result.append(current)
    return result


class ValueDecoder:
    def __init__(self, block: PrimitiveBlock) -> None:
        self.granularity = block.granularity or 100
        self.lat_offset = block.lat_offset or 0
        self.lon_offset = block.lon_offset or 0
        self.date_granularity = block.date_granularity or 1000

    def lat(self, value: int) -> float:
        return 0.000000001 * (self.lat_offset + (self.granularity * value))

    def lon(self, value: int) -> float:
        return 0.000000001 * (self.lon_offset + (self.granularity * value))

    def timestamp(self, value: int) -> int:
        return value * self.date_granularity


class PrimitiveBlockDecoder:
    def __init__(self, block: PrimitiveBlock) -> None:
        self.block = block
        self.string_table = [s.decode("utf-8") for s in block.stringtable.s]
        self.value_decoder = ValueDecoder(block)

    def decode_string(self, index: int) -> str:
        return self.string_table[index]

    def decode_info(self, info: Info) -> OsmInfo:
        return OsmInfo(
            version=info.version,
            timestamp=info.timestamp,
            changeset=info.changeset,
            uid=info.uid,
            user_sid=self.decode_string(info.user_sid),
        )

    def decode_tags(self, keys: list[int], vals: list[int]) -> dict[str, str]:
        keys = (self.decode_string(key) for key in keys)
        vals = (self.decode_string(val) for val in vals)
        return dict(zip(keys, vals))

    def decode_node(self, node: Node) -> OsmNode:
        return OsmWay(
            id=node.id,
            info=self.decode_info(node.info),
            tags=self.decode_tags(node.keys, node.vals),
            latitude=self.value_decoder(node.lat),
            longitude=self.value_decoder(node.lon),
        )

    def decode_dense_info(self, dense: DenseInfo) -> list[OsmInfo]:
        return [
            OsmInfo(
                version=version,
                timestamp=timestamp,
                changeset=changeset,
                uid=uid,
                user_sid=self.decode_string(user_sid),
            )
            for version, timestamp, changeset, uid, user_sid in zip(
                dense.version,
                delta_decode(dense.timestamp),
                delta_decode(dense.changeset),
                delta_decode(dense.uid),
                delta_decode(dense.user_sid),
            )
        ]

    def decode_dense_nodes(self, dense: DenseNodes) -> list[OsmNode]:
        tags = self.decode_tags(*mit.distribute(2, dense.keys_vals))
        return [
            OsmNode(
                id=id,
                info=info,
                tags=tags,
                latitude=lat,
                longitude=lon,
            )
            for id, info, lat, lon in zip(
                delta_decode(dense.id),
                self.decode_dense_info(dense.denseinfo),
                delta_decode(dense.lat),
                delta_decode(dense.lon),
                
            )
        ]

    def decode_way(self, way: Way) -> OsmWay:
        return OsmWay(
            id=way.id,
            info=self.decode_info(way.info),
            tags=self.decode_tags(way.keys, way.vals),
            nodes=delta_decode(way.refs),
        )

    def decode_relation(self, relation: Relation) -> OsmRelation:
        return OsmRelation(
            id=relation.id,
            info=self.decode_info(relation.info),
            tags=self.decode_tags(relation.keys, relation.vals),
            members=[
                OsmRelationMember(
                    id=member_id,
                    role=self.decode_string(role_sid),
                    type=Relation.MemberType.Name(member_type),
                )
                for role_sid, member_id, member_type in zip(
                    relation.roles_sid, delta_decode(relation.memids), relation.types
                )
            ],
        )



def decode_block(block: PrimitiveBlock) -> OsmData:
    decoder = PrimitiveBlockDecoder(block)
    nodes = [
        decoder.decode_node(node)
        for group in block.primitivegroup
        for node in group.nodes
    ]
    nodes += list(mit.flatten(decoder.decode_dense_nodes(group.dense) for group in block.primitivegroup))
    ways = [decoder.decode_way(way) for group in block.primitivegroup for way in group.ways]
    relations = [decoder.decode_relation(relation) for group in block.primitivegroup for relation in group.relations]
    return OsmData(nodes=nodes, ways=ways, relations=relations)

In [5]:
from io import BytesIO


def read_blob_data(source: BytesIO) -> BlobData | None:
    data = source.read(4)
    if len(data) == 0:
        return None

    header_size = int.from_bytes(data, "big")

    data = source.read(header_size)
    blob_header = BlobHeader.FromString(data)

    data = source.read(blob_header.datasize)
    blob = Blob.FromString(data)

    return BlobData(header=blob_header, blob=blob)


def read_raw_blob_data(source: BytesIO) -> RawBlobData | None:
    from protos.fileformat_pb2 import BlobHeader

    data = source.read(4)
    if len(data) == 0:
        return None

    header_size = int.from_bytes(data, "big")

    data = source.read(header_size)
    blob_header = BlobHeader.FromString(data)
    data = source.read(blob_header.datasize)
    
    return RawBlobData(blob_type=str(blob_header.type), data=data)


def read_blobs(source: BytesIO) -> Generator[BlobData, None, None]:
    while True:
        data = read_blob_data(source)
        if data is None:
            return
        if data.header.type == "OSMHeader":
            continue
        yield data
        


def read_raw_blobs(source: BytesIO) -> Generator[RawBlobData, None, None]:
    while True:
        data = read_raw_blob_data(source)
        if data is None:
            return
        yield data
        

# def read_pbf(filename: Path) -> None:
#     stats = dict(nodes=0, ways=0, relations=0)
#     with open(filename, "rb") as fin:
#         for data in tqdm(read_blobs(fin)):
#             block = decode_blob_data(data)
#             nodes, ways, relations = decode_block(block)
#             stats["nodes"] += len(nodes)
#             stats["ways"] += len(ways)
#             stats["relations"] += len(relations)

#     return stats
# #read_pbf(Path("/workspaces/data/osm/us-latest.osm.pbf"))
# read_pbf(Path("/workspaces/data/osm/nevada-latest.osm.pbf"))

In [6]:
from collections.abc import Iterator
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql.types import StructField, StructType, IntegerType, StringType, MapType, ArrayType, DoubleType, LongType



ReadState = dict[str, Any]

class PbfDataSource(DataSource):
    @classmethod
    def name(cls) -> str:
        return "pbf"

    def schema(self) -> str:
        return "header_type string, data binary"

    def reader(self, schema: StructType) -> PbfDataSourceReader:
        return PbfDataSourceReader(schema, self.options)
    

class PbfDataSourceReader(DataSourceReader):
    def __init__(self, schema: StructType, options: dict[str, Any]) -> None:
        self.schema = schema
        self.options = options

    def partitions(self) -> list[InputPartition]:
        return [InputPartition(0)]

    def read(self, partition: InputPartition) -> Iterator[tuple[str, bytes]]:
        with open(self.options["filename"], "rb") as fin:
            for data in read_raw_blobs(fin):
                yield str(data.blob_type), data.data


TagType = MapType(StringType(), StringType())

InfoType = StructType([
    StructField("version", IntegerType()),
    StructField("timestamp", LongType()),
    StructField("changeset", LongType()),
    StructField("uid", LongType()),
    StructField("user", StringType()),
])

NodeType = StructType([
    StructField("id", LongType()),
    StructField("info", InfoType),
    StructField("tags", TagType),
    StructField("latitude", DoubleType()),
    StructField("longitude", DoubleType()),
])

WayType = StructType([
    StructField("id", LongType()),
    StructField("info", InfoType),
    StructField("tags", TagType),
    StructField("nodes", ArrayType(LongType())),
])


RelationType = StructType([
    StructField("id", LongType()),
    StructField("info", InfoType),
    StructField("tags", TagType),
    StructField("members", ArrayType(StructType([
        StructField("id", LongType()),
        StructField("role", StringType()),
        StructField("type", StringType()),
    ]))),
])


OsmDataType = StructType([
    StructField("nodes", ArrayType(NodeType)),
    StructField("ways", ArrayType(WayType)),
    StructField("relations", ArrayType(RelationType)),
])

@F.udf(returnType=OsmDataType)
def decode_data(data: bytes) -> tuple[str, int]:
    from protos.fileformat_pb2 import Blob
    from protos.osmformat_pb2 import PrimitiveBlock
    from protos.osmformat_pb2 import Way, Node, Relation, DenseNodes, Info, DenseInfo

    blob = Blob.FromString(data)
    data = decompress_blob(blob)
    block = PrimitiveBlock.FromString(data)
    result = decode_block(block)

    return result.nodes, result.ways, result.relations
    

def run(filename: str) -> None:
    with local_spark() as spark_session:
        spark_session.dataSource.register(PbfDataSource)
        df = spark_session.read.format("pbf").option("filename",filename).load()
        df = df.where(F.col("header_type") == "OSMData")
        df = df.repartition(100)
        df = df.withColumn("info", decode_data("data"))
        df = df.select(F.explode("info.nodes").alias("node")) 

        df.show()

run("/workspaces/data/osm/nevada-latest.osm.pbf")

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/10 22:06:00 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/07/10 22:06:24 ERROR Utils: uncaught error in thread Spark Context Cleaner, stopping SparkContext
java.lang.OutOfMemoryError: Java heap space
	at java.base/java.lang.invoke.DirectMethodHandle.allocateInstance(DirectMethodHandle.java:501)
	at java.base/java.lang.invoke.DirectMethodHandle$Holder.newInvokeSpecial(DirectMethodHandle$Holder)
	at java.base/java.lang.invoke.Invokers$Holder.linkToTargetMethod(Invokers$Holder)
	at org.apache.spark.ContextCleaner.$anonfun$keepCleaning$1(ContextCleaner.scala:196)
	at org.apache.spark.ContextCleaner$$Lambda/0x00007458685548d0.apply$mcV$sp(Unknown Source)
	at org.apache.spark.util.Utils$.tryOrStopS

ConnectionRefusedError: [Errno 111] Connection refused