In [None]:
from sensession.campaign import CampaignProcessor

import polars as pl
import numpy as np

from pathlib import Path
from IPython.display import display

# Data loading

Jupyter was a bit finicky to get to work with the large amounts of data (read: RAM) required to perform preprocessing. Instead, we have a script in `evaluation/preprocess_single_scs.py` that performs the necessary preprocessing steps and stores checkpoints on disk of the corresponding data.

In this notebook, we simply load the data from the checkpoints and visualize the results. In other words: The preprocessing must be run on recorded data before running this notebook.

## Zeroth preprocessing: "Cleaning"

The first steps only involve discarding invalid subcarriers and dropping values irrelevant for this notebook to save some memory.

In [None]:
all_receivers_expr = [
    pl.col("receiver_name") == "asus1",
    pl.col("receiver_name") == "asus2",
    pl.col("receiver_name") == "qca",
    pl.col("receiver_name") == "ESP32_S3",
    pl.col("receiver_name") == "ax210",
    pl.col("receiver_name") == "iwl5300",
    pl.col("receiver_name") == "x310",
]

proc = CampaignProcessor(
    meta_attach_cols={"modified_idx", "session_nr"},
    checkpoint_dir=Path.cwd()
    / ".."
    / ".."
    / ".cache"
    / "checkpoints"
    / "single_phases",
    lazy=True,
)


proc = proc.load_checkpoint("00_subc_cleaned")  # .meta_attach("session_nr")

with pl.Config(tbl_cols=-1):
    display(proc.meta.head().collect())
    display(proc.csi.head().collect())

## Raw Amplitudes

Next, we take a look at the raw amplitudes during warmup sessions.

Warmup sessions consist of repeatedly ($500$ times) retransmitting the exact same unmodified WiFi frame, i.e. where we expect flat CSI.

In [None]:
proc.csi.filter(pl.col("session_nr") < 2).filter(
    pl.col("collection_name").str.contains("warmup")
).select(
    "subcarrier_idxs", "csi_abs", "collection_name", "receiver_name"
).collect().plot.scatter(
    x="subcarrier_idxs", y="csi_abs", by="collection_name", groupby="receiver_name"
)

## Correlation

We next check how the CSI magnitudes of the warmup sessions are correlated. This means that, for each subcarrier $i$, we interpret the corresponding CSI Amplitude $A_i$ as a random variable. We then compute the Pearson Correlation Coefficients between all of them:

$$
\begin{align}
\rho_{i, j} = \frac{
        \mathbb{E}[A_i - \mu_i] \mathbb{E}[A_j - \mu_j]
    }{
        \sigma_i \sigma_j
    }
\end{align}
$$

where $\mu_i$ and $\sigma_i$ are the mean value and standard deviation of $A_i$, respectively.

The idea is to find out how the "natural fluctuations" are correlated. Does an increase in amplitude in subcarrier $i$ imply an increase in amplitude in subcarrier $j$?

In [None]:
def plot_corr(df: pl.DataFrame, session_nr: int = 0):
    # NOTE: This assumes that we only have one stream and one antenna in the data.
    # # Pivot such that one column for each subcarrier
    csi = (
        df.filter(pl.col("collection_name").str.contains("warmup"))
        .filter(pl.col("session_nr") == session_nr)
        .collect()  # pivot requires manifestation
        .pivot(index="capture_num", on="subcarrier_idxs", values="csi_abs")
        .join(
            proc.csi.select("capture_num", "receiver_name").collect(), on="capture_num"
        )
        .drop("capture_num")  # Attach receiver_name
    )

    # Calculate correlation - Between all subcarrier pairs and for every receiver separately
    # That is, we want to calculate C_ij(r), i and j being subcarriers, r being the receiver.
    # Need to use DataFrame.corr() since pl.corr() only works for 2 variables and we
    # have n_subcarrier cross-correlations to consider
    corr = (
        csi.group_by("receiver_name")
        .map_groups(
            # NOTE:
            # - Have to drop because receiver_name mustn't be considered in correlation
            # - Join receiver_name back after calculation
            # - Join back index (`i` in above description)
            lambda x: (
                x.drop("receiver_name")
                .corr()
                .with_columns(receiver_name=pl.lit(x["receiver_name"][0]))
                .with_columns(subcarrier_idx_a=pl.Series(x.columns[:-1]))
            )
        )
        .fill_nan(0)
    )

    # Unpivot to get a DataFrame with columns:
    # r | i | j | C_ij(r)
    corr = corr.unpivot(
        index=["receiver_name", "subcarrier_idx_a"],
        variable_name="subcarrier_idx_b",
        value_name="correlation",
    )

    # Plot!
    return corr.plot.heatmap(
        x="subcarrier_idx_a",
        y="subcarrier_idx_b",
        C="correlation",
        groupby="receiver_name",
        colormap="rainbow",
    )


plot_corr(proc.csi)

## First preprocessing: Scaling and trending

Next, we normalize every CSI packet amplitude individually, using

$$
\begin{align}
A_i \mapsto \frac{A_i}{\sum_j A_j}
\end{align}
$$

The sum across subcarriers is a voltage quantity, hence this normalization is akin to ignoring voltage-scaling.

We also normalize phases $\varphi_i$ by fixing the outermost subcarrier phases to zero and applying a phase correction in between.
Assume a symmetric subcarrier presence from $-K$ to $K$, then:

$$
\begin{align}
\varphi_i \mapsto \varphi_i - (i-K) * \frac{\varphi_K - \varphi_{-K}}{2K} - \varphi_{-K}
\end{align}
$$

