In [None]:
from datetime import date, datetime

import pandas as pd
from pyspark.sql import Row, SparkSession

spark = SparkSession.builder.getOrCreate()

### spark df from a pandas df

In [None]:
pandas_df = pd.DataFrame(
    {
        "a": [1, 2, 3],
        "b": [2.0, 3.0, 4.0],
        "c": ["string1", "string2", "string3"],
        "d": [date(2000, 1, 1), date(2000, 2, 1), date(2000, 3, 1)],
        "e": [
            datetime(2000, 1, 1, 12, 0),
            datetime(2000, 1, 2, 12, 0),
            datetime(2000, 1, 3, 12, 0),
        ],
    }
)

spark_df = spark.createDataFrame(pandas_df)
spark_df

### spark df from a spark rdd

In [None]:
rdd = spark.sparkContext.parallelize(
    [
        (1, 2.0, "string1", date(2000, 1, 1), datetime(2000, 1, 1, 12, 0)),
        (2, 3.0, "string2", date(2000, 2, 1), datetime(2000, 1, 2, 12, 0)),
        (3, 4.0, "string3", date(2000, 3, 1), datetime(2000, 1, 3, 12, 0)),
        (4, 5.0, "string4", date(2000, 4, 1), datetime(2000, 1, 4, 12, 0)),
        (5, 6.0, "string5", date(2000, 5, 1), datetime(2000, 1, 5, 12, 0)),
        (6, 7.0, "string6", date(2000, 6, 1), datetime(2000, 1, 6, 12, 0)),
    ]
)
df = spark.createDataFrame(rdd, schema=["a", "b", "c", "d", "e"])
df

In [None]:
df.printSchema()

In [None]:
df.show(4)

### tweak some config for df display

In [None]:
spark.conf.set("spark.sql.repl.eagerEval.enabled", True)
spark.conf.set("spark.sql.repl.eagerEval.maxNumRows", 5)

In [None]:
df

### filtering

In [None]:
df.filter(df.a == 1).show()

### functions

In [None]:
import pandas
from pyspark.sql.functions import pandas_udf

In [None]:
@pandas_udf("long")
def pandas_plus_ten(series: pd.Series) -> pd.Series:
    return series + 10

In [None]:
df.select(pandas_plus_ten(df.a)).show()

In [None]:
def pandas_filter_func(iterator):
    for pandas_df in iterator:
        yield pandas_df[pandas_df.a == 2]


df.mapInPandas(pandas_filter_func, schema=df.schema).show()

### grouping

In [None]:
df = spark.createDataFrame(
    [
        ["red", "banana", 1, 10],
        ["blue", "banana", 2, 20],
        ["red", "carrot", 3, 30],
        ["blue", "grape", 4, 40],
        ["red", "carrot", 5, 50],
        ["black", "carrot", 6, 60],
        ["red", "banana", 7, 70],
        ["red", "grape", 8, 80],
    ],
    schema=["color", "fruit", "v1", "v2"],
)
df

In [None]:
df.groupby("color").avg("v1")