# Transformations for all schemas with a given column using DataSetImplements

Let's illustrate this with an example! First, we'll define some data.

In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.Builder().config("spark.ui.showConsoleProgress", "false").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

In [2]:
from pyspark.sql.types import LongType, StringType
from typedspark import (
    Schema,
    Column,
    create_empty_dataset,
)


class Person(Schema):
    name: Column[StringType]
    age: Column[LongType]
    job: Column[StringType]


class Pet(Schema):
    name: Column[StringType]
    age: Column[LongType]
    type: Column[StringType]


class Fruit(Schema):
    type: Column[StringType]


person = create_empty_dataset(spark, Person)
pet = create_empty_dataset(spark, Pet)
fruit = create_empty_dataset(spark, Fruit)

Now, suppose we want to define a function `birthday()` that works on all schemas that contain the column `age`. With `DataSet`, we'd have to specifically indicate which schemas contain the `age` column. We could do this with for example:

In [3]:
from typing import TypeVar, Union

from typedspark import DataSet, transform_to_schema

T = TypeVar("T", bound=Union[Person, Pet])


def birthday(df: DataSet[T]) -> DataSet[T]:
    return transform_to_schema(
        df,
        df.typedspark_schema,
        {Person.age: Person.age + 1},
    )

This can get tedious if the list of schemas with the column `age` changes, for example because new schemas are added, or because the `age` column is removed from a schema! It's also not great that we're using `Person.age` here to define the `age` column...

Fortunately, we can do better! Consider the following example:

In [4]:
from typing import Protocol

from typedspark import DataSetImplements


class Age(Schema, Protocol):
    age: Column[LongType]


T = TypeVar("T", bound=Schema)


def birthday(df: DataSetImplements[Age, T]) -> DataSet[T]:
    return transform_to_schema(
        df,
        df.typedspark_schema,
        {Age.age: Age.age + 1},
    )

Here, we define `Age` to be both a `Schema` and a `Protocol` ([PEP-0544](https://peps.python.org/pep-0544/)). 

We then define `birthday()` to:

1. Take as an input `DataSetImplements[Age, T]`: a `DataSet` that implements the protocol `Age` as `T`. 
2. Return a `DataSet[T]`: a `DataSet` of the same type as the one that was provided.

Let's see this in action!

In [5]:
# returns a DataSet[Person]
happy_person = birthday(person)

# returns a DataSet[Pet]
happy_pet = birthday(pet)

try:
    # Raises a linting error:
    # Argument of type "DataSet[Fruit]" cannot be assigned to
    # parameter "df" of type "DataSetImplements[Age, T@birthday]"
    birthday(fruit)
except Exception as e:
    pass