In [None]:
# Imports
import polars as pl
from pathlib import Path
from loguru import logger
import numpy as np
from dataclasses import dataclass

# Generic Helpers

Some generic helpers to simplify the code below

In [None]:
def on_unwrapped(df: pl.DataFrame, func, unwrap_cols: list) -> pl.DataFrame:
    """
    Perform an operation on a temporarily unwrapped dataframe.

    Args:
        df  : DataFrame to work on
        func : func to apply to the unwrapped dataframe
        unwrap_cols : Column names to unwrap

    Returns:
        A rewrapped dataframe after application of func on the unwrapped one.
    """
    df = df.with_row_index("remap_index")
    unwrap_df = df.explode(*unwrap_cols)
    unwrap_df = func(unwrap_df)
    rewrap_df = unwrap_df.group_by("remap_index").agg(*unwrap_cols)
    return df.update(rewrap_df, on="remap_index").drop("remap_index")

# Aligning Sequence Numbers

Our data consists of multiple receivers collecting CSI harvested from the exact same signal, split from one antenna using an RF-waveguide-splitter. Not always do the cards see the same thing. For a direct 1-1 comparison between values, we might only want to consider data points that all devices have caught up.

Unmodified frames are sent with increasing sequence numbers, which serve as unique identifier of frames. Hence, by comparing sequence numbers, we can figure out which packages were indeed used by all receivers.

In [None]:
def align_sequence_nums(df: pl.DataFrame) -> pl.DataFrame:
    """
    Sequence number sequentially enumerates the packets sent to invoke CSI collection.
    To align sequence numbers, we discard all CSI values in which at least one card
    has failed the capture.
    """
    logger.info("Aligning sequence numbers ...")

    # Only keep captures where all cards successfully saw CSI
    df = (
        df.group_by("session_id", maintain_order=True)
        .agg(pl.col("receiver_name").n_unique().alias("n_receivers"))
        .join(df, on="session_id")
    )

    filtered_sequences = (
        df.group_by("sequence_number", "session_id", maintain_order=True)
        .agg(
            pl.count("receiver_name").alias("sns_count"), pl.col("n_receivers").first()
        )
        .filter(pl.col("sns_count") == pl.col("n_receivers"))
    )

    return df.join(filtered_sequences, on=["sequence_number", "session_id"], how="semi")

# Discard non-precoded frames

We precode frames to emulate channel changes. To have a baseline, we often also use "training" or "warmup" runs, in which we do send the unmodified frame. For some postprocessing, these frames are irrelevant and can be dropped.

In [None]:
def discard_unmodified(df: pl.DataFrame) -> pl.DataFrame:
    """
    Before sending curve precoded frames, we send a "warmup" sequence of unmodified
    frames. This functions discards those from the dataframe.
    """
    logger.info("Discarding unmodified frame captures ...")
    # Some of our captures are taken without any movement or obstructions in between.
    # These can possibly be used later to normalize e.g. scalings. However, for the
    # visualization at hand, we do not need them.
    UNMODIFIED_MASK_NAME: str = "_unmodified"
    return df.filter(pl.col("mask_name") != UNMODIFIED_MASK_NAME)

# Merging session pairs

One issue in our setup is that two of the cards are incompatible (iwl5300 and qca). More precisely, the frames which the corresponding tools require to yield CSI are incompatible. Hence, we can not collect CSI from these cards concurrently.

In our experiments, we run sessions including either of those cards back-to-back.

Since the remaining cards can always capture (and we make use of that), this creates an imbalance in the amount of data available per receiver. In session pair merging, we try to "merge" the data from two back-to-back sessions, pretending as if they were from one.

**NOTE**: This only works as long as no session is missing. Sessions have identifiers but are not enumerated, so when one of the pair sessions is missing, we do not detect that.

