Skip to content

Commit

Permalink
Add force type option
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Tom committed Nov 17, 2023
1 parent 99b2af6 commit 6644213
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
22 changes: 22 additions & 0 deletions src/pydantic_spark/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
from enum import Enum
from typing import List, Tuple

from pydantic import BaseModel


class ForceType(str, Enum):
integer = "integer"
long = "long"
double = "double"
string = "string"
boolean = "boolean"
date = "date"
timestamp = "timestamp"
uuid = "uuid"


class ForceTypeError(Exception):
pass


class SparkBase(BaseModel):
"""This is base pydantic class that will add some methods"""

Expand Down Expand Up @@ -50,6 +66,8 @@ def get_type(value: dict) -> Tuple[str, dict]:
f = value.get("format")
r = value.get("$ref")
a = value.get("additionalProperties")
e = value.get("json_schema_extra", {})
ft = e.get("force_type")
metadata = {}
if "default" in value:
metadata["default"] = value.get("default")
Expand All @@ -60,6 +78,10 @@ def get_type(value: dict) -> Tuple[str, dict]:
else:
spark_type = get_type_of_definition(r, schema)
classes_seen[class_name] = spark_type
elif ft is not None:
if not isinstance(ft, ForceType):
raise ForceTypeError("force_type must be of type ForceType")
spark_type = ft.value
elif t == "array":
items = value.get("items")
tn, metadata = get_type(items)
Expand Down
43 changes: 42 additions & 1 deletion tests/test_to_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from typing import Dict, List, Optional
from uuid import UUID

from pydantic import Field
from pyspark.sql.types import (
ArrayType,
BooleanType,
DateType,
DoubleType,
IntegerType,
LongType,
MapType,
StringType,
Expand All @@ -17,7 +19,7 @@
TimestampType,
)

from pydantic_spark.base import SparkBase
from pydantic_spark.base import ForceType, SparkBase


class Nested2Model(SparkBase):
Expand Down Expand Up @@ -283,3 +285,42 @@ def test_enum():
)
result = TestEnum.spark_schema()
assert result == json.loads(expected_schema.json())


def test_force_type():
class TestForceType(SparkBase):
c1: int = Field(json_schema_extra={"force_type": ForceType.integer})

result = TestForceType.spark_schema()
assert result["fields"][0]["type"] == "integer"


class Nested2ModelForceType(SparkBase):
c111: str = Field(json_schema_extra={"force_type": ForceType.integer})


class NestedModelForceType(SparkBase):
c11: Nested2ModelForceType


class ComplexTestModelForceType(SparkBase):
c1: List[NestedModelForceType]


def test_force_type_complex_spark():
expected_schema = StructType(
[
StructField(
"c1",
ArrayType(StructType.fromJson(NestedModelForceType.spark_schema())),
nullable=False,
metadata={"parentClass": "ComplexTestModelForceType"},
)
]
)
result = ComplexTestModelForceType.spark_schema()
assert result == json.loads(expected_schema.json())
# Reading schema with spark library to be sure format is correct
schema = StructType.fromJson(result)
assert len(schema.fields) == 1
assert isinstance(schema.fields[0].dataType.elementType.fields[0].dataType.fields[0].dataType, IntegerType)

0 comments on commit 6644213

Please sign in to comment.