Skip to content

Commit

Permalink
Serialize data frame (#1035)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgasner committed Mar 26, 2019
1 parent 0e73ec9 commit 1682c6e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 3 deletions.
50 changes: 47 additions & 3 deletions python_modules/airline-demo/airline_demo/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@


import os
import shutil
import tempfile
import zipfile

from collections import namedtuple

import sqlalchemy

from pyspark.sql import DataFrame

from dagster import as_dagster_type, dagster_type, Dict, Field, String
from dagster.core.types.runtime import PythonObjectType, Stringish
from dagster import as_dagster_type, Dict, Field, String
from dagster.core.types.marshal import SerializationStrategy
from dagster.core.types.runtime import Stringish
from dagster.utils import safe_isfile


Expand All @@ -20,8 +24,48 @@
)


class SparkDataFrameParquetSerializationStrategy(SerializationStrategy):
def serialize_value(self, _context, value, write_file_obj):
parquet_dir = tempfile.mkdtemp()
shutil.rmtree(parquet_dir)
value.write.parquet(parquet_dir)

archive_file_obj = tempfile.NamedTemporaryFile(delete=False)
try:
archive_file_obj.close()
archive_file_path = archive_file_obj.name
zipfile_path = shutil.make_archive(archive_file_path, 'zip', parquet_dir)
try:
with open(zipfile_path, 'rb') as archive_file_read_obj:
write_file_obj.write(archive_file_read_obj.read())
finally:
if os.path.isfile(zipfile_path):
os.remove(zipfile_path)
finally:
if os.path.isfile(archive_file_obj.name):
os.remove(archive_file_obj.name)

def deserialize_value(self, context, read_file_obj):
archive_file_obj = tempfile.NamedTemporaryFile(delete=False)
try:
archive_file_obj.write(read_file_obj.read())
archive_file_obj.close()
parquet_dir = tempfile.mkdtemp()
# We don't use the ZipFile context manager here because of py2
zipfile_obj = zipfile.ZipFile(archive_file_obj.name)
zipfile_obj.extractall(parquet_dir)
zipfile_obj.close()
finally:
if os.path.isfile(archive_file_obj.name):
os.remove(archive_file_obj.name)
return context.resources.spark.read.parquet(parquet_dir)


SparkDataFrameType = as_dagster_type(
DataFrame, name='SparkDataFrameType', description='A Pyspark data frame.'
DataFrame,
name='SparkDataFrameType',
description='A Pyspark data frame.',
serialization_strategy=SparkDataFrameParquetSerializationStrategy(),
)

SqlAlchemyEngineType = as_dagster_type(
Expand Down
44 changes: 44 additions & 0 deletions python_modules/airline-demo/airline_demo_tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import sys
import tempfile

from airline_demo.resources import spark_session_local, tempfile_resource
from airline_demo.types import SparkDataFrameParquetSerializationStrategy


class MockResources(object):
def __init__(self, spark, tempfile_):
self.spark = spark
self.tempfile = tempfile_


class MockPipelineContext(object):
def __init__(self, spark, tempfile_):
self.resources = MockResources(spark, tempfile_)


def test_spark_data_frame_serialization():
spark = spark_session_local.resource_fn(None)
if sys.version_info > (3,):
tempfile_ = tempfile_resource.resource_fn(None).__next__()
else:
tempfile_ = tempfile_resource.resource_fn(None).next()

try:
serialization_strategy = SparkDataFrameParquetSerializationStrategy()

pipeline_context = MockPipelineContext(spark, tempfile_)

df = spark.createDataFrame([('Foo', 1), ('Bar', 2)])

with tempfile.NamedTemporaryFile(delete=False) as tempfile_obj:
serialization_strategy.serialize_value(pipeline_context, df, tempfile_obj)

tempfile_obj.close()

with open(tempfile_obj.name, 'r+b') as tempfile_obj:
new_df = serialization_strategy.deserialize_value(pipeline_context, tempfile_obj)

assert set(map(lambda x: x[0], new_df.collect())) == set(['Bar', 'Foo'])
assert set(map(lambda x: x[1], new_df.collect())) == set([1, 2])
finally:
tempfile_.close()

0 comments on commit 1682c6e

Please sign in to comment.