In [None]:
# https://towardsdatascience.com/validate-your-pandas-dataframe-with-pandera-2995910e564

In [1]:
import pandas as pd

fruits = pd.DataFrame(
    {
        "name": ["apple", "banana", "apple", "orange"],
        "store": ["Aldi", "Walmart", "Walmart", "Aldi"],
        "price": [2, 1, 3, 4],
    }
)

fruits

Unnamed: 0,name,store,price
0,apple,Aldi,2
1,banana,Walmart,1
2,apple,Walmart,3
3,orange,Aldi,4


In [2]:
#pass validation

import pandera as pa
from pandera import Column, Check

available_fruits = ["apple", "banana", "orange"]
nearby_stores = ["Aldi", "Walmart"]

schema = pa.DataFrameSchema(
    {
        "name": Column(str, Check.isin(available_fruits)),
        "store": Column(str, Check.isin(nearby_stores)),
        "price": Column(int, Check.less_than(5)),
    }
)
schema.validate(fruits)

Unnamed: 0,name,store,price
0,apple,Aldi,2
1,banana,Walmart,1
2,apple,Walmart,3
3,orange,Aldi,4


In [3]:
#fail validation

import pandera as pa
from pandera import Column, Check

schema = pa.DataFrameSchema(
    {
        "name": Column(str, Check.isin(available_fruits)),
        "store": Column(str, Check.isin(nearby_stores)),
        "price": Column(int, Check.less_than(4)),
    }
)
schema.validate(fruits)
# not all price less than 4

SchemaError: <Schema Column(name=price, type=<class 'int'>)> failed element-wise validator 0:
<Check less_than: less_than(4)>
failure cases:
   index  failure_case
0      3             4

In [5]:
schema = pa.DataFrameSchema(
    {
        "name": Column(str, Check.isin(available_fruits)),
        "store": Column(str, Check.isin(nearby_stores)),
        "price": Column(
            int, [Check.less_than(5), Check(lambda price: sum(price) < 15)]
        ),
    }
)
schema.validate(fruits)

Unnamed: 0,name,store,price
0,apple,Aldi,2
1,banana,Walmart,1
2,apple,Walmart,3
3,orange,Aldi,4


In [7]:
# class

from pandera.typing import Series

class Schema(pa.SchemaModel):
    name: Series[str] = pa.Field(isin=available_fruits)
    store: Series[str] = pa.Field(isin=nearby_stores)
    price: Series[int] = pa.Field(le=5)

    @pa.check("price")
    def price_sum_lt_20(cls, price: Series[int]) -> Series[bool]:
        return sum(price) < 20

Schema.validate(fruits)

Unnamed: 0,name,store,price
0,apple,Aldi,2
1,banana,Walmart,1
2,apple,Walmart,3
3,orange,Aldi,4


In [9]:
fruits = pd.DataFrame(
    {
        "name": ["apple", "banana", "apple", "orange"],
        "store": ["Aldi", "Walmart", "Walmart", "Aldi"],
        "price": [2, 1, 3, 4],
    }
)

schema = pa.DataFrameSchema(
    {
        "name": Column(str, Check.isin(available_fruits)),
        "store": Column(str, Check.isin(nearby_stores)),
        "price": Column(int, Check.less_than(5)),
    }
)


def get_total_price(fruits: pd.DataFrame, schema: pa.DataFrameSchema):
    validated = schema.validate(fruits)
    return validated["price"].sum()

get_total_price(fruits, schema)

10

In [11]:
#for unit test

def test_get_total_price():
    fruits = pd.DataFrame({'name': ['apple', 'banana'], 'store': ['Aldi', 'Walmart'], 'price': [1, 2]})
    
    # Need to include schema in the unit test
    schema = pa.DataFrameSchema(
        {
            "name": Column(str, Check.isin(available_fruits)),
            "store": Column(str, Check.isin(nearby_stores)),
            "price": Column(int, Check.less_than(5)),
        }
    )
    assert get_total_price(fruits, schema) == 3

test_get_total_price()

In [13]:
#decorator

from pandera import check_input

schema = pa.DataFrameSchema(
    {
        "name": Column(str, Check.isin(available_fruits)),
        "store": Column(str, Check.isin(nearby_stores)),
        "price": Column(int, Check.less_than(5)),
    })

@check_input(schema)
def get_total_price(fruits: pd.DataFrame):
    return fruits.price.sum()

get_total_price(fruits)

10

In [14]:
# input check

fruits = pd.DataFrame(
    {
        "name": ["apple", "banana", "apple", "orange"],
        "store": ["Aldi", "Walmart", "Walmart", "Aldi"],
        "price": ["2", "1", "3", "4"],
        # "price": [2,1,3,4]
    }
)

@check_input(schema)
def get_total_price(fruits: pd.DataFrame):
    return fruits.price.sum()

get_total_price(fruits)

SchemaError: error in check_input decorator of function 'get_total_price': expected series 'price' to have type 'int64', got 'object'

In [15]:
# check_output

from pandera import check_output

fruits_nearby = pd.DataFrame(
    {
        "name": ["apple", "banana", "apple", "orange"],
        "store": ["Aldi", "Walmart", "Walmart", "Aldi"],
        "price": [2, 1, 3, 4],
    }
)

