In [1]:
import pandas as pd

## Problem

In [2]:
def transform_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    df['transformed_column'] = df['required_column'] * 10
    return df

In [3]:
df_valid = pd.DataFrame({
    'required_column': [1, 2, 3],
})

df_invalid = pd.DataFrame({
    'other_column': [10, 20, 30]
})

In [4]:
transform_dataframe(df=df_valid)

Unnamed: 0,required_column,transformed_column
0,1,10
1,2,20
2,3,30


In [5]:
transform_dataframe(df=df_invalid)

KeyError: 'required_column'

## Idea

Instead of having a very general `pd.DataFrame` Typing hint, we want a Typing hint that tells me what kind of schema the DataFrame should have.

In [None]:
from typing import Annotated
from pandas import DataFrame

# Annotated adds metadata to type hint
def transform_dataframe2(
    df: Annotated[DataFrame, "must have column `required_column`"]
) -> DataFrame:
    df['transformed_column'] = df['required_column'] * 10
    return df

In [None]:
from dataclasses import dataclass   

@dataclass
class ValidDataFrame:
    columns = ['required_column']
    data: pd.DataFrame
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

In [None]:
def transform_dataframe(df: ValidDataFrame) -> pd.DataFrame:
    df['transformed_column'] = df['required_column'] * 10
    return df

In [None]:
transform_dataframe(df=ValidDataFrame(df_valid))

TypeError: object.__init__() takes exactly one argument (the instance to initialize)

In [None]:
import pandera as pa

In [None]:
schema = pa.DataFrameSchema({
    'col1': pa.Column(pa.Int)
})

df_valid = pd.DataFrame({'col1': [1, 2, 3]})
df_invalid = pd.DataFrame({'col1': '1 2 3'.split()})

In [None]:
schema(df_valid)

Unnamed: 0,col1
0,1
1,2
2,3


In [None]:
schema(df_invalid)

SchemaError: expected series 'col1' to have type int64, got object

In [None]:
import pandas as pd
import pandera as pa
from pandera.typing import DataFrame, Series


class Schema(pa.SchemaModel):
    # item: Series[str] = pa.Field(isin=["apple", "orange"], coerce=True)
    # price: Series[float] = pa.Field(gt=0, coerce=True)
    category: Series[str]

    @pa.check('category')
    def category_check(cls, series: Series[str]) -> Series[bool]:
        return series.isin(['fruit', 'vegetables'])

In [None]:
invalid_data = pd.DataFrame.from_records([
    {"item": "applee", "price": 0.5, "category": 'fruit'},
    {"item": "orange", "price": -1000, "category": 'fruit'},
    {"item": "orange", "price": 1, "category": 'snack'},
])

In [None]:
Schema.validate(invalid_data)

Unnamed: 0,item,price,category
0,applee,0.5,fruit
1,orange,-1000.0,fruit


In [None]:
try:
    transform_data(invalid_data)
except pa.errors.SchemaErrors as exc:
    display(exc.failure_cases)

In [None]:
@pa.check_types(lazy=True)
def transform_data(data: DataFrame[Schema]):
    return None



try:
    transform_data(invalid_data)
except pa.errors.SchemaErrors as exc:
    display(exc.failure_cases)

In [None]:
transform_data(invalid_data)

SchemaErrors: Schema Schema: A total of 2 schema errors were found.

Error Counts
------------
- schema_component_check: 2

Schema Error Summary
--------------------
                                                failure_cases  n_failure_cases
schema_context column check                                                   
Column         item   isin({'orange', 'apple'})      [applee]                1
               price  greater_than(0)               [-1000.0]                1

Usage Tip
---------

Directly inspect all errors by catching the exception:

```
try:
    schema.validate(dataframe, lazy=True)
except SchemaErrors as err:
    err.failure_cases  # dataframe of schema errors
    err.data  # invalid dataframe
```


In [None]:
class BaseSchema(pa.SchemaModel):
    year: Series[str]

class FinalSchema(BaseSchema):
    year: Series[int] = pa.Field(ge=2000, coerce=True)  # overwrite the base type
    passengers: Series[int]
    idx: Index[int] = pa.Field(ge=0)

df = pd.DataFrame({
    "year": ["2000", "2001", "2002"],
})

