In [None]:
%load_ext rich
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import polars as pl
import pointblank as pb
import numpy as np
from rich import print as rprint

from odyssey.core import *
from odyssey.explore import *

In [None]:
from config.paths import RAW_DATA

In [None]:
g220 = Dataset("G220_Q.sav", RAW_DATA)
lf, meta = g220.load_data()

In [None]:
cols = [
    "ID",
    "G220_IPAQ_MOD_D", "G220_IPAQ_MOD_HPD", "G220_IPAQ_MOD_MPD", "G220_IPAQ_MOD_W", 
    "G220_IPAQ_VIG_D", "G220_IPAQ_VIG_HPD", "G220_IPAQ_VIG_MPD", "G220_IPAQ_VIG_W", 
    "G220_IPAQ_WALK_D", "G220_IPAQ_WALK_HPD", "G220_IPAQ_WALK_MPD", "G220_IPAQ_WALK_W", 
    "G220_IPAQ_SIT_WD_HPD", "G220_IPAQ_SIT_WD_MPD", 
    "G220_SIT_WD_TRUNC", "G220_IPAQ_SIT_COM",
    "G220_VIG_MET", "G220_VIG_MINS", 
    "G220_MOD_MET", "G220_MOD_MINS", 
    "G220_WALK_MET", "G220_WALK_MINS",
    "G220_IPAQ_CAT", "G220_TOT_MET", 
]

In [None]:
df = lf.select(cols).collect()

In [None]:
from typing import Callable

In [None]:
def check_total_mins(
    hpd_column: str,
    mpd_column: str
    ) -> Callable:
    """
    Returns a preprocessing function to verify the minutes for a given category have been correctly calculated.
    Cap the total at 180 minutes, and preserve null values.
    """
    def preprocessor(df: pl.DataFrame) -> pl.DataFrame:
        return df.with_columns(
            (pl.col(hpd_column) * 60 + pl.col(mpd_column))
            .pipe(lambda expr: pl.when(expr > 180).then(180).otherwise(expr))
            .alias("check")
        )
    return preprocessor

def check_met(
    mins_column: str, 
    n_days_column: str, 
    met_column: str,
    factor: int|float # the corresponding factor for the activity (Vig: 8, Mod: 4, Walk: 3.3)
    ) -> Callable:
    """Returns a preprocessing function to verify the calculated MET value for a given category."""
    def preprocessor(df: pl.DataFrame) -> pl.DataFrame:
        return df.with_columns(
            (pl.col(mins_column).fill_null(0) * pl.col(n_days_column).fill_null(0) * factor).alias("check"),
            pl.col(met_column).fill_null(0)
        )
    return preprocessor

def check_tot_met(
    met_columns: list[str],
    tot_met_column: str
    ) -> Callable:
    """Returns a preprocessing function to verify the calculated total MET value."""
    def preprocessor(df: pl.DataFrame) -> pl.DataFrame:
        expr = sum(pl.col(col).fill_null(0) for col in met_columns)
        
        return df.with_columns(
            expr.alias("check"),
            pl.col(tot_met_column).fill_null(0) # Fill nulls with 0; otherwise the validation skips if one value in a comparison is null
        )
    return preprocessor

In [None]:
validation = (
    pb.Validate(
        data=df,
    )
    .col_vals_eq(
        columns="G220_VIG_MINS",
        value=pb.col("check"),
        pre=check_total_mins("G220_IPAQ_VIG_HPD", "G220_IPAQ_VIG_MPD"),
        brief="Check total mins/day equals `HPD*60 + MPD`"
    )
    .col_vals_eq(
        columns="G220_MOD_MINS",
        value=pb.col("check"),
        pre=check_total_mins("G220_IPAQ_MOD_HPD", "G220_IPAQ_MOD_MPD"),
        brief="Check total mins/day equals `HPD*60 + MPD`"
    )
    .col_vals_eq(
        columns="G220_WALK_MINS",
        value=pb.col("check"),
        pre=check_total_mins("G220_IPAQ_WALK_HPD", "G220_IPAQ_WALK_MPD"),
        brief="Check total mins/day equals `HPD*60 + MPD`"
    )
    .col_vals_eq(
        columns="G220_VIG_MET",
        value=pb.col("check"),
        pre=check_met("G220_VIG_MINS", "G220_IPAQ_VIG_D", "G220_VIG_MET", factor=8),
    )
    .col_vals_eq(
        columns="G220_MOD_MET",
        value=pb.col("check"),
        pre=check_met("G220_MOD_MINS", "G220_IPAQ_MOD_D", "G220_MOD_MET", factor=4),
    )
    .col_vals_eq(
        columns="G220_WALK_MET",
        value=pb.col("check"),
        pre=check_met("G220_WALK_MINS", "G220_IPAQ_WALK_D", "G220_WALK_MET", factor=3.3),
    )
    .col_vals_eq(
        columns="G220_TOT_MET",
        value=pb.col("check"),
        pre=check_tot_met(["G220_VIG_MET", "G220_MOD_MET", "G220_WALK_MET"], "G220_TOT_MET"),
        brief="Check `TOT_MET` equals the sum of `VIG_MET`, `MOD_MET`, and `WALK_MET`"
    )
    .interrogate()
)

validation

In [None]:
validation = (
    pb.Validate(
        data=df,
        brief="Step {step}: {auto}"
    )
    # If `VIG_W` was answered (0 or 1), corresponding values for `VIG_MINS` and `VIG_MET` exist are >= 0
    .col_vals_between(
        columns=pb.matches("VIG_MINS"), 
        left=0,
        right=180,
        segments=("G220_IPAQ_VIG_W", [0, 1]) # Only look at instances where there's the participant answered the question (ignore if skipped)
    )
    .col_vals_ge(
        columns=pb.matches("VIG_MET"), 
        value=0,
        segments=("G220_IPAQ_VIG_W", [0, 1]) # Only look at instances where there's the participant answered the question (ignore if skipped)
    )
    # If `VIG_W` is null, corresponding values for `VIG_MINS` and `VIG_MET` are also null
    # TODO: simplify so it just accepts and segments by null values, rather than filling null with -1
    .col_vals_null(
        columns=pb.matches("VIG_(MINS|MET)"),
        pre=lambda df: df.with_columns(pl.col("G220_IPAQ_VIG_W").fill_null(-1)), # Pointblank doesn't seem to like segmenting values with null, so transform null to -1 and segment y that
        segments=("G220_IPAQ_VIG_W", -1)
    )
    .col_vals_eq(
        columns=pb.matches("VIG_MET"), 
        value=0,
        segments=("G220_IPAQ_VIG_D", [0]) # Only look at instances where there's the participant answered the question (ignore if skipped)
    )
    .interrogate()
)

validation