In [None]:
!pip install mypy nb_mypy
%load_ext nb_mypy




Version 1.0.5
INFO:nb-mypy:Version 1.0.5


In [None]:
import seaborn as sns
sns.get_dataset_names()

['anagrams',
 'anscombe',
 'attention',
 'brain_networks',
 'car_crashes',
 'diamonds',
 'dots',
 'dowjones',
 'exercise',
 'flights',
 'fmri',
 'geyser',
 'glue',
 'healthexp',
 'iris',
 'mpg',
 'penguins',
 'planets',
 'seaice',
 'taxis',
 'tips',
 'titanic']

In [None]:
from typing import Callable
from typing import get_type_hints, Any, List
import polars as pl
import seaborn as sns

def get_df() -> pl.DataFrame:
        # Loading toy dataset as Pandas DataFrame using Seaborn
    df_pd = sns.load_dataset('iris')

    # Converting Pandas DataFrame to Polars DataFrame
    df = pl.DataFrame(df_pd)
    return df

def is_iterable(var: Any) -> bool:
    try:
        iter(var)
        return True
    except TypeError:
        # not iterable
        return False

    return False

class AbstractSchema(pl.DataFrame):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, schema=self._get_schema(), **kwargs)

    def _validate(self, df: pl.DataFrame, func_name: str):
        # TODO: Make this tidier
        other_schema = df.schema
        this_schema = self._get_schema()

        errors: list[str] = []
        if isinstance(other_schema, property):
            errors.append(f"The other schema for function [{func_name}] was empty.")
        elif len(other_schema) == 0:
            errors.append(f"The other schema for function [{func_name}] had zero len.")
        else:
            for key, val in this_schema.items():
                if key not in other_schema:
                    errors.append(f"{key} not found in schema for function [{func_name}].")
                elif other_schema[key] != val:
                    errors.append(f"{key} was not the right type in schema for function [{func_name}].")

            if len(other_schema) != len(this_schema):
                errors.append(f"Schema for function [{func_name}] has extra columns.")

        if len(errors) > 0:
            chars = "\n\t - "
            error_list = chars.join(errors)
            raise AssertionError(f"Found error(s) in schema: {chars}{error_list}")

    def _get_schema(self):
        raise NotImplementedError("This class did not implement the get_schema method.")

class IrisSchema(AbstractSchema):
    def _get_schema(self):
        return {
            'sepal_length': pl.Float64,
            'sepal_width': pl.Float64,
            'petal_length': pl.Float64,
            'petal_width': pl.Float64,
            'species': pl.Utf8
        }

class ExtraIrisSchema(AbstractSchema):
    def _get_schema(self):
        return {
            'sepal_length': pl.Float64,
            'sepal_width': pl.Float64,
            'petal_length': pl.Float64,
            'petal_width': pl.Float64,
            'species': pl.Utf8,
            'other_species': pl.Utf8
        }
#new decorator
def schema_check(func):
    # wrapper function
    func_type_hints = get_type_hints(func)
    func_name = func.__name__

    assert 'return' in func_type_hints, f"Func {func_name} did not have a return type."
    return_type = func_type_hints["return"]

    # TODO: Extend this to internal?
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)

        return_type()._validate(result, func_name)
        return result
    return wrapper

@schema_check
def get_iris_schema_fail() -> IrisSchema:
    return  get_df().with_columns(pl.col("species").alias('other_species'))

@schema_check
def get_iris_schema_pass() -> IrisSchema:
    df = get_df()
    instance = IrisSchema
    instance._df = df._df

    return instance


@schema_check
def get_iris_schema_extra_fail() -> ExtraIrisSchema:
    return get_df()

@schema_check
def get_iris_schema_extra_pass() -> ExtraIrisSchema:

    df = get_df().with_columns(pl.col("species").alias('other_species'))
    instance = ExtraIrisSchema
    instance._df = df._df

    return instance

@schema_check
def get_iris_schema_extra_fail2() -> ExtraIrisSchema:
    return get_df().with_columns(pl.col("species").alias('species_copy'))


#new decorator
def schema_check_single_type(arg_name: str, type_hint: Any, type_var: Any, func_name: str):
    if issubclass(type_hint, AbstractSchema):
        type_hint()._validate(type_var, func_name)
        print(f"Argument {arg_name} was validated to be {type_hint}")
    else:
        print(f"Argument {arg_name} was not a schema, skip.")

def check_single_argument_or_result(type_hint_name: str, type_hint_type: Any, var_to_check: Any, func_name: str):
    type_hint_instance = type_hint_type()
    if isinstance(type_hint_instance, tuple):
        print("Found a type that is a tuple")
        subtypes = type_hint_type.__args__
        for j, subtype in enumerate(subtypes):
            subresult = var_to_check[j]
            schema_check_single_type(f"type_hint_name[{j}]", subtype, subresult, func_name)
    else:
        schema_check_single_type(type_hint_name, type_hint_type, var_to_check, func_name)

