## TODO

- Add validation for metadata -> should be a separate heading (validate alignment across ALL datasets: should be identical)
- Update Dataset class so it doesn't have to take prefix or usecols

In [None]:
%load_ext rich
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import numpy as np
import polars as pl
import pandera.polars as pa
from pandera.typing import DataFrame, Series
from typing import Any
import pyreadstat
from functools import partial

from pain.read import *
from fastcore.test import *

In [None]:
from metadata import METADATA
data_dir = Path("../data/raw")

## Data Schema

Define expected data structure for each variable

In [None]:
PN17 = partial(pa.Field, isin=(-99, 0, 1), coerce=True)
PN25 = partial(pa.Field, isin=(-88, -99, 0, 1), coerce=True)
PN34 = partial(pa.Field, isin=(-88, -99, 0, 1), coerce=True)
PN35 = partial(pa.Field, isin=(-88, -99, 0, 1), coerce=True)
PN36 = partial(pa.Field, isin=(-88, -99, 0, 1), coerce=True)

# Only applicable for G217_PQ and G217_SQ
PN9 = partial(pa.Field, isin=(-88, -99, 0, 1), coerce=True)
PN38 = partial(pa.Field, isin=(-88, -99, 0, 1), coerce=True)

## G214_PQ

In [None]:
G214_PQ = Dataset("G214_PQ.sav", data_dir)
df, meta = G214_PQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G214_PQ_PN17").replace({9: -99}),
        pl.col("G214_PQ_PN25").replace({8: -88, 9: -99}),
        pl.col("G214_PQ_PN34").replace({8: -88, 9: -99}),
        pl.col("G214_PQ_PN35").replace({8: -88, 9: -99}),
        pl.col("G214_PQ_PN36").replace({8: -88, 9: -99})
    )
).collect()

In [None]:
class G214PQSchema(pa.DataFrameModel):
    G214_PQ_PN17: Series[int] = PN17()
    G214_PQ_PN25: Series[int] = PN25()
    G214_PQ_PN34: Series[int] = PN34()
    G214_PQ_PN35: Series[int] = PN35()
    G214_PQ_PN36: Series[int] = PN36()

    @pa.dataframe_check
    def check_for_na(cls, data: pa.PolarsData) -> pl.LazyFrame:
        """Verify when PN17 is 0, all subsequent values are -88 (ie. if no back pain, other variables should be N/A)."""
        s = pl.col("G214_PQ_PN25", "G214_PQ_PN34", "G214_PQ_PN35", "G214_PQ_PN36")
        return data.lazyframe.filter(pl.col("G214_PQ_PN17") == 0).select(s == -88)
    
    @pa.dataframe_check
    def check_for_na2(cls, data: pa.PolarsData) -> pl.LazyFrame:
        """
        Verify the reverse of the above, where if any subsequent variables has a value of -88, PN17 should be 0.
        Except PN35, where there are valid values of -88 when PN17 is 0
        """
        f = ((pl.col("G214_PQ_PN25") == -88) |
             (pl.col("G214_PQ_PN34") == -88) |
             (pl.col("G214_PQ_PN36") == -88))
        return data.lazyframe.filter(f).select(pl.col("G214_PQ_PN17") == 0)

try:
    G214PQSchema.validate(df, lazy=True)
except pa.errors.SchemaErrors as err:
    print(err)

## Explore Metadata Validation

One sensible approach would be to take a collection of Metadata and transform it so rather than having each variable defined independently, in isolation, that data is converted to be in the same format as the current metadata structure of the dataset (ie. Parameters as parents, variables as children)

In [None]:
new_meta = convert_metadata_list_to_dict(METADATA, p="G214_PQ_")
m = merge_dictionaries([new_meta, meta])

In [None]:
from pydantic import BaseModel, ValidationError, field_validator

