# Random dynamic curves

This notebook shows how to deal with the CSI data.

Some preliminary details:

- Every CSI data point is a vector $c \in \mathbb{C}^{54}$.
- This dataset contains data with emulated movement
- Movement is emulated as smooth CSI amplitude changes ($|c_i(t)|$) across time
- To emulate this change, I sampled a random dynamic curve $s(t)$ from a gaussian process
- Transmissions are modified such that every component $c_i(t)$ at time $t$ should see a relative amplitude scaling of $s(t)$.
- Movement is emulated, but transmissions are real, i.e. were captured in a true wireless environment.

Note:

- The same scaling $s(t)$ is applied to all subcarriers $i$. It is possible to apply different scales on each $i$ if needed

In [None]:
# Imports
import numpy as np
import polars as pl
import hvplot.polars  # noqa: F401
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import display
from sensession.campaign import CampaignProcessor
from loguru import logger

logger.remove()

# Example parameters for this notebook
basedir = "../../data/random_curves_18dbgain"
curves_file = f"{basedir}/curves.parquet"
exp_db_file = f"{basedir}/csi.parquet"
meta_db_file = f"{basedir}/meta.parquet"
ex_curve_num = 1

# Curve inspection

Two dataframes are relevant:

- `curves_file` contains the generated ground-truth curves
- `exp_db_file` contains the collected CSI information with time-evolution emulated by the above curves

First, let's inspect the generated curves which are used for emulation

In [None]:
curves = pl.read_parquet(curves_file)
print(curves)

print(f"Number of curves in dataframe  : {curves.select('num_curve').n_unique()}")
print(f"Length (#timesamples of curve) : {len(curves.item(0, 'curve'))}\n")

unwrapped_curve = (
    curves.filter(pl.col("num_curve") == 1).explode("curve").with_row_index("time_idx")
)

plt.figure(figsize=(10, 4))
sns.lineplot(
    data=unwrapped_curve,
    x="time_idx",
    y="curve",
).set(title=f"Sampled random curve (number {ex_curve_num})")


plt.show()

# Initial glimpse into data

Next, let's load and inspect the data.

The data contains:

- $20$ different curves
- Each session used one curve ($700$ points)
- In each session, $4$ receivers see the same signal and capture CSI from it
- All receivers capture a time series from it $c \in \mathbb{C}^{S \times T}$, where $T \leq 700$
- $T$ is the time, $S=54$ the subcarrier dimension
- Within a session, $n$ is equal for all receivers because I prefiltered them. Data points present are only those
  where all receivers captured CSI simultaneously.
- We store only the absolute value $|c_{i,t}|$
- For every curve, we have $20$ sessions. $10$ of these are with the qca receiver, $10$ with the iwl.
  The other receivers, ax210 and asus, are present in all $20$ sessions. The total session number is thus $400$.

An explanation of the columns:

- `timestamp`: The time at which the CSI value was collected
- `subcarrier_idxs`: This is basically the $i$ from above, i.e. the component index of a single CSI vector
- `csi_abs`: The CSI vector (absolute values)
- `rssi` and `antenna_rssi`: Signal strength indicators
- `receiver_name`: Name of the receiver
- `sequence_number`: Number of WiFi packet from which CSI was extracted
- `num_curve`: Label of the curve used for movement emulation

In [None]:
def glimpse(df: pl.DataFrame, num_vals: int = 5):
    print(f"A glimpse of the first {num_vals} entries in the data frame:")
    with pl.Config() as cfg:
        cfg.set_tbl_cols(-1)
        display(df)


# Start by reading and displaying data

meta = pl.read_parquet(meta_db_file)
csi = pl.read_parquet(exp_db_file)

# Plot a histogram of how many CSI values are captured per session
# This data was taken on actual hardware. A transmitter sends frames in regular
# intervals, but occasionally a receiver won't capture CSI from it, hence causing
# gaps in the time series.
sess_counts = csi.group_by("meta_id", maintain_order=True).len()