def schema_check_full(func: Callable) -> Callable:
    # wrapper function
    func_type_hints = get_type_hints(func)
    func_name = func.__name__

    assert 'return' in func_type_hints, f"Func {func_name} did not have a return type."

    # TODO: Extend this to internal?
    def wrapper(*args, **kwargs):
        result = func(*args, **kwargs)

        i = 0
        for type_hint_name, type_hint_type in func_type_hints.items():
            print(f"CHECKING {type_hint_name}: {type_hint_type}")
            var_to_check = result

            if type_hint_name != 'return':
                    var_to_check = args[i]

            check_single_argument_or_result(
                type_hint_name,
                type_hint_type,
                var_to_check,
                func_name
            )

            i += 1
    return wrapper

@schema_check_full
def test(a: IrisSchema, b: str) -> tuple[IrisSchema, ExtraIrisSchema, str]:
    return a, get_iris_schema_extra_pass(), b

@schema_check_full
def test_fail(a: IrisSchema, b: str) -> tuple[IrisSchema, ExtraIrisSchema, str]:
    return a, get_iris_schema_extra_pass()

print("Test should fail: Do we get an iris schema?")
try:
    get_iris_schema_fail()
except AssertionError as e:
    print(f"ERROR: {e}")

print("Test should pass: We get an iris schema")
try:
    get_iris_schema_pass()
except AssertionError as e:
    print(f"ERROR: {e}")


print("Test should fail: Do we get an extra iris schema ?")
try:
    get_iris_schema_extra_fail()
except AssertionError as e:
    print(f"ERROR: {e}")

print("Test should pass: We get an extra iris schema")

try:
    get_iris_schema_extra_pass()
except AssertionError as e:
    print(f"ERROR: {e}")

print("Test should fail: Do we get an extra iris schema ? (2)")
try:
    get_iris_schema_extra_fail2()
except AssertionError as e:
    print(f"ERROR: {e}")

print("Test should pass: We have the right schemas on input and output")
try:
    test(get_iris_schema_pass(), "cookies")
except AssertionError as e:
    print(f"ERROR: {e}")

print("Test should fail: We have the an empty dataframe as input")
try:
    test(pl.DataFrame(), "cookies")
except AssertionError as e:
    print(f"ERROR: {e}")

print("Test should fail: I didn't return the right things")
try:
    test_fail(get_iris_schema_pass(), "cookies")
except IndexError as e:
    print(f"ERROR: {e}")
except AssertionError as e:
    print(f"ERROR: {e}")


<cell>95: error: Incompatible return value type (got "DataFrame", expected "IrisSchema")  [return-value]
ERROR:nb-mypy:<cell>95: error: Incompatible return value type (got "DataFrame", expected "IrisSchema")  [return-value]
<cell>103: error: Incompatible return value type (got "type[IrisSchema]", expected "IrisSchema")  [return-value]
ERROR:nb-mypy:<cell>103: error: Incompatible return value type (got "type[IrisSchema]", expected "IrisSchema")  [return-value]
<cell>108: error: Incompatible return value type (got "DataFrame", expected "ExtraIrisSchema")  [return-value]
ERROR:nb-mypy:<cell>108: error: Incompatible return value type (got "DataFrame", expected "ExtraIrisSchema")  [return-value]
<cell>117: error: Incompatible return value type (got "type[ExtraIrisSchema]", expected "ExtraIrisSchema")  [return-value]
ERROR:nb-mypy:<cell>117: error: Incompatible return value type (got "type[ExtraIrisSchema]", expected "ExtraIrisSchema")  [return-value]
<cell>121: error: Incompatible return va

Test should fail: Do we get an iris schema?
ERROR: Found error(s) in schema: 
	 - Schema for function [get_iris_schema_fail] has extra columns.
Test should pass: We get an iris schema
ERROR: Found error(s) in schema: 
	 - The other schema for function [get_iris_schema_pass] was empty.
Test should fail: Do we get an extra iris schema ?
ERROR: Found error(s) in schema: 
	 - other_species not found in schema for function [get_iris_schema_extra_fail].
	 - Schema for function [get_iris_schema_extra_fail] has extra columns.
Test should pass: We get an extra iris schema
ERROR: Found error(s) in schema: 
	 - The other schema for function [get_iris_schema_extra_pass] was empty.
Test should fail: Do we get an extra iris schema ? (2)
ERROR: Found error(s) in schema: 
	 - other_species not found in schema for function [get_iris_schema_extra_fail2].
Test should pass: We have the right schemas on input and output
ERROR: Found error(s) in schema: 
	 - The other schema for function [get_iris_schema_pa