@pa.check_types
def transform(df: DataFrame[BaseSchema]) -> DataFrame[FinalSchema]:
    return (
        df.assign(passengers=[61000, 50000, 45000])
        .set_index(pd.Index([1, 2, 3]))
        .astype({"year": int})
    )

print(transform(df))

NameError: name 'Index' is not defined

## Notes

* decorators
    * `@pa.check_output(schema)` checks if the output-dataframe has expected schema
    * `@pa.check_input(schema)` checks if input-dataframe has expected schema
    * `@pa.check_types`

## Example

In [21]:
import pandera as pa
from pandera import SchemaModel, Column, Check, Float, Field
from pandera.typing import Series, DataFrame
from typing import List

class SchemaNPuncStops(SchemaModel):
    n_stops: Series[Float]
    n_punctual_stops: Series[Float]

class SchemaPunc(SchemaNPuncStops):
    # punctualities must be Floats and values between 0.0 and 1.0
    punctuality: Series[Float] = Field(ge=0.0, le=1.0)

# @pa.check_types
@pa.check_types
def calculate_punctuality(df: DataFrame[SchemaNPuncStops]) -> DataFrame[SchemaPunc]:
    dfpunc = df.copy() 
    dfpunc['punctuality'] = dfpunc['n_punctual_stops'] / dfpunc['n_stops']
    return dfpunc

TypeError: check_input() got an unexpected keyword argument 'coerce'

In [60]:
def my_decorator(func):
    def wrapper(*args, **kwargs):
        greet(*args, **kwargs)
    return func

In [61]:
def greet(name):
    print(f'hello {name}')

In [63]:
greet = my_decorator(greet)
greet(name='nils')

hello nils


In [201]:
from typing import List
from functools import wraps

def require_columns(columns: List[str]):
    def decorator_require_columns(wrapped):
        @wraps(wrapped)
        def wrapper_require_columns(*args, **kwargs):
            # find the function argument that is of type pd.DataFrame
            df_arg_name = None
            for arg_name, arg_type_annotation in typing.get_type_hints(wrapped).items():
                if arg_type_annotation == pd.DataFrame:
                    df_arg_name = arg_name
            
            # raise error if no argument of type pd.DataFrame has been found
            if df_arg_name is None:
                raise KeyError("Function has no argument of type pandas.DataFrame")

            # check if the dataframe has all required columns
            input_dataframe = kwargs[df_arg_name]
            for col in columns:
                if col not in input_dataframe.columns:
                    err = f"Missing required column: {col}"
                    raise KeyError(err)

            return wrapped(*args, **kwargs)
        return wrapper_require_columns
    return decorator_require_columns

In [204]:
@require_columns(columns=['A', 'B'])
def transform(data: pd.DataFrame):
    # do something that requires column 'A'
    return df

df = pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]})
transform(data=df)

Unnamed: 0,A,B
0,1,4
1,2,5
2,3,6


In [205]:
df = pd.DataFrame({'A': [1, 2, 3]})
transform(data=df)

KeyError: 'Missing required column: B'

In [174]:
import typing
import pandera as pa

def func(a: pd.DataFrame, b: pa.typing.DataFrame, c: int) -> bool:
    return True

def is_dataframe(annotation) -> bool:
    if (annotation == pd.DataFrame) or (annotation == pa.typing.DataFrame):
        return True
    else:
        return False

for arg_name, annotation in typing.get_type_hints(func).items():
    print([arg_name, annotation])
    if is_dataframe(annotation):
        print(True)
    else:
        print(False)

['a', <class 'pandas.core.frame.DataFrame'>]
True
['b', <class 'pandera.typing.pandas.DataFrame'>]
True
['c', <class 'int'>]
False
['return', <class 'bool'>]
False


In [None]:
@check_columns('n_stops')
def calculate_punctuality(df: DataFrame[SchemaNPuncStops]) -> DataFrame[SchemaPunc]:
    dfpunc = df.copy() 
    dfpunc['punctuality'] = dfpunc['n_punctual_stops'] / dfpunc['n_stops']
    return dfpunc

In [19]:
df = pd.DataFrame({
    'n_stops': [100.0, 100.0, 100.0],
    # 'n_punctual_stops': [80.0, 60.0, 50.0],
    'n_punctual_stops': ['10', '5', '10'],
})

calculate_punctuality(df)

SchemaError: error in check_types decorator of function 'calculate_punctuality': expected series 'n_punctual_stops' to have type float64, got object