# General validation class for metadata in nested dictionary format
class MetaDict(BaseModel):
    label: dict[str, str|None] # TODO: replace None with empty string? Valid error - should have label?
    field_values: dict[str, dict[int, str]]
    field_type: dict[str, str] # TODO: Add validator to ensure values are one of ["Numeric", "String", "Date"]
    field_width: dict[str, int]
    decimals: dict[str, int]
    variable_type: dict[str, str]

    # TODO: validator fail because they're trying to read individual elements, not the larger dict
    # Can I create BaseModels for each parameter? Hard without knowing and specifying the individual variables

    # @field_validator('variable_type', mode='before')
    # @classmethod
    # def lowercase(cls, value: str) -> str:
    #     return value.lower()

    # @field_validator('field_type', mode='before')
    # @classmethod
    # def check_field_type(cls, value: str) -> str:
    #     if value in ["Numeric", "String", "Date"]:
    #         return value
    #     else:
    #         raise ValueError(f"{value} not one of 'Numeric', 'String' or 'Date'.")

    # TODO: check that the keys of each dict are identical

In [None]:
#| export
def munge_keys(m: dict[str, dict[str, Any]]) -> dict[str, dict[str, Any]]:
    """Munge outer keys of dictionary for use with Pydantic."""
    return {k.lower().replace(" ", "_"): v for k, v in m.items()}

In [None]:
MetaDict(**munge_keys(new_meta))


