## TODO

- Add validation for metadata -> should be a separate heading (validate alignment across ALL datasets: should be identical)
- Update Dataset class so it doesn't have to take prefix or usecols

In [None]:
%load_ext rich
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import numpy as np
import polars as pl
import pandera.polars as pa
from pandera.typing import DataFrame, Series
from typing import Any
import pyreadstat

from pain.read import *
from fastcore.test import *

In [None]:
from config import METADATA, M
data_dir = Path("../data/raw")

## Data Schema

Define expected data structure for each variable

In [None]:
PN17 = pa.Field(isin=(-99, 0, 1), coerce=True)
PN25 = pa.Field(isin=(-88, -99, 0, 1), coerce=True)
PN34 = pa.Field(isin=(-88, -99, 0, 1), coerce=True)
PN35 = pa.Field(isin=(-88, -99, 0, 1), coerce=True)
PN36 = pa.Field(isin=(-88, -99, 0, 1), coerce=True)

# Only applicable for G217_PQ and G217_SQ
PN9 = pa.Field(isin=(-88, -99, 0, 1), coerce=True)
PN38 = pa.Field(isin=(-88, -99, 0, 1), coerce=True)

## G214_PQ

In [None]:
G214_PQ = Dataset("G214_PQ.sav", data_dir)
df, meta = G214_PQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G214_PQ_PN17").replace({9: -99}),
        pl.col("G214_PQ_PN25").replace({8: -88, 9: -99}),
        pl.col("G214_PQ_PN34").replace({8: -88, 9: -99}),
        pl.col("G214_PQ_PN35").replace({8: -88, 9: -99}),
        pl.col("G214_PQ_PN36").replace({8: -88, 9: -99})
    )
).collect()

In [None]:
class G214PQSchema(pa.DataFrameModel):
    G214_PQ_PN17: Series[int] = PN17
    G214_PQ_PN25: Series[int] = PN25
    G214_PQ_PN34: Series[int] = PN34
    G214_PQ_PN35: Series[int] = PN35
    G214_PQ_PN36: Series[int] = PN36

    @pa.dataframe_check
    def check_for_na(cls, data: pa.PolarsData) -> pl.LazyFrame:
        """Verify when PN17 is 0, all subsequent values are -88 (ie. if no back pain, other variables should be N/A)."""
        s = pl.col("G214_PQ_PN25", "G214_PQ_PN34", "G214_PQ_PN35", "G214_PQ_PN36")
        return data.lazyframe.filter(pl.col("G214_PQ_PN17") == 0).select(s == -88)
    
    @pa.dataframe_check
    def check_for_na2(cls, data: pa.PolarsData) -> pl.LazyFrame:
        """
        Verify the reverse of the above, where if any subsequent variables has a value of -88, PN17 should be 0.
        Except PN35, where there are valid values of -88 when PN17 is 0
        """
        f = ((pl.col("G214_PQ_PN25") == -88) |
             (pl.col("G214_PQ_PN34") == -88) |
             (pl.col("G214_PQ_PN36") == -88))
        return data.lazyframe.filter(f).select(pl.col("G214_PQ_PN17") == 0)

try:
    G214PQSchema.validate(df, lazy=True)
except pa.errors.SchemaErrors as err:
    print(err)

One sensible approach would be to take a collection of Metadata and transform it so rather than having each variable defined independently, in isolation, that data is converted to be in the same format as the current metadata structure of the dataset (ie. Parameters as parents, variables as children)

In [None]:
new_meta = convert_metadata_list_to_dict(M, p="G214_PQ_")
m = merge_dictionaries([new_meta, meta])

TODO:
- Verify that for each parameter, the final metadata matches the defined metadata from `config.py`
- Verify all other metadata remains identical

Use Pydantic model for validation?

## G214_SQ

In [None]:
G214_SQ = Dataset("G214_SQ.sav", data_dir)
df, _ = G214_SQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G214_SQ_PN17").replace({9: -99}),
        pl.col("G214_SQ_PN25").replace({8: -88, 9: -99}),
        pl.col("G214_SQ_PN34").replace({8: -88, 9: -99}),
        pl.col("G214_SQ_PN35").replace({8: -88, 9: -99}),
        pl.col("G214_SQ_PN36").replace({8: -88, 9: -99}),
    )
)