_ = sns.histplot(data=sess_counts, x="len").set(
    title="Number of captured CSI in session (out of 700 sent)"
)


# Preprocess the data
proc = (
    CampaignProcessor(
        csi,
        meta,
        lazy=False,
    )
    # .correct_rssi_by_agc()
    .unwrap()
    .filter("antenna_idxs", 0)
    .drop_contains("collection_name", "warmup")
    .meta_attach("curve_nr", "session_nr")
)


original_df = proc.csi.drop(
    "stream_capture_num",
    "rx_antenna_capture_num",
    "collection_name",
    "schedule_name",
    "stream_idxs",
    "antenna_idxs",
)
glimpse(original_df)

# Selecting data to work on

In the rest of the notebook I'll visualize some stuff on the example of:

- Some arbitrary movement-emulating curve
- A session in which that curve was used
- An arbitrary subcarrier/index to visualize CSI of

In [None]:
example_curve_idx = 0
example_session_idx = 0
example_subc = 17
example_sn = 91


def extract_example_session(df: pl.DataFrame) -> pl.DataFrame:
    """
    Function to extract example session df to avoid namespace clutter.
    """
    unique_ids = df.unique("session_nr", maintain_order=True)
    session_id = unique_ids.item(example_session_idx, "session_nr")
    example_df = df.filter(pl.col("session_nr") == session_id)

    print(
        "Finished example df preparation. Sanity check ...\n" + "Receivers:"
        f" {example_df.unique('receiver_name')['receiver_name'].to_list()}\n"
        + f"Df shape: {example_df.shape}\n"
    )

    # glimpse(example_df, 2)
    return example_df


# Extract example session
example_df = extract_example_session(
    original_df.filter(pl.col("curve_nr") == example_curve_idx)
)

# Extract reference curve and the collected one
example_curve = (
    curves.filter(pl.col("num_curve") == example_curve_idx)
    .explode("curve")
    .with_row_index("time_idx")
)

# Investigating a single data point

As mentioned, in every session, receivers capture a time series of CSI $c \in \mathbb{C}^{54 \times n}$. Let's start by examining a single data point for fixed $n$.

Note:

- Scales of data points vary between receivers. This is due to intrinsic hardware differences but also differing data formats
- Shapes have little similarity; Each receiver has a different shape characteristic

