From 889eb253b88aa1622b820dd26b0cf7f11f86cd70 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 22 Jun 2015 12:15:05 -0700 Subject: [PATCH] Minor refactoring and add partitionBy to save, saveAsTable, and parquet. --- python/pyspark/sql/readwriter.py | 43 +++++++++++--------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 38e3690a4aa17..1c1519b3f1651 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -218,7 +218,10 @@ def mode(self, saveMode): >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - self._jwrite = self._jwrite.mode(saveMode) + # At the JVM side, the default value of mode is already set to "error". + # So, if the given saveMode is None, we will not call JVM-side's mode method. + if saveMode is not None: + self._jwrite = self._jwrite.mode(saveMode) return self @since(1.4) @@ -253,11 +256,12 @@ def partitionBy(self, *cols): """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] - self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + if len(cols) > 0: + self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) return self @since(1.4) - def save(self, path=None, format=None, mode=None, **options): + def save(self, path=None, format=None, mode=None, partitionBy=(), **options): """Saves the contents of the :class:`DataFrame` to a data source. The data source is specified by the ``format`` and a set of ``options``. @@ -276,11 +280,7 @@ def save(self, path=None, format=None, mode=None, **options): >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - if mode is not None: - # At the JVM side, the default value of mode is already set to "error". - # We will only call mode method if the provided mode is not None. - self.mode(mode) - self.options(**options) + self.partitionBy(partitionBy).mode(mode).options(**options) if format is not None: self.format(format) if path is None: @@ -300,7 +300,7 @@ def insertInto(self, tableName, overwrite=False): self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) @since(1.4) - def saveAsTable(self, name, format=None, mode=None, **options): + def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options): """Saves the content of the :class:`DataFrame` as the specified table. In the case the table already exists, behavior of this function depends on the @@ -318,11 +318,7 @@ def saveAsTable(self, name, format=None, mode=None, **options): :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) :param options: all other string options """ - if mode is not None: - # At the JVM side, the default value of mode is already set to "error". - # We will only call mode method if the provided mode is not None. - self.mode(mode) - self.options(**options) + self.partitionBy(partitionBy).mode(mode).options(**options) if format is not None: self.format(format) self._jwrite.saveAsTable(name) @@ -341,14 +337,10 @@ def json(self, path, mode=None): >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ - if mode is not None: - # At the JVM side, the default value of mode is already set to "error". - # We will only call mode method if the provided mode is not None. - self.mode(mode) - self._jwrite.json(path) + self._jwrite.mode(mode).json(path) @since(1.4) - def parquet(self, path, mode=None): + def parquet(self, path, mode=None, partitionBy=()): """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. :param path: the path in any Hadoop supported file system @@ -361,10 +353,7 @@ def parquet(self, path, mode=None): >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - if mode is not None: - # At the JVM side, the default value of mode is already set to "error". - # We will only call mode method if the provided mode is not None. - self.mode(mode) + self.partitionBy(partitionBy).mode(mode) self._jwrite.parquet(path) @since(1.4) @@ -386,14 +375,10 @@ def jdbc(self, url, table, mode=None, properties={}): arbitrary string tag/value. Normally at least a "user" and "password" property should be included. """ - if mode is not None: - # At the JVM side, the default value of mode is already set to "error". - # We will only call mode method if the provided mode is not None. - self.mode(mode) jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) - self._jwrite.jdbc(url, table, jprop) + self._jwrite.mode(mode).jdbc(url, table, jprop) def _test():