Skip to content

Commit

Permalink
Merge pull request #39 from dan1elt0m/add-force-type-option
Browse files Browse the repository at this point in the history
Add Coerce type option
  • Loading branch information
timvancann committed Nov 17, 2023
2 parents 99b2af6 + 3b4c10a commit 975e786
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v2
Expand Down
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ schema_dict: dict = TestModel.spark_schema()
print(json.dumps(schema_dict))

```
#### Coerce type
Pydantic-spark provides a `coerce_type` option that allows type coercion.
When applied to a field, pydantic-spark converts the column's data type to the specified coercion type.

```python
import json
from pydantic import Field
from pydantic_spark.base import SparkBase, CoerceType

class TestModel(SparkBase):
key1: str = Field(extra_json_schema={"coerce_type": CoerceType.integer})

schema_dict: dict = TestModel.spark_schema()
print(json.dumps(schema_dict))

```


### Install for developers

Expand Down
99 changes: 15 additions & 84 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ include = [
packages = [{ include = "pydantic_spark", from = "src"}]

[tool.poetry.dependencies]
python = ">=3.7,<4.0"
python = ">=3.8,<4.0"
pydantic = "^1.4.0"

#spark
Expand Down
21 changes: 21 additions & 0 deletions src/pydantic_spark/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
from enum import Enum
from typing import List, Tuple

from pydantic import BaseModel


class CoerceType(str, Enum):
integer = "integer"
long = "long"
double = "double"
string = "string"
boolean = "boolean"
date = "date"
timestamp = "timestamp"


class CoerceTypeError(Exception):
pass


class SparkBase(BaseModel):
"""This is base pydantic class that will add some methods"""

Expand Down Expand Up @@ -50,6 +65,8 @@ def get_type(value: dict) -> Tuple[str, dict]:
f = value.get("format")
r = value.get("$ref")
a = value.get("additionalProperties")
e = value.get("json_schema_extra", {})
ft = e.get("coerce_type")
metadata = {}
if "default" in value:
metadata["default"] = value.get("default")
Expand All @@ -60,6 +77,10 @@ def get_type(value: dict) -> Tuple[str, dict]:
else:
spark_type = get_type_of_definition(r, schema)
classes_seen[class_name] = spark_type
elif ft is not None:
if not isinstance(ft, CoerceType):
raise CoerceTypeError("coerce_type must be of type CoerceType")
spark_type = ft.value
elif t == "array":
items = value.get("items")
tn, metadata = get_type(items)
Expand Down
52 changes: 51 additions & 1 deletion tests/test_to_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
from typing import Dict, List, Optional
from uuid import UUID

import pytest
from pydantic import Field
from pyspark.sql.types import (
ArrayType,
BooleanType,
DateType,
DoubleType,
IntegerType,
LongType,
MapType,
StringType,
Expand All @@ -17,7 +20,7 @@
TimestampType,
)

from pydantic_spark.base import SparkBase
from pydantic_spark.base import CoerceType, CoerceTypeError, SparkBase


class Nested2Model(SparkBase):
Expand Down Expand Up @@ -283,3 +286,50 @@ def test_enum():
)
result = TestEnum.spark_schema()
assert result == json.loads(expected_schema.json())


def test_coerce_type():
class TestCoerceType(SparkBase):
c1: int = Field(json_schema_extra={"coerce_type": CoerceType.integer})

result = TestCoerceType.spark_schema()
assert result["fields"][0]["type"] == "integer"


def test_coerce_type_invalid():
class TestCoerceType(SparkBase):
c1: int = Field(json_schema_extra={"coerce_type": "integer"})

with pytest.raises(CoerceTypeError):
TestCoerceType.spark_schema()


class Nested2ModelCoerceType(SparkBase):
c111: str = Field(json_schema_extra={"coerce_type": CoerceType.integer})


class NestedModelCoerceType(SparkBase):
c11: Nested2ModelCoerceType


class ComplexTestModelCoerceType(SparkBase):
c1: List[NestedModelCoerceType]


def test_coerce_type_complex_spark():
expected_schema = StructType(
[
StructField(
"c1",
ArrayType(StructType.fromJson(NestedModelCoerceType.spark_schema())),
nullable=False,
metadata={"parentClass": "ComplexTestModelCoerceType"},
)
]
)
result = ComplexTestModelCoerceType.spark_schema()
assert result == json.loads(expected_schema.json())
# Reading schema with spark library to be sure format is correct
schema = StructType.fromJson(result)
assert len(schema.fields) == 1
assert isinstance(schema.fields[0].dataType.elementType.fields[0].dataType.fields[0].dataType, IntegerType)

0 comments on commit 975e786

Please sign in to comment.