We investigate both amplitudes, a few phases, as well as the correlation from above after these steps.

In [None]:
proc = proc.load_checkpoint("01_detrended").meta_attach("session_nr")

proc.csi.filter(pl.col("session_nr") < 3).filter(
    pl.col("collection_name").str.contains("warmup")
).collect().plot.scatter(
    x="subcarrier_idxs", y="csi_abs", by="collection_name", groupby="receiver_name"
)

In [None]:
# Plot first ten captured phases of the first session
proc.csi.filter(pl.col("session_nr") == 0).filter(
    pl.col("sequence_number") < 10
).filter(pl.col("schedule_name").str.contains("qca")).filter(
    pl.col("collection_name").str.contains("warmup")
).collect().plot.scatter(
    x="subcarrier_idxs", y="csi_phase", by="collection_name", groupby="receiver_name"
)

In [None]:
proc.csi.filter(pl.col("session_nr") == 2).filter(
    pl.col("sequence_number") < 10
).filter(pl.col("schedule_name").str.contains("qca")).filter(
    ~pl.col("collection_name").str.contains("warmup")
).collect().plot.scatter(
    x="subcarrier_idxs", y="csi_phase", by="collection_name", groupby="receiver_name"
)

In [None]:
plot_corr(proc.csi)

## Second Preprocessing: Shape equalization

In every experiment, we first run a warmup session with $M=500$ unmodified frames. Afterwards follows our actual run with $1000$ modified frames. From the warmups, we can compute an average shape profile for magnitudes and phases

$$
\begin{align}
\bar{\phi}_i &= \frac{1}{500} \sum_{t=1}^{M} \phi_i(t) \\
\bar{A}_i &= \frac{1}{500} \sum_{t=1}^{M} A_i(t)
\end{align}
$$

Then, we equalize on each subcarrier and packet individually to "flatten" out the inherent average shape, likely introduced by Filters and similar hardware components.

$$
\begin{align}
\varphi_i(t) &\mapsto \frac{\varphi_i(t)}{\bar{\phi}_i} \\
A_i(t) &\mapsto \frac{A_i(t)}{\bar{A}_i}
\end{align}
$$

In [None]:
proc = proc.load_checkpoint("02_equalized").meta_attach("session_nr")

In [None]:
filtered = proc.csi.filter(pl.col("session_nr") < 3).filter(
    pl.col("collection_name").str.contains("warmup")
)

filtered.collect().plot.scatter(
    x="subcarrier_idxs", y="csi_abs", by="collection_name", groupby="receiver_name"
)

In [None]:
filtered.collect().plot.scatter(
    x="subcarrier_idxs", y="csi_phase", by="collection_name", groupby="receiver_name"
)

In [None]:
# Subselect a few sessions to plot, otherwise the plot is too convoluted.
sessions_to_plot = list(range(10))

proc.csi.filter(pl.col("session_nr").is_in(sessions_to_plot)).group_by(
    "receiver_name", "collection_name", "subcarrier_idxs", maintain_order=True
).agg(pl.col("csi_phase").mean()).collect().plot.line(
    x="subcarrier_idxs", y="csi_phase", by="collection_name", groupby="receiver_name"
)

## Third Preprocessing

Just discarding some stuff to ease visualization (e.g. the warmup sessions)

In [None]:
proc = proc.load_checkpoint("03_filtered")

## Scaling detection

Finally, we check how the per-subcarrier scaling is detected. That is, we plot the average detected scaling after all our normalization steps and the applied theoretical scaling. The actual scaling is marked with a box marker, the detected one with a cross. Color-coding is applied for the different scale values

In [None]:
csi = proc.csi
csi.collect().columns

csi = csi.join(
    proc.meta.select("meta_id", "scale_factor", "modified_idx"), on="meta_id"
).with_columns(pl.col("scale_factor") * np.pi)

jump_df = csi.filter(pl.col("subcarrier_idxs") == pl.col("modified_idx")).rename(
    {"csi_phase": "post_jump_phase"}
)
prej_df = (
    csi.filter(pl.col("subcarrier_idxs") == pl.col("modified_idx") - 1)
    .rename({"csi_phase": "pre_jump_phase"})
    .select(
        "meta_id",
        "sequence_number",
        "pre_jump_phase",
    )
)

# NOTE: Assumes one antenna and one stream
diff = (
    jump_df.join(prej_df, on=["meta_id", "sequence_number"])
    .fill_null(0)
    .with_columns(phase_jump=pl.col("post_jump_phase").sub(pl.col("pre_jump_phase")))
)

mean_diff = (
    diff.group_by(
        "schedule_name",
        "receiver_name",
        "modified_idx",
        "scale_factor",
        maintain_order=True,
    )
    .agg(pl.col("phase_jump").mean())
    .collect()
)


scat1 = mean_diff.plot.scatter(
    x="modified_idx",
    y="phase_jump",
    c="scale_factor",
    groupby="receiver_name",
    marker="x",
    size=10,
    cmap="rainbow",
)
scat2 = mean_diff.plot.scatter(
    x="modified_idx",
    y="scale_factor",
    c="scale_factor",
    groupby="receiver_name",
    marker="s",
    size=10,
    cmap="rainbow",
)
display(scat1 * scat2)

In [None]:
mean_diff.sort("modified_idx").plot.violin(
    by="modified_idx",
    y="phase_jump",
    groupby=["receiver_name", "scale_factor"],
    cmap="rainbow",
)

In [None]:
mean_diff.sort("modified_idx").plot.box(
    by="modified_idx",
    y="phase_jump",
    groupby=["receiver_name", "scale_factor"],
    cmap="rainbow",
)