In [None]:
def csi_amp_plot(df: pl.DataFrame):
    receiver = df.unique("receiver_name")
    n_plots = len(receiver)
    fig, axs = plt.subplots(ncols=n_plots // 2, nrows=2, figsize=(18, 4))
    axs = axs.flatten()
    fig.suptitle("Amplitudes of single CSI data point", fontsize=16)

    for i, (recv, group_df) in enumerate(df.group_by(["receiver_name"])):
        sns.lineplot(
            x="subcarrier_idxs",
            y="csi_abs",
            hue="receiver_name",
            data=group_df.to_pandas(),
            legend="full",
            ax=axs[i],
        )
    plt.tight_layout()
    plt.show()


# ----------------------------------------------------------------------------
# Plot some CSI value
csi_amp_plot(example_df.filter(pl.col("sequence_number") == example_sn))

## Plotting the curve and received CSI

Next, let's examine the data points across time.
For this, we will choose one subcarrier $s$ to then visualize $c(t) = c_{s, t}$ across time. For most $s$, these changes should be highly correlated.

The effect of the emulation is that the resulting curve $c(t)$ should look like the input curve from above.

Unfortunately, this is not what we observe in the raw data due to confounding factors:

- CSI amplitudes should be proportional to signal power
- But receivers change scales using Automatic Gain Control (AGC) to deal with varying power levels
- Power is also influenced by movements between transmitter and receiver, e.g. a person
- It is conceptually difficult to distinguish whether changes are due to environment or AGC

On the upside:

- If the frequency of CSI captures is high, we expect a "smooth" transition between them
- Strong discontinuities are likely due to AGC and not movement

In [None]:
def plot_receiver_timeseries(dataframe, key="csi_abs"):
    # n_receiver = dataframe.select("receiver_name").n_unique()
    # fig, axs = plt.subplots(nrows=n_receiver + 1, figsize=(18, 22))
    df = dataframe.with_columns(
        [
            (pl.col(key) / pl.col(key).mean())
            .over("receiver_name")
            .alias("val_normalized_mean"),
            (pl.col(key) / pl.col(key).max())
            .over("receiver_name")
            .alias("val_normalized_max"),
            # (pl.col("agc") / pl.col("agc").mean()).alias("agc_mean"),
            # (pl.col("fft_gain") / pl.col("fft_gain").max()).alias("fft_gain_mean"),
            # ((pl.col("rssi") / pl.col("rssi").mean()) + 1.5).alias("rssi_mean"),
        ]
    ).rename({"sequence_number": "time_idx"})

    curve = example_curve.with_columns(
        curve_mean=pl.col("curve") / pl.col("curve").mean()
    ).hvplot.line(x="time_idx", y="curve_mean")
    fig = df.hvplot.line(
        x="time_idx",
        y="val_normalized_mean",
        groupby="receiver_name",  # , subplots=True
    )  # .cols(1).opts(width=300)

    # agc = df.hvplot.scatter(x="time_idx", y="agc_mean", groupby="receiver_name")
    # fft_gain = df.hvplot.scatter(x="time_idx", y="fft_gain_mean", groupby="receiver_name", alpha=0.4, color="g")
    plt.tight_layout()

    plot = fig * curve
    display(plot)

    import panel as pn

    pane = pn.panel(plot)
    pane


# Manual Outlier removal
# In this specific capture, QCA has a large distinct outlier that we can remove manually.
# example_df = example_df.filter(
#     (pl.col("receiver_name") != "qca") | (pl.col("csi_abs") < 300)
# )
example_subc_csi = example_df.filter(pl.col("subcarrier_idxs") == example_subc)

plot_receiver_timeseries(example_subc_csi)

# Manually trying to get rid of the jumps

As mentioned, the AGC jumps seem to be on a larger scale than the changes caused by the environment. We can manually play around with values to perform unwrapping and recover the actual CSI curve.

We can see that this works somewhat well for the asus receiver, but not really for the others. I played around a lot with them and couldnt get a good looking value. Also, we have no guarantee that these values hold even for same devices of the same manufacturer.

In [None]:
def unwrap(dataframe, receiver: str, period: int, discont: int):
    filtered_df = dataframe.filter(pl.col("receiver_name").str.contains(receiver))

    mask = (
        filtered_df.group_by(["subcarrier_idxs", "receiver_name"], maintain_order=True)
        .agg("csi_abs", "timestamp")
        .with_columns(
            csi_abs=pl.col("csi_abs").map_elements(
                lambda x: np.unwrap(
                    np.array(x), period=period, discont=discont
                ).tolist(),
                return_dtype=pl.List(inner=pl.Float32),
            )
        )
        .explode("csi_abs", "timestamp")
    )

    return mask.join(filtered_df, on="timestamp").drop("csi_phase")


asus_df = unwrap(example_subc_csi, "asus", 130, 130)
ax210df = unwrap(example_subc_csi, "ax210", 10, 5)
iwl_df = unwrap(example_subc_csi, "iwl5300", 2, 1.8)
qca_df = unwrap(example_subc_csi, "qca", 9, 9)


combined = pl.concat([asus_df, qca_df, iwl_df, ax210df])
plot_receiver_timeseries(combined)

# Normalization using RSSI

The literature proposes to normalize using RSSI to get rid of AGC values. We test whether this accomplishes at least getting rid of the jumps. However, RSSI is coarser than CSI amplitudes, so this is expected to get rid of more information than just correction for AGC.

This seems to be a step in the right direction, 

In [None]:
def clean_zscore(df: pl.DataFrame, window: int = 50, thresh: int = 2):
    # For every receiver, perform z-score based filtering
    dfs = []
    for receiver, dfw in df.group_by(
        ["meta_id", "subcarrier_idxs"], maintain_order=True
    ):
        zscore_df = dfw.with_columns(
            rolling_mean=pl.col("csi_abs").rolling_mean(
                window_size=window, center=True, min_periods=1
            ),
            rolling_std=pl.col("csi_abs").rolling_std(
                window_size=window, center=True, min_periods=1, ddof=0
            ),
        ).with_columns(
            z_score=(pl.col("csi_abs") - pl.col("rolling_mean")) / pl.col("rolling_std")
        )

        dfs.append(zscore_df.filter(pl.col("z_score").abs() < thresh))

    df = pl.concat(dfs)
    return df


# proc_rescale = proc.rescale_csi_by_rssi()

rescaled_csi = proc.csi

rescaled_csi = rescaled_csi.with_columns(
    new_rssi=pl.when(pl.col("agc").is_null())
    .then(pl.col("rssi"))
    .otherwise(-pl.col("agc")),
)


rescaled_csi = (
    rescaled_csi.join(
        rescaled_csi.group_by(["capture_num"], maintain_order=True).agg(
            (pl.col("csi_abs") ** 2).sum().alias("power")
        ),
        on="capture_num",
    )
    .with_columns(
        (
            pl.col("csi_abs")
            * ((10 ** (pl.col("new_rssi") / 10)) / pl.col("power")) ** 0.5
        ),
    )
    .drop("power")
)

rescaled_csi = extract_example_session(
    rescaled_csi.filter(pl.col("curve_nr") == example_curve_idx)
).filter(pl.col("subcarrier_idxs") == example_subc)


# # Clean outliers by zscore deviation filtering
# rssi_normed_example_df = clean_zscore(rssi_normed_example_df)


plot_receiver_timeseries(rescaled_csi)

## Double Phase Difference

This correction was proposed recently in [Enabling WiFi Sensing on New-generation WiFi Cards](https://dl.acm.org/doi/pdf/10.1145/3633807).
They rely on the fact that AGC scaling affects all subcarriers equally. They first show that taking ratios of CSI on different subcarriers is a conformal transformation, i.e. preserves angles, and gets rid of the scaling.
Using a double ratio, they can also get rid of the phase offsets. For this, they choose a subcarrier distance $k$ and calculate:

$$
\tilde{H}_i = \frac{\frac{H_i}{H_{i+k}}}{\frac{H_{i+k}}{H_{i+2k}}}
$$

It's a very nice procedure, with a few shortcomings that I can think of:
- Effectively only a third of the bandwidth is now used (at most $18$ values instead of $54$
- Correlated changes across subcarriers are factored out because of the quotients
- Obviously while retaining angles, it cant retain large scale amplitude changes. Even in the absence of AGC jumps, these changes would be divided out by such quotients. Consider for example the part of the curve that is close to $0.5$. If two subcarriers follow that, approximately dropping to half their reference value, their quotient will rid that factor.


In [None]:
def correct_double_phase_diff(row):
    abs_val = np.array(row[0])
    phase_val = np.array(row[1])
    csi = abs_val * np.exp(1j * phase_val)

    # Since we have 54 subcarriers, we can use split to produce three groups of 54/3 = 18
    triple = np.split(csi, 3)
    double_phase_diff = (triple[0] / triple[1]) / (triple[1] / triple[2])

    return np.abs(double_phase_diff).tolist(), np.angle(double_phase_diff).tolist()


num_reduced_subcs = 18  # 54 / 3
double_phase_df = original_df.with_columns(
    original_df[["csi_abs", "csi_phase"]]
    .map_rows(correct_double_phase_diff, return_dtype=pl.List(inner=pl.Float64))
    .rename({"column_0": "csi_abs", "column_1": "csi_phase"}),
    subcarrier_idxs=pl.lit(list(range(num_reduced_subcs))),
)


# double_phase_df = explode(double_phase_df)
example_antenna = 0
example_double_phase_df = (
    extract_example_session(double_phase_df)
    .filter(pl.col("used_antenna_idxs") == example_antenna)
    .filter(pl.col("subcarrier_idxs") == 2)
)

plot_receiver_timeseries(example_double_phase_df)

In [None]:
example_double_phase_df = (
    extract_example_session(double_phase_df)
    .filter(pl.col("used_antenna_idxs") == example_antenna)
    .filter(pl.col("subcarrier_idxs") == 7)
)
plot_receiver_timeseries(example_double_phase_df, key="csi_phase")

# Converting to numpy

Finally, lets take that cleaned data and convert it to numpy arrays.
This is a bit tedious, sorry for that.

Polars requires unwrapping so it can be properly used. To recover the shapes, I have to rewrap/group all the stuff manually.

- First group CSI vectors (i.e. each row should contain one CSI vector $c \in \mathbb{R}^{54}$
- Then group antennas (i.e. each row should contain a vector of CSIs $c \in \mathbb{R}^{n_{antenna} \times 54}$)
- Subsample data to equalize dimension, making sure we keep only $152$ CSI samples
- Then group data (such that $c \in \mathbb{R}^{152 \times n_{antenna} \times 54}$)

In [None]:
# -- Initial regrouping
# First group by antenna_capture_num (wrapping up all CSI of one antenna into one list)
# Then regroup by capture_num        (wrapping up all CSI lists of all antennas into one list)
numpy_df = (
    rescaled_csi.group_by(
        "antenna_capture_num",
        "capture_num",
        "session_id",
        "receiver_name",
        "num_curve",
        "timestamp",
        "used_antenna_idxs",
        maintain_order=True,
    )
    .agg("csi_abs", "subcarrier_idxs")
    .group_by(
        "capture_num",
        "session_id",
        "receiver_name",
        "num_curve",
        "timestamp",
        "subcarrier_idxs",
        maintain_order=True,
    )
    .agg("csi_abs", "used_antenna_idxs")
)

# ------ Subsampling
# NOTE: We group to get 152 samples per second of capture.
# In each group begin, we take the first accumulated value
# 700 samples, including timeout, make for a 14.12 second transmission
# To be safe, we take 13.5 seconds.
# Subsampling 152 out of that, we need to take samples every 88.82ms
aggs = [
    pl.col("csi_abs").first(),
    pl.col("session_id").first(),
    pl.col("receiver_name").first(),
    pl.col("num_curve").first(),
    pl.col("subcarrier_idxs").first(),
    pl.col("used_antenna_idxs").first(),
]

numpy_df = numpy_df.group_by(
    "session_id", "receiver_name", maintain_order=True
).map_groups(
    lambda group: group.sort("timestamp")
    .group_by_dynamic("timestamp", every="88ms82us", period="500ms")
    .agg(*aggs)
    .head(152)
)


# -- Session regrouping
numpy_df = numpy_df.group_by(
    "session_id",
    "receiver_name",
    "num_curve",
    "subcarrier_idxs",
    "used_antenna_idxs",
    maintain_order=True,
).agg("csi_abs", "timestamp")


# Glimpse
glimpse(numpy_df, 5)

# Now convert to array
# NOTE: to_numpy doesnt do nested lists, so instead we take the roundabout way over lists
csi_abs = np.array(numpy_df.get_column("csi_abs").to_list())
timestamps = np.array(numpy_df.get_column("timestamp").to_list())
curve_labels = np.array(numpy_df.get_column("num_curve").to_list())
receiver_labels = np.array(numpy_df.get_column("receiver_name").to_list())

print("Dimensions: session x num_capture x antenna x csi_subcarrier")
print(csi_abs.shape)
print(timestamps.shape)
print(curve_labels.shape)
print(receiver_labels.shape)