In [None]:
def merge_session_pairs(df: pl.DataFrame) -> pl.DataFrame:
    """
    While qca and iwl5300 are incompatible, ax210 and asus can always collect CSI.
    We perform back-to-back sessions, one for iwl and one for qca, with the other
    cards always capturing. This creates a data imbalance.

    This function merges all pairs of such back-to-back sessions, taking only the
    "new" values from the second pair.

    **NOTE**: This assumes that there are actually pairs in the data!
    If any of the sessions failed, this order is broken and the function invalid.
    """
    logger.info("Balancing data by merging pairs of captures...")
    session_pair_idx = (
        df.select("session_id")
        .unique("session_id", maintain_order=True)
        .with_row_index("num_sess_pair")
        .with_columns(pl.col("num_sess_pair") // 2)
    )
    df = df.join(session_pair_idx, on="session_id")

    # Function to take groups with session pairs and merge them!
    def session_merger(group):
        unique_ids = group.unique("session_id", maintain_order=True)
        assert unique_ids.shape[0] == 2, "Session merger must be fed pairs of sessions!"
        id_a = unique_ids.item(0, "session_id")
        id_b = unique_ids.item(1, "session_id")

        # Take receivers from session and append data captured by qca in following session
        new_df = pl.concat(
            [
                group.filter(pl.col("session_id") == id_a),
                group.filter(pl.col("session_id") == id_b).filter(
                    pl.col("receiver_name") == "qca"
                ),
            ]
        )

        return new_df

    return (
        df.group_by("num_sess_pair", maintain_order=True)
        .map_groups(session_merger)
        .drop("num_sess_pair")
    )

# Discarding low counts

Sometimes the cards perform subpar. When cards drop lots of frames, the data is expected to look pretty bad, especially in terms of temporal coherence. It is thus desirable to drop it usually. 

In some experiments, e.g. the curve experiment, we also want to drop data if, possibly as a result of dropping low session counts, not enough sessions are present to represent a precoding curve.

In [None]:
def discard_low_counts(df: pl.DataFrame, min_count: int = 600, grouped: bool = False):
    """
    Discard all sessions in which the number of CSI captured is below threshold.

    Args:
        min_count : The threshold minimum count of CSI values
        grouped   : Whether to discard in neighboring pairs of two. Should be done only
            when groups are present and haven't been merged.
    """
    logger.info(f"Discarding sessions with low  (<{min_count}) CSI capture counts ...")

    # Assumes sequence numbers were aligned; This way we can find the captured
    # frame counts per session by considering unique SNs per session
    sess_counts = (
        df.unique(
            ["session_id", "sequence_number", "receiver_name"], maintain_order=True
        )
        .group_by("session_id", "receiver_name", maintain_order=True)
        .len()
    )

    # We discard in groups of two, since we capture in groups of two (iwl and qca)
    to_discard = sess_counts.filter(pl.col("len") < min_count)

    if grouped:
        # Enumerate the sessions
        to_discard = to_discard.join(
            to_discard.select("session_id")
            .unique("session_id", maintain_order=True)
            .with_row_index("session_idx"),
            on="session_id",
        )

        # Find even and odd sessions that are to be discarded
        even_discards = to_discard.filter(pl.col("session_idx") % 2 == 0)
        odd_discards = to_discard.filter(pl.col("session_idx") % 2 == 1)

        # Merge with the correspondingly paired sessions
        to_discard = pl.concat(
            [
                even_discards,
                odd_discards,
                even_discards.with_columns(session_idx=pl.col("session_idx") + 1),
                odd_discards.with_columns(session_idx=pl.col("session_idx") - 1),
            ]
        )

    logger.info(
        "Discarding sessions with insufficient capture rates : \n"
        + f"{to_discard.get_column('session_id').to_list()}"
    )

    df = df.join(to_discard, on="session_id", how="anti")
    return df


def discard_low_session_counts(df: pl.DataFrame, min_num_sessions: int = 10):
    """
    In our experiments, we use curves to precode WiFi frames. Transmission of one curve
    precoded frame sequence is a session. For each curve, we perform multiple sessions.

    This function discards curves completely if an insufficient amount of sessions is
    present in the dataframe.

    Args:
        min_num_sessions: Minimum number of sessions to ensure are present for each
            curve.
    """
    logger.info(
        f"Discarding data for curves with less than {min_num_sessions} sessions ..."
    )

    curves_to_remove = (
        df.unique(["session_id", "num_curve", "receiver_name"], maintain_order=True)
        .group_by("num_curve", "receiver_name", maintain_order=True)
        .len()
        .filter(pl.col("len") < min_num_sessions)
        .drop("len")
    )

    logger.info(
        "Discarding curves: \n"
        + f"{curves_to_remove.get_column('num_curve').to_list()}"
    )
    return df.join(curves_to_remove, on="num_curve", how="anti")

# Removing Edge subcarriers

In Nexmon, specifically the rt-ac86u asus receiver, one of the subcarriers is broken, i.e. never yields sensible data. To still work with a balance dataset of equal shape, we may want to discard the corresponding edge subcarriers (28 and -28) for all receivers.

In [None]:
def remove_edge_subcs(df: pl.DataFrame) -> pl.DataFrame:
    logger.info(
        "Removing edge subcarriers for aligned data (asus has a broken one) ..."
    )
    # Unwrap, remove edge subcarrier, rewrap
    filtered = (
        df.with_row_index()
        .explode("subcarrier_idxs", "csi_abs", "csi_phase")
        .filter((pl.col("subcarrier_idxs") < 28) & (pl.col("subcarrier_idxs") > -28))
        .group_by("index", maintain_order=True)
        .agg(["subcarrier_idxs", "csi_abs", "csi_phase"])
        .drop("index")
    )

    # Overwrite csi abs and subcarrier idxs in original dataframe
    return df.with_columns(filtered.get_columns())

# AGC correction

The intel iwl5300 tool yields a special value for AGC correction. Currently, it is the only card to do so. Judging from experiments with precoding curves and visual comparison, however, the iwl5300 is the only card to not include an AGC corrected RSSI.

The following method incorporates the iwl5300 AGC value into a corrected RSSI.

In [None]:
def agc_correct_rssi(df: pl.DataFrame) -> pl.DataFrame:
    """
    Use AGC values, if available, to apply RSSI correction.

    NOTE: Input must be the original (still wrapped) dataframe
    """
    logger.info("Performing AGC RSSI correction ...")

    df = df.join(
        df.with_columns(
            antenna_rssi_db=10 ** (pl.col("antenna_rssi") / 10),
            rssi_db=10 ** (pl.col("rssi") / 10),
        )
        .group_by(["capture_num"], maintain_order=True)
        .agg(
            pl.col("antenna_rssi_db").sum().alias("combined_antenna_rssi"),
            pl.col("rssi_db").sum().alias("combined_rssi"),
        ),
        on="capture_num",
    )

    return df.with_columns(
        antenna_rssi=pl.when(pl.col("agc").is_null())
        .then(pl.col("antenna_rssi"))
        .otherwise(
            (10 * pl.col("combined_antenna_rssi").log(base=10)) - 44 - pl.col("agc")
        ),
        rssi=pl.when(pl.col("agc").is_null())
        .then(pl.col("rssi"))
        .otherwise((10 * pl.col("combined_rssi").log(base=10)) - 44 - pl.col("agc")),
    ).drop("combined_rssi", "combined_antenna_rssi")

# Phase Detrending

Sampling Time offsets introduce a random-offset linear-slope phase shift across subcarriers. To rephrase this, the very first phase value on a subcarrier is completely uniformly distributed. To make phases even comparable, we need some phase sanitization procedures.

The simplest one is to fix two reference values to get rid of the introduced linear offsets. Specifically, by fixing the outermost subcarriers to have phases $0$, we can linearly detrend them.

In [None]:
def detrend_phase(df: pl.DataFrame) -> pl.DataFrame:
    # ----------------------------------------------------------------------------------
    # -- Linear phase correction to get rid of STO effect
    # ----------------------------------------------------------------------------------
    logger.trace(
        "Detrending phase by fitting line through first and last phase values "
        + "and subtracting that."
    )

    # Columns to use for linear phase normalization
    phase_helper_cols = {
        "phase_slope": (
            pl.col("unwrapped_phase").list.last()
            - pl.col("unwrapped_phase").list.first()
        )
        / pl.col("subcarrier_idxs").list.len(),
        "lowest_sc": pl.col("subcarrier_idxs").list.first(),
        "sc_offset": pl.col("unwrapped_phase").list.first(),
        "unwrapped_phase": pl.col("unwrapped_phase"),
    }

    # Assign row index to allow matching from derived dataframe back to here
    df = df.with_row_index("phase_row_idx")

    # Aggregate phases and calculate helper columns
    phase_helpers = df.with_columns(
        unwrapped_phase=pl.col("csi_phase").map_elements(
            lambda x: np.unwrap(np.array(x)).tolist()
        ),
    ).select("subcarrier_idxs", "phase_row_idx", **phase_helper_cols)

    # To correct linear trend: Shift down by first value, then remove linear slope
    # in dependence of subcarrier number
    lin_corr_expr = (
        pl.col("unwrapped_phase")
        - pl.col("sc_offset")
        - ((pl.col("subcarrier_idxs") - pl.col("lowest_sc")) * pl.col("phase_slope"))
    )

    def phase_sanitizer(unwrapped_df):
        return unwrapped_df.with_columns(normed_csi_phase=lin_corr_expr)

    # Apply phase sanitization on unwrapped dataframe
    sanitized_phases = on_unwrapped(
        phase_helpers, phase_sanitizer, ["unwrapped_phase", "subcarrier_idxs"]
    )

    # Match back
    return df.update(sanitized_phases, on="phase_row_idx").drop("phase_row_idx")

# Subselection

Sharing of gigabyte sized databases is not always needed. Sometimes we just want to take the first $n$ of something, e.g. the first $5$ sessions.

In [None]:
def first_n_curves(df: pl.DataFrame, num_curves: int) -> pl.DataFrame:
    logger.info(f"Extracting data of the first {num_curves} curves present.")
    first_n = (
        df.select("num_curve").unique("num_curve", maintain_order=True).head(num_curves)
    )
    return df.join(first_n, on="num_curve")


def first_n_sessions(df: pl.DataFrame, num_sessions: int) -> pl.DataFrame:
    logger.info(f"Extracting first {num_sessions} sessions for each curve")

    def first_n_filter(group: pl.DataFrame) -> pl.DataFrame:
        keys = (
            group.select("session_id")
            .unique("session_id", maintain_order=True)
            .head(num_sessions)
        )
        test = group.join(keys, on="session_id", how="semi")
        return test

    return df.group_by("num_curve", "receiver_name", maintain_order=True).map_groups(
        first_n_filter
    )

# Data shape alignment

One of the consequences of cards dropping frames is that unequal shapes of DataFrames/arrays occur.

Consider for example a session with $1000$ packets sent. If two cards drop $10$ and $100$ values, respectively, the resulting
shapes for those sessions would be $(990, n_{subcs})$ and $(900, n_{subcs})$. If we want to join these in a numpy array, we will run into errors.

Similar to the sequence number alignment above, we can use them to figure out which values are missing and simply input Nulls. Follow-up routines can then decode whether to get rid of nulls by subsampling or maybe performing interpolation/imputation.

In [None]:
def insert_nulls(df: pl.DataFrame, sequence_len: int) -> pl.DataFrame:
    logger.info("Inserting Nulls for missing sequence number values")

    complete_index = pl.DataFrame(
        {"sequence_number": range(0, sequence_len)},
        schema={"sequence_number": pl.UInt16},
    )

    # NOTE: Inserting nulls inserts null in ALL columns. The metadata columns are of course
    # known even there. For that, we need to fill them back in immediately.
    def null_inserter(group):
        if group.item(0, "mask_name") == "_unmodified":
            group = group.select("sequence_number", pl.exclude("sequence_number"))
            return group

        n_subcs = len(group.item(0, "subcarrier_idxs"))
        cgroup = complete_index.join(group, on="sequence_number", how="left")

        cgroup = (
            cgroup.with_columns(
                cgroup["session_id"]
                .fill_null(strategy="forward")
                .fill_null(strategy="backward")
            )
            .with_columns(
                cgroup["num_curve"]
                .fill_null(strategy="forward")
                .fill_null(strategy="backward")
            )
            .with_columns(
                cgroup["receiver_name"]
                .fill_null(strategy="forward")
                .fill_null(strategy="backward")
            )
            .with_columns(
                cgroup["subcarrier_idxs"]
                .fill_null(strategy="forward")
                .fill_null(strategy="backward")
            )
            .with_columns(
                cgroup["mask_name"]
                .fill_null(strategy="forward")
                .fill_null(strategy="backward")
            )
            .with_columns(cgroup["csi_abs"].fill_null(value=[None] * n_subcs))
            .with_columns(cgroup["csi_phase"].fill_null(value=[None] * n_subcs))
        )

        return cgroup

    # Fill!
    df = df.group_by(
        "session_id", "receiver_name", "mask_name", maintain_order=True
    ).map_groups(null_inserter)

    return df

# Linear interpolation / imputation

The simplest method, which can be done in polars natively, is interpolation by linear values. Note that this doesn't work at the edges, it lineraly interpolates between two values only, not on larger subsets. On the edges, we therefore simply extend the last valid value.

This method is only suitable for very clean datasets in which only few values at a time are missing.

In [None]:
def interpolate(df: pl.DataFrame) -> pl.DataFrame:
    """
    Interpolate missing (Null) using a linear fit and constant extensions at the edges

    TODO: Implement further methods, such as Gaussian Process Regression:
    https://scikit-learn.org/stable/modules/generated/sklearn.gaussian_process.GaussianProcessRegressor.html#sklearn.gaussian_process.GaussianProcessRegressor
    """
    logger.info("Interpolating missing (Null) values")

    def interpolator(unwrapped_df):
        return unwrapped_df.group_by(
            "subcarrier_idxs", "receiver_name", "session_id", maintain_order=True
        ).map_groups(
            lambda group: group.with_columns(
                group.select("timestamp", "csi_abs", "csi_phase", "antenna_rssi")
                .interpolate()
                .fill_null(strategy="forward")
                .fill_null(strategy="backward")
            )
        )

    nested_cols = ["csi_abs", "csi_phase", "subcarrier_idxs"]
    return on_unwrapped(df, interpolator, nested_cols)

# Assembly of preprocessing

Finally, let's assemble all the preprocessing steps

In [None]:
@dataclass(frozen=True)
class PreprocessingConfig:
    discard_low_counts: bool = False
    discard_unmodified: bool = False
    align_sequence_nums: bool = False
    remove_edge_subcs: bool = True
    merge_session_pairs: bool = True
    correct_agc: bool = True
    squash_antennas: bool = True  # Only use when just one antenna used!
    insert_nulls: bool = True
    impute_missing: bool = True
    detrend_phase: bool = False
    sequence_len: int = 4000
    low_count_threshold = 1900
    num_curves = 20
    num_sessions = 10


# -----------------------------------------------------------------------------------
# Preprocess. The dataframe contains a lot of duplicate junk data.
# I need to rewrite that part for deduplication, but didnt take the
# time so far. Therefore, I just discard the junk. You can ignore this,
# the dataframe shared with you will already have those columns dropped
#
# This does:
# - Perform corrections
# - Remove captures of insufficient size
# - Remove curves with insufficient captures
# - Balance out data (since ax210, asus are overrepresented)
# - Align sequence numbers
# - Insert "missing" data with Nulls (based on missing sequence numbers)
# - Remove unnecessary columns
def preprocess(df: pl.DataFrame, config: PreprocessingConfig):
    unused_columns = [
        "experiment_start_time",
        "mask_id",
        "bandwidth",
        "tx_gain",
        "channel",
        "session_timestamp",
        "frame_id",
        "experiment_id",
        "session_name",
        "n_receivers",
    ]

    if config.discard_unmodified:
        df = discard_unmodified(df)
        unused_columns.append("mask_name")

    if config.align_sequence_nums:
        df = align_sequence_nums(df)

    # NOTE: csi_abs and csi_phase are nested for each of the receive antennas
    # Here we unrap this nesting into separate CSI per antenna
    df = df.with_row_index("capture_num").explode(
        "used_antenna_idxs", "csi_abs", "csi_phase", "antenna_rssi"
    )

    # We store labels in the session name. Extract them into separate columns
    df = df.with_columns(
        num_curve=pl.col("session_name")
        .str.extract("^.*shape_([0-9]+)")
        .cast(pl.UInt8),
    )

    df = df.drop(unused_columns)
    df = df.collect()

    # Either merge pairs of iwl/qca captures, or treat as wanting double the
    # amount of sessions
    if config.merge_session_pairs:
        df = merge_session_pairs(df)

    # Discarding low counts and then curves with insufficient session numbers
    if config.discard_low_counts:
        df = discard_low_counts(df, config.low_count_threshold, grouped=False)
        df = discard_low_session_counts(df, config.num_sessions)

    # Extracting first n sessions and curves
    df = first_n_sessions(df, config.num_sessions)
    df = first_n_curves(df, config.num_curves)

    if config.remove_edge_subcs:
        df = remove_edge_subcs(df)

    if config.correct_agc:
        df = agc_correct_rssi(df)
        df = df.drop("agc")

    if config.squash_antennas:
        df = df.drop("rssi", "used_antenna_idxs")

    if config.insert_nulls:
        df = insert_nulls(df, config.sequence_len)

    if config.impute_missing:
        df = interpolate(df)

    if config.detrend_phase:
        df = detrend_phase(df)

    # Drop temporary capture num row index
    df = df.drop("capture_num")
    return df

In [None]:
# Temporary file to store processed parquet in
exp_dir = Path.cwd().parent / "data" / "random_dyncurves_diffspeeds"
tmp_db_file = exp_dir / "damped_preprocessed.parquet"

# Preprocess from data directory if not done yet
df = pl.scan_parquet(exp_dir / "damped.parquet")

config = PreprocessingConfig()
df = preprocess(df, config)

# NOTE: 700 CSI per session, 5 receiver, 10 curves and 20 sessions = 700_000 rows
logger.info(f"FINISHED PREPROCESSING. Shape of preprocessed dataframe: {df.shape}")
logger.info(f"Number of sessions in dataframe: {df.n_unique('session_id')}")
logger.info(f"Number of curves in dataframe: {df.n_unique('num_curve')}")
print(df.columns)

df.write_parquet(tmp_db_file)