[1;35mMetaDict[0m[1m([0m
    [33mlabel[0m=[1m{[0m
        [32m'G214_PQ_PN17'[0m: [32m'Ever had back pain'[0m,
        [32m'G214_PQ_PN25'[0m: [32m'Sought professional advice/treatment'[0m,
        [32m'G214_PQ_PN34'[0m: [32m'Took medication to relieve pain'[0m,
        [32m'G214_PQ_PN35'[0m: [32m'Missed work due to pain'[0m,
        [32m'G214_PQ_PN36'[0m: [32m'Pain interfered with normal activities'[0m,
        [32m'G214_PQ_PN9'[0m: [32m'Ever had neck/shoulder pain'[0m,
        [32m'G214_PQ_PN38'[0m: [32m'Ever had low back pain'[0m
    [1m}[0m,
    [33mfield_values[0m=[1m{[0m
        [32m'G214_PQ_PN17'[0m: [1m{[0m[1;36m-99[0m: [32m'Missing'[0m, [1;36m0[0m: [32m'No'[0m, [1;36m1[0m: [32m'Yes'[0m[1m}[0m,
        [32m'G214_PQ_PN25'[0m: [1m{[0m[1;36m-88[0m: [32m'N/A'[0m, [1;36m-99[0m: [32m'Missing'[0m, [1;36m0[0m: [32m'No'[0m, [1;36m1[0m: [32m'Yes'[0m[1m}[0m,
        [32m'G214_PQ_PN34'[0m: [1m{[0m[1;36m-8

In [None]:
# Munge variable names to suit Pydantic
# TODO: convert this into an actual reusable function
m2 = munge_keys(m)
try:
    MetaDict(**m2)
except ValidationError as err:
    print(err)

In [None]:
class Label(BaseModel):
    G214_PQ_PN17: str = 'Ever had back pain'
    G214_PQ_PN25: str = 'Sought professional advice/treatment'
    G214_PQ_PN34: str = 'Took medication to relieve pain'
    G214_PQ_PN35: str = 'Missed work due to pain'
    G214_PQ_PN36: str = 'Pain interfered with normal activities'

In [None]:
class Meta(BaseModel):
    label: Label

In [None]:
Meta(**munge_keys(new_meta))


[1;35mMeta[0m[1m([0m
    [33mlabel[0m=[1;35mLabel[0m[1m([0m
        [33mG214_PQ_PN17[0m=[32m'Ever had back pain'[0m,
        [33mG214_PQ_PN25[0m=[32m'Sought professional advice/treatment'[0m,
        [33mG214_PQ_PN34[0m=[32m'Took medication to relieve pain'[0m,
        [33mG214_PQ_PN35[0m=[32m'Missed work due to pain'[0m,
        [33mG214_PQ_PN36[0m=[32m'Pain interfered with normal activities'[0m
    [1m)[0m
[1m)[0m


## G214_SQ

In [None]:
G214_SQ = Dataset("G214_SQ.sav", data_dir)
df, _ = G214_SQ.load_data()

In [None]:
df = (
    df
    .with_columns(
        pl.col("G214_SQ_PN17").replace({9: -99}),
        pl.col("G214_SQ_PN25").replace({8: -88, 9: -99}),
        pl.col("G214_SQ_PN34").replace({8: -88, 9: -99}),
        pl.col("G214_SQ_PN35").replace({8: -88, 9: -99}),
        pl.col("G214_SQ_PN36").replace({8: -88, 9: -99}),
    )
)

In [None]:
class DataSchema_G214_SQ(pa.DataFrameModel):
    G214_SQ_PN17: Series[int] = PN17(nullable=True)
    G214_SQ_PN25: Series[int] = PN25(nullable=True)
    G214_SQ_PN34: Series[int] = PN34(nullable=True)
    G214_SQ_PN35: Series[int] = PN35(nullable=True)
    G214_SQ_PN36: Series[int] = PN36(nullable=True)

    @pa.dataframe_check
    def check_for_na(cls, data: pa.PolarsData) -> pl.LazyFrame:
        """Verify when PN17 is 0, all subsequent values are -88 (ie. if no back pain, other variables should be N/A)."""
        s = pl.col("G214_SQ_PN25", "G214_SQ_PN34", "G214_SQ_PN35", "G214_SQ_PN36")
        return data.lazyframe.filter(pl.col("G214_SQ_PN17") == 0).select(s == -88)
    
    @pa.dataframe_check
    def check_for_na2(cls, data: pa.PolarsData) -> pl.LazyFrame:
        """
        Verify the reverse of the above, where if any subsequent variables has a value of -88, PN17 should be 0.
        Except PN35, where there are valid values of -88 when PN17 is 0
        """
        f = ((pl.col("G214_SQ_PN25") == -88) |
             (pl.col("G214_SQ_PN34") == -88) |
             (pl.col("G214_SQ_PN36") == -88))
        return data.lazyframe.filter(f).select(pl.col("G214_SQ_PN17") == 0)

try:
    DataSchema_G214_SQ.validate(df.collect(), lazy=True)
except pa.errors.SchemaErrors as err:
    print(err)

## G217_PQ

In [None]:
G217_PQ = Dataset("G217_PQ.sav", data_dir)
df, _ = G217_PQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G217_PQ_PN17").replace({9: -99}),
        pl.col("G217_PQ_PN9").replace({7: -99, 9: -99}),
        pl.col("G217_PQ_PN38").replace({7: -99, 9: -99}),
        pl.col("G217_PQ_PN25").replace({7: -99, 9: -99}),
        pl.col("G217_PQ_PN34").replace({7: -99, 9: -99}),
        pl.col("G217_PQ_PN35").replace({7: -99, 9: -99}),
        pl.col("G217_PQ_PN36").replace({7: -99, 9: -99})
    )
)

In [None]:
class G217PQSchema(pa.DataFrameModel):
    G217_PQ_PN17: Series[int] = PN17()
    G217_PQ_PN9: Series[int] = PN9()
    G217_PQ_PN38: Series[int] = PN38()
    G217_PQ_PN25: Series[int] = PN25()
    G217_PQ_PN34: Series[int] = PN34()
    G217_PQ_PN35: Series[int] = PN35()
    G217_PQ_PN36: Series[int] = PN36()

try:
    df = G217PQSchema.validate(df, lazy=True)
except pa.errors.SchemaErrors as err:
    print(err)

## G217_SQ

In [None]:
G217_SQ = Dataset("G217_SQ.sav", data_dir)
df, _ = G217_SQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G217_SQ_PN17").replace({9: -99}),
        pl.col("G217_SQ_PN9").replace({9: -99}),
        pl.col("G217_SQ_PN38").replace({9: -99}),
        pl.col("G217_SQ_PN25").replace({9: -99}),
        pl.col("G217_SQ_PN34").replace({9: -99}),
        pl.col("G217_SQ_PN35").replace({9: -99}),
        pl.col("G217_SQ_PN36").replace({9: -99})
    )
)

In [None]:
class G217SQDataSchema(pa.DataFrameModel):
    G217_SQ_PN17: Series[int] = PN17()
    G217_SQ_PN9: Series[int] = PN9()
    G217_SQ_PN38: Series[int] = PN38()
    G217_SQ_PN25: Series[int] = PN25()
    G217_SQ_PN34: Series[int] = PN34()
    G217_SQ_PN35: Series[int] = PN35()
    G217_SQ_PN36: Series[int] = PN36()

try:
    df = G217SQDataSchema.validate(df, lazy=True)
except pa.errors.SchemaErrors as err:
    print(err)