In [None]:
class G214SQSchema(pa.DataFrameModel):
    G214_SQ_PN17: Series[int] = PN17
    G214_SQ_PN25: Series[int] = PN25
    G214_SQ_PN34: Series[int] = PN34
    G214_SQ_PN35: Series[int] = PN35
    G214_SQ_PN36: Series[int] = PN36

    @pa.dataframe_check
    def check_for_na(cls, data: pa.PolarsData) -> pl.LazyFrame:
        """Verify when PN17 is 0, all subsequent values are -88 (ie. if no back pain, other variables should be N/A)."""
        s = pl.col("G214_SQ_PN25", "G214_SQ_PN34", "G214_SQ_PN35", "G214_SQ_PN36")
        return data.lazyframe.filter(pl.col("G214_SQ_PN17") == 0).select(s == -88)
    
    @pa.dataframe_check
    def check_for_na2(cls, data: pa.PolarsData) -> pl.LazyFrame:
        """
        Verify the reverse of the above, where if any subsequent variables has a value of -88, PN17 should be 0.
        Except PN35, where there are valid values of -88 when PN17 is 0
        """
        f = ((pl.col("G214_SQ_PN25") == -88) |
             (pl.col("G214_SQ_PN34") == -88) |
             (pl.col("G214_SQ_PN36") == -88))
        return data.lazyframe.filter(f).select(pl.col("G214_SQ_PN17") == 0)

try:
    df = G214SQSchema.validate(df, lazy=True)
except pa.errors.SchemaErrors as err:
    print(err)

## G217_PQ

In [None]:
G217_PQ = Dataset("G217_PQ.sav", data_dir)
df, _ = G217_PQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G217_PQ_PN17").replace({9: -99}),
        pl.col("G217_PQ_PN9").replace({7: -99, 9: -99}),
        pl.col("G217_PQ_PN38").replace({7: -99, 9: -99}),
        pl.col("G217_PQ_PN25").replace({7: -99, 9: -99}),
        pl.col("G217_PQ_PN34").replace({7: -99, 9: -99}),
        pl.col("G217_PQ_PN35").replace({7: -99, 9: -99}),
        pl.col("G217_PQ_PN36").replace({7: -99, 9: -99})
    )
)

In [None]:
class G217PQSchema(pa.DataFrameModel):
    G217_PQ_PN17: Series[int] = PN17
    G217_PQ_PN9: Series[int] = PN9
    G217_PQ_PN38: Series[int] = PN38
    G217_PQ_PN25: Series[int] = PN25
    G217_PQ_PN34: Series[int] = PN34
    G217_PQ_PN35: Series[int] = PN35
    G217_PQ_PN36: Series[int] = PN36

try:
    df = G217PQSchema.validate(df, lazy=True)
except pa.errors.SchemaErrors as err:
    print(err)

## G217_SQ

In [None]:
G217_SQ = Dataset("G217_SQ.sav", data_dir)
df, _ = G217_SQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G217_SQ_PN17").replace({9: -99}),
        pl.col("G217_SQ_PN9").replace({9: -99}),
        pl.col("G217_SQ_PN38").replace({9: -99}),
        pl.col("G217_SQ_PN25").replace({9: -99}),
        pl.col("G217_SQ_PN34").replace({9: -99}),
        pl.col("G217_SQ_PN35").replace({9: -99}),
        pl.col("G217_SQ_PN36").replace({9: -99})
    )
)

In [None]:
class G217SQDataSchema(pa.DataFrameModel):
    G217_SQ_PN17: Series[int] = PN17
    G217_SQ_PN9: Series[int] = PN9
    G217_SQ_PN38: Series[int] = PN38
    G217_SQ_PN25: Series[int] = PN25
    G217_SQ_PN34: Series[int] = PN34
    G217_SQ_PN35: Series[int] = PN35
    G217_SQ_PN36: Series[int] = PN36

try:
    df = G217SQDataSchema.validate(df, lazy=True)
except pa.errors.SchemaErrors as err:
    print(err)