fruits_faraway = pd.DataFrame(
    {
        "name": ["apple", "banana", "apple", "orange"],
        "store": ["Whole Foods", "Whole Foods", "Schnucks", "Schnucks"],
        "price": [3, 2, 4, 5],
    }
)

out_schema = pa.DataFrameSchema(
    {"store": Column(str, Check.isin(["Aldi", "Walmart", "Whole Foods", "Schnucks"]))}
)


@check_output(out_schema)
def combine_fruits(fruits_nearby: pd.DataFrame, fruits_faraway: pd.DataFrame):
    fruits = pd.concat([fruits_nearby, fruits_faraway])
    return fruits


combine_fruits(fruits_nearby, fruits_faraway)

Unnamed: 0,name,store,price
0,apple,Aldi,2
1,banana,Walmart,1
2,apple,Walmart,3
3,orange,Aldi,4
0,apple,Whole Foods,3
1,banana,Whole Foods,2
2,apple,Schnucks,4
3,orange,Schnucks,5


In [16]:
# check input and output both

from pandera import check_io

in_schema = pa.DataFrameSchema({"store": Column(str)})

out_schema = pa.DataFrameSchema(
    {"store": Column(str, Check.isin(["Aldi", "Walmart", "Whole Foods", "Schnucks"]))}
)


@check_io(fruits_nearby=in_schema, fruits_faraway=in_schema, out=out_schema)
def combine_fruits(fruits_nearby: pd.DataFrame, fruits_faraway: pd.DataFrame):
    fruits = pd.concat([fruits_nearby, fruits_faraway])
    return fruits


combine_fruits(fruits_nearby, fruits_faraway)

Unnamed: 0,name,store,price
0,apple,Aldi,2
1,banana,Walmart,1
2,apple,Walmart,3
3,orange,Aldi,4
0,apple,Whole Foods,3
1,banana,Whole Foods,2
2,apple,Schnucks,4
3,orange,Schnucks,5


In [17]:
# nullable, default is not allowed

import numpy as np

fruits = fruits = pd.DataFrame(
    {
        "name": ["apple", "banana", "apple", "orange"],
        "store": ["Aldi", "Walmart", "Walmart", np.nan],
        "price": [2, 1, 3, 4],
    }
)

schema = pa.DataFrameSchema(
    {
        "name": Column(str, Check.isin(available_fruits)),
        "store": Column(str, Check.isin(nearby_stores), nullable=True),
        "price": Column(int, Check.less_than(5)),
    }
)
schema.validate(fruits)

Unnamed: 0,name,store,price
0,apple,Aldi,2
1,banana,Walmart,1
2,apple,Walmart,3
3,orange,,4


In [18]:
# allow_duplicates

fruits = fruits = pd.DataFrame(
    {
        "name": ["apple", "banana", "apple", "orange"],
        "store": ["Aldi", "Walmart", "Walmart", np.nan],
        "price": [2, 1, 3, 4],
    }
)

schema = pa.DataFrameSchema(
    {
        "name": Column(str, Check.isin(available_fruits)),
        "store": Column(
            str, Check.isin(nearby_stores), nullable=True, allow_duplicates=False
        ),
        "price": Column(int, Check.less_than(5)),
    }
)
schema.validate(fruits)

SchemaError: series 'store' contains duplicate values: {2: 'Walmart'}

In [19]:
# convert data type

fruits = pd.DataFrame(
    {
        "name": ["apple", "banana", "apple", "orange"],
        "store": ["Aldi", "Walmart", "Walmart", "Aldi"],
        "price": [2, 1, 3, 4],
    }
)

schema = pa.DataFrameSchema({"price": Column(str, coerce=True)})
validated = schema.validate(fruits)
validated.dtypes

name     object
store    object
price    object
dtype: object

In [20]:
#regex pattern

favorite_stores = ["Aldi", "Walmart", "Whole Foods", "Schnucks"]

fruits = pd.DataFrame(
    {
        "name": ["apple", "banana", "apple", "orange"],
        "store_nearby": ["Aldi", "Walmart", "Walmart", "Aldi"],
        "store_far": ["Whole Foods", "Schnucks", "Whole Foods", "Schnucks"],
    }
)

schema = pa.DataFrameSchema(
    {
        "name": Column(str, Check.isin(available_fruits)),
        "store_+": Column(str, Check.isin(favorite_stores), regex=True),
    }
)
schema.validate(fruits)

Unnamed: 0,name,store_nearby,store_far
0,apple,Aldi,Whole Foods
1,banana,Walmart,Schnucks
2,apple,Walmart,Whole Foods
3,orange,Aldi,Schnucks


In [21]:
# YAML
from pathlib import Path

# Get a YAML object
yaml_schema = schema.to_yaml()

# Save to a file
f = Path("data/schema.yml")
f.touch()
f.write_text(yaml_schema)

528

In [22]:
with f.open() as file:
    yaml_schema = file.read()

schema = pa.io.from_yaml(yaml_schema)
schema

<Schema DataFrameSchema(columns={'name': <Schema Column(name=name, type=object)>, 'store_+': <Schema Column(name=store_+, type=object)>}, checks=[], index=None, coerce=False, pandas_dtype=None,strict=False,name=None,ordered=False)>