In [0]:
# Databricks notebook source
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, BooleanType, TimestampType
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame 
from pyspark.sql.functions import col 


def map_to_spark_type(data_type):
    if 'int' in data_type:
        return IntegerType()
    elif 'float' in data_type:
        return DoubleType()
    elif 'bool' in data_type:
        return BooleanType()
    elif 'datetime' in data_type:
        return TimestampType()
    else:
        return StringType()


def generate_schema_string(df, struct_type_name):

    # if not isinstance(df, pyspark.sql.dataframe.DataFrame):
    #     raise TypeError("The 'df' parameter must be a DataFrame or Dataset.")

    if df.columns:
        columns = df.columns
    else:
        columns = [f"col_{i + 1}" for i in range(len(df.columns))]

    print("before zip columns: ", columns)
    print("before zip df.dtypes: ", df.dtypes)
    print("before zip columns: ", columns)
    print("before zip df.dtypes: ", df.dtypes)
    print("zip: ", zip(columns, df.dtypes))

    schema_string = f"{struct_type_name} = StructType(["
    for col, dtype in zip(columns, df.dtypes):
        spark_data_type = map_to_spark_type(dtype[1])
        schema_string += f"\n    StructField('{col}', {spark_data_type}, True),"

    schema_string = schema_string.rstrip(',') + '\n])'

    return schema_string


from sqlalchemy.types import (
    BigInteger,
    Boolean,
    Date,
    DateTime,
    DECIMAL,
    Float,
    Integer,
    Numeric,
    SmallInteger,
    String,
    Unicode,
    Text,
    TIMESTAMP,
)

from pyspark.sql.types import (
    BooleanType,
    DateType,
    DecimalType,
    FloatType,
    IntegerType,
    LongType,
    ShortType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)

__sqlalchemy_to_spark_type_mapping: dict = {
    Integer: IntegerType,
    SmallInteger: ShortType,
    BigInteger: LongType,
    Float: FloatType,
    String: StringType,
    Text: StringType,
    Unicode: StringType,
    Boolean: BooleanType,
    Date: DateType,
    DateTime: TimestampType,
    Numeric: DecimalType,
    DECIMAL: DecimalType,
    TIMESTAMP: TimestampType,
}


def convert_to_spark_type(sqlalchemy_type):
    return __sqlalchemy_to_spark_type_mapping[type(sqlalchemy_type)]

def generate_model_schema(df, struct_type_name):
    spark_fields = []

    if df.columns:
        columns = df.columns
    else:
        columns = [f"col_{i + 1}" for i in range(len(df.columns))]

    for col, dtype in zip(columns, df.dtypes):
        spark_data_type = convert_to_spark_type(dtype[1])
        spark_fields.append(StructField(col, spark_data_type))
    return StructType(spark_fields)         


spark = SparkSession.builder.appName("ASG").getOrCreate()
df = spark.createDataFrame([(1, 'John', True, 100.5, '2022-01-01'),
                            (2, 'Jane', False, 200.0, '2022-02-01')],
                           ['id', 'name', 'status', 'amount', 'timestamp'])

generated_schema_string = generate_schema_string(df , 'MyCustomSchema')

print('Generated Schema String:')
print(generated_schema_string)

# MyCustomSchema = StructType([
#     StructField('id', IntegerType(), True),
#     StructField('name', StringType(), True),
#     StructField('status', BooleanType(), True),
#     StructField('amount', StringType(), True),
#     StructField('timestamp', StringType(), True)
# ])

df_no_column_names = spark.createDataFrame([(1, 'John', True, 100.5, '2022-01-01'),
                                            (2, 'Jane', False, 200.0, '2022-02-01')])

generated_schema_string = generate_schema_string(df_no_column_names , 'MyCustomShema')
print('Generated Schema without column name:')
print(generated_schema_string)

In [0]:
from typing import List, Type

import pyspark
from pyspark.sql.functions import lit
from pyspark.sql.types import (
    BooleanType,
    DateType,
    DecimalType,
    FloatType,
    IntegerType,
    LongType,
    ShortType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)
from sqlalchemy import Column, inspect
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.types import (
    BigInteger,
    Boolean,
    Date,
    DateTime,
    DECIMAL,
    Float,
    Integer,
    Numeric,
    SmallInteger,
    String,
    Unicode,
    Text,
    TIMESTAMP,
)

from bh_shared.exceptions import DataTypeMappingException
from bh_shared.models_types import (
    ANY_SCHEMA,
    HistoricalColumn,
    OutputSchema,
    RecordIntegrationStatus,
    StagingSchema,
)

__sqlalchemy_to_spark_type_mapping: dict = {
    Integer: IntegerType,
    SmallInteger: ShortType,
    BigInteger: LongType,
    Float: FloatType,
    String: StringType,
    Text: StringType,
    Unicode: StringType,
    Boolean: BooleanType,
    Date: DateType,
    DateTime: TimestampType,
    Numeric: DecimalType,
    DECIMAL: DecimalType,
    TIMESTAMP: TimestampType,
}


def convert_to_spark_type(sqlalchemy_type: TypeEngine):
    return __sqlalchemy_to_spark_type_mapping[type(sqlalchemy_type)]


def model_schema(model: Type[StagingSchema]) -> StructType:
    columns = inspect(model).columns
    spark_fields = []

    for column in columns:
        spark_data_type = convert_to_spark_type(column.type)

        try:
            if spark_data_type is DecimalType:
                spark_type = spark_data_type(
                    precision=getattr(column.type, "precision"),
                    scale=getattr(column.type, "scale"),
                )
            else:
                spark_type = spark_data_type()
        except KeyError:
            raise DataTypeMappingException(
                f"Do not know how to translate alembic's {str(column.type)} to Spark's data type."
            )

        spark_fields.append(StructField(column.name, spark_type, column.nullable))

    return StructType(spark_fields)