## TODO

- Add validation for metadata -> should be a separate heading (validate alignment across ALL datasets: should be identical)
- Update Dataset class so it doesn't have to take prefix or usecols

In [None]:
%load_ext rich
%load_ext autoreload
%autoreload 2

The rich extension is already loaded. To reload it, use:
  %reload_ext rich
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
from pain.read import *
from pathlib import Path
import numpy as np

import polars as pl
import pandera.polars as pa
from pandera.typing import DataFrame, Series

In [None]:
# from config import METADATA
data_dir = Path("../data/raw")

## Metadata Validation

In [None]:
# Read and combine relevant datasets

In [None]:
# 

In [None]:
# Schema to validate metadata
# TODO: create a MetadataClass class which contains a list of metadata containers?

PN17 = Metadata(
    label= "Ever had back pain",
    field_values = {-99: "Missing", 0: "No", 1: "Yes"},
    field_type = "Numeric",
    field_width = 3,
    decimals =  0,
    variable_type = "Nominal"
)

PN25 = Metadata(
    label= "Sought professional advice/treatment",
    field_values = {-88: "N/A", -99: "Missing", 0: "No", 1: "Yes"},
    field_type = "Numeric",
    field_width = 3,
    decimals =  0,
    variable_type = "Nominal"
)


## G214_PQ

In [None]:
G214_PQ = Dataset("G214_PQ.sav", data_dir)
df, _ = G214_PQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G214_PQ_PN17").replace({9: -99}),
        pl.col("G214_PQ_PN25").replace({8: -88, 9: -99}),
    )
).collect()

In [None]:
# check_var_is_na_when_pn17_is_no = pa.Check(
#     lambda df: df["G214_PQ_PN25"] == -88 if df["G214_PQ_PN17"] == 0 
# )
check_even = pa.Check(lambda x: x % 2 == 0, element_wise=True)

check1 = pa.Check(
    lambda s, df: (
        ((df["G214_PQ_PN17"] == 0) & (s == -88)) |
        ((df["G214_PQ_PN17"] == 1) & (s.isin([0, 1, -99]))) |
        (df["G214_PQ_PN17"] == -99)
    ))

def check_pn25_based_on_pn17(df: DataFrame, var: str) -> Series[bool]:
    conditions = [
        (df["G214_PQ_PN17"] == 0) & (df[var] == -88),
        (df["G214_PQ_PN17"] == 1) & (df[var].isin([0, 1, -99])),
        (df["G214_PQ_PN17"].isin([-99]))  # Allow -99 for G214_PQ_PN17
    ]
    return conditions[0] | conditions[1] | conditions[2]

In [None]:
class G214PQDataSchema(pa.DataFrameModel):
    G214_PQ_PN17: Series[int] = pa.Field(isin=(-99, 0, 1), coerce=True)
    G214_PQ_PN25: Series[int] = pa.Field(isin=(-88, -99, 0, 1), coerce=True, check_name=check_even)
    # G214_PQ_PN34: int = pa.Field(coerce=True)
    # G214_PQ_PN35: int = pa.Field(coerce=True)
    # G214_PQ_PN36: int = pa.Field(coerce=True)

    @pa.dataframe_check
    def check_pn25_based_on_pn17(cls, data: pa.PolarsData) -> ...:
        return data.lazyframe.filter(pl.col("G214_PQ_PN17") == 0).select("G214_PQ_PN25").collect().unique().to_dict(as_series=False)["G214_PQ_PN25"] == -88

In [None]:
df.filter(pl.col("G214_PQ_PN17") == 0).select("G214_PQ_PN25").unique().to_dict(as_series=False)["G214_PQ_PN25"]

[1m[[0m[1;36m-88.0[0m[1m][0m

In [None]:
# Create df designed to return an error
fake_df = df.to_pandas()
fake_df.iloc[0, 0] = 1
fake_df.iloc[6, 1] = 1
fake_df = pl.from_pandas(fake_df)

In [None]:
try:
    G214PQDataSchema.validate(df, lazy=True)
except pa.errors.SchemaErrors as err:
    print(err)

## G214_SQ

In [None]:
G214_SQ = Dataset("G214_SQ.sav", data_dir)
df, _ = G214_SQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G214_SQ_PN17").replace({9: -99}),
        pl.col("G214_SQ_PN25").replace({8: -88, 9: -99}),
        
    )
)

In [None]:
class G214SQDataSchema(pa.DataFrameModel):
    G214_SQ_PN17: int = pa.Field(isin=(-99, 0, 1), coerce=True)
    # G214_SQ_PN25: int = pa.Field(isin=(8, 9, 0, 1), coerce=True)
    # G214_SQ_PN34: int = pa.Field(coerce=True)
    # G214_SQ_PN35: int = pa.Field(coerce=True)
    # G214_SQ_PN36: int = pa.Field(coerce=True)

try:
    df = G214SQDataSchema.validate(df, lazy=True)
except Exception as e:
    print(e)

## G217_PQ

In [None]:
G217_PQ = Dataset("G217_PQ.sav", data_dir)
df, _ = G217_PQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G217_PQ_PN17").replace({9: -99}),
        pl.col("G217_PQ_PN25").replace({7: -99, 9: -99}),
    )
)

In [None]:
class G217PQDataSchema(pa.DataFrameModel):
    G217_PQ_PN17: int = pa.Field(isin=(-99, 0, 1), coerce=True)
    # G217_PQ_PN9: int = pa.Field(coerce=True)
    # G217_PQ_PN38: int = pa.Field(coerce=True)
    # G217_PQ_PN25: int = pa.Field(isin=(8, 9, 0, 1), coerce=True)
    # G217_PQ_PN34: int = pa.Field(coerce=True)
    # G217_PQ_PN35: int = pa.Field(coerce=True)
    # G217_PQ_PN36: int = pa.Field(coerce=True)

try:
    df = G217PQDataSchema.validate(df, lazy=True)
except Exception as e:
    print(e)

## G217_SQ

In [None]:
G217_SQ = Dataset("G217_SQ.sav", data_dir)
df, _ = G217_SQ.load_data()

In [None]:
df = (
    df
    .select(
        pl.col("G217_SQ_PN17").replace({9: -99}),
        pl.col("G217_SQ_PN25").replace({9: -99}),
    )
)

In [None]:
class G217SQDataSchema(pa.DataFrameModel):
    G217_SQ_PN17: int = pa.Field(isin=(-99, 0, 1), coerce=True)
    # G217_SQ_PN9: int = pa.Field(coerce=True)
    # G217_SQ_PN38: int = pa.Field(coerce=True)
    # G217_SQ_PN25: int = pa.Field(isin=(8, 9, 0, 1), coerce=True)
    # G217_SQ_PN34: int = pa.Field(coerce=True)
    # G217_SQ_PN35: int = pa.Field(coerce=True)
    # G217_SQ_PN36: int = pa.Field(coerce=True)

try:
    df = G217SQDataSchema.validate(df, lazy=True)
except Exception as e:
    print(e)