---
title: PARAFAC 2 on Wine
description: first proper attempt to decompose wine dataset.
project: parafac2
conclusion: ""
status: open
cdt: 2024-09-26T15:07:44
---

In [None]:
%reload_ext autoreload
%autoreload 2


import polars as pl
import duckdb as db
from pca_analysis.get_sample_data import get_shiraz_data
from pca_analysis.definitions import DB_PATH_UV
import altair as alt

alt.data_transformers.enable("vegafusion")
pl.Config.set_tbl_rows(99)


con = db.connect(DB_PATH_UV, read_only=True)

shiraz_data = get_shiraz_data(con=con)

shiraz_data[0][0].head()
len(shiraz_data)


In [None]:
ids = [tup[0]["id"][0] for tup in shiraz_data]


As we can see, sample 82 is questionably small in scale relative to the other samples.

make the tensor


select a subset to begin..

In [None]:
# def prepare_tensor(
#     data: list[tuple[pl.DataFrame, pl.DataFrame]],
# ) -> tuple[list[str], NDArray]:
#     """
#     todo: add test for numeric col only in x[0]
#     """
#     t: list[pl.DataFrame] = [x[0].drop(["mins", "path", "runid", "id"]) for x in data]

#     ids: list[str] = [x[0]["runid"][0] for x in data]

#     tt = np.stack([x[1].to_numpy() for x in t])

#     return ids, tt


# tt = prepare_tensor(data=shiraz_data)
# len(tt[0])


We're expecting a tensor with dimensions m observations, n wavelengths and o samples. For the **shiraz** dataset, this is 7800 x 106 x 10.

In [None]:
# tt[1].shape


In [None]:
# decomp = parafac2(
#     tensor_slices=tt[1], rank=22, return_errors=True, n_iter_max=500, nn_modes="all"
# )
# decomp


In [None]:
def prepare_data(data):
    imgs: pl.DataFrame = pl.concat([data[0] for data in data])
    mta: pl.DataFrame = pl.concat([data[1] for data in data])

    imgs.head()

    mta = mta.join(imgs.select("runid", "path").unique(), on="runid", how="left")
    imgs = imgs.drop("path", "id")

    imgs_ = (
        imgs.unpivot(
            index=["runid", "mins"],
            variable_name="wavelength",
            value_name="abs",
        )
        .select(
            "runid",
            pl.col("wavelength").cast(int),
            "mins",
            "abs",
        )
        .sort("runid", "wavelength", "mins")
    )

    return imgs_, mta


imgs, mta = prepare_data(data=shiraz_data)
mta.shape


plot the data at 256..

In [None]:
imgs.filter((pl.col("wavelength").eq(256)).or_(pl.col("wavelength").eq(330))).plot.line(
    x="mins", y="abs", color="wavelength:N"
).facet("runid", columns=4)


Ok. 82 doesnt look interesting, lets get rid of it.

In [None]:
imgs = imgs.filter(pl.col("runid").ne("82"))
imgs["runid"].unique()


## Slimming The Dataset

There will obviosuly be a section of each image which contains more information by proportion than other sections.
Thus as a first go we should try to reduce the time and wavelength modes as much as possible. The first can be done
first by cutting baseline past the intersection with the origin, say ~30 mins from memory, and the second by first
doing the same, dropping any points after a return to origin.

After that if ewe decide that the tensors are still too large we can look at resampling by aggregation in the 
wavelength and time modes.

## Time Snipping

What is the time point at which we can comfortably cut the times? Where do the baselines return to zero across all wavelengths? Something like.. after smoothing, where is an inflection point? Do we need to subtract a baseline first?


In [None]:
imgs.filter(
    pl.col("wavelength").is_in([256, 330]),
    runid="84",
).with_columns(pl.col("wavelength")).plot.line(
    x="mins", y="abs", color="wavelength:N"
).properties(width=1000, height=300)


As we can see, the two regions, ~256 and ~330 have their own distinctive features, indicating that there is value in preserving a wide range of wavelengths.

It is obvious that there is a baseline. Lets remove it with asls. Infact, pybaselines comes with a native 2D implementation..

looks like its taking about a minute a run. Once we identify time/wavelengths to eliminate 
it will be faster.


In [None]:
# from pybaselines import Baseline2D

# for run in imgs.partition_by("runid")[0:1]:
#     tidy: pl.DataFrame() = run.pivot(on="wavelength", values="abs")

#     # print(tidy.schema)

#     # display(tidy.head())
#     fitter = Baseline2D()
#     min_unid = tidy.select("runid", "mins")
#     img = tidy.drop(["runid", "mins"])
#     np_img = img.to_numpy()
#     img_schema = img.schema

#     baseline, params = fitter.asls(tidy.drop(["runid", "mins"]).to_numpy(writable=True))

#     baseline_df = pl.DataFrame(data=baseline, schema=img_schema)
#     new_df = pl.DataFrame(
#         pl.concat([min_unid, baseline_df], how="horizontal"),
#         schema=tidy.schema,
#     )
#     display(new_df.head())

# # display(baseline_df.head())

# display(
#     new_df.unpivot(
#         index=["runid", "mins"], value_name="abs", variable_name="wavelength"
#     )
#     .with_columns(pl.col("wavelength").cast(int))
#     .filter(pl.col("wavelength").is_in([256, 330]))
#     .plot.line(x="mins", y="abs", color="wavelength:N")
#     .properties(width=1000)
# )

# new_df_l = new_df.unpivot(
#     index=["runid", "mins"], value_name="abs", variable_name="wavelength"
# )

# display(new_df_l.head())

# to graphically compare the signal and fitted baseline, join the tables
# joined = imgs.filter(runid="0101").join(
#     new_df_l.filter(runid="0101").with_columns(pl.col("wavelength").cast(int)),
#     on=["runid", "wavelength", "mins"],
#     how="left",
#     suffix="_baseline",
# )

# display(joined.head())

# joined_l = (
#     joined.rename({"abs": "ing_sig"})
#     .unpivot(
#         index=["runid", "wavelength", "mins"], value_name="abs", variable_name="signal"
#     )
#     .sort("runid", "signal", "wavelength", "mins")
# )

# display(
#     joined_l.filter(pl.col("wavelength").is_in([256, 330]))
#     # .with_columns(pl.concat_str(["wavelength", "signal"]).alias("nm_sig"))
#     .plot.line(x="mins", y="abs", color="signal")
#     .facet("wavelength")
# )


as we can see, very poor baseline fit. we should consolidate the processes created thus far and try again.

# Downsampling

To develop a pipeline we need execution to be quick, thus we'll downsample to a respectable coarseness, develop the pipeline, then observe the results as fineness is increased. One way of downsampling is to calculate rolling averages for a given window. In this case, it is logical to first downsample by time then by wavelength.

To downsample by time we need to express time units as datetime, or integer. A natural method is to use a numeric index, running from 0 to n where n is the number of observations.


In [None]:
imgs = imgs.select(
    pl.col("mins").rank(method="dense").over(["runid", "wavelength"]).alias("idx"),
    pl.exclude("idx"),
)
imgs.head()


In [None]:
display(imgs.shape)


In [None]:
# approach taken from <https://stackoverflow.com/questions/70327284/how-can-we-resample-time-series-in-polars>


def downsample_by_time(
    df: pl.DataFrame,
    index_col: str,
    every: str,
    group_by: str,
    cols_to_perserve: list[str] = [],
) -> pl.DataFrame:
    # if cols_to_perserve:
    #     pres_cols = [pl.col(col) for col in cols_to_perserve]

    if not isinstance(cols_to_perserve, list):
        raise TypeError("cols_to_preserve must be list of string")

    for col in cols_to_perserve:
        if not isinstance(col, str):
            raise TypeError("expect elements of cols_to_preserve to be str")

    df_ = (
        df.with_columns(pl.col(index_col).cast(pl.Int64))
        .group_by_dynamic(index_column=index_col, every=every, group_by=group_by)
        .agg(pl.mean("abs"), *[pl.first(col) for col in cols_to_perserve])
    )

    return df_
    # .group_by("runid", "wavelength").len()


downsampled = downsample_by_time(
    df=imgs,
    index_col="idx",
    every="10i",
    group_by=["runid", "wavelength"],
    cols_to_perserve=["mins"],
)


downsampled.shape


In [None]:
ds_10_line = downsampled.filter(runid="0101", wavelength=256).select(
    "runid", pl.lit(10).alias("downsampling_factor"), "idx", "mins", "abs"
)

cols = [
    "runid",
    pl.lit(0).alias("downsampling_factor"),
    "idx",
    "mins",
    "abs",
]
img = imgs.filter(runid="0101", wavelength=256).select(cols)
pl.concat([ds_10_line, img], how="vertical_relaxed").sort(
    ["downsampling_factor", "mins"]
).plot.line(x="mins", y="abs", color="downsampling_factor:N").properties(
    width=1000, height=300
)


As we can see, the downsampling has a marked effect on the height of the most intense peaks, however the peaks are widened, and the area is unchanged, according to <https://terpconnect.umd.edu/~toh/spectrum/Integration.html#:~:text=Incidentally%2C%20smoothing%20a%20noisy%20signal,the%20overlap%20between%20adjacent%20peaks>.

I want to divide the signal into regions for easier viewing. Bin 1 will run from 0 to 4 mins, bin 2 from 4 t0 8, bin 3 from 8 to 15, bin 4 from 15 to 19, bin 5 from 19 to 22.7, bin 6 from 22.7 to 27, bin 7 the remainder..

In [None]:
bins = [4, 8, 15, 19, 22.7, 27]

mins = img.select("mins")
time_labels = mins.with_columns(
    pl.col("mins").cut(bins).alias("bins"), pl.col("mins").rank("dense").alias("idx")
)

img.join(time_labels.drop("mins"), on="idx").plot.line(x="mins", y="abs").facet(
    "bins",
    columns=3,
).resolve_scale(x="independent", y="independent")


Incidentally, it appears that you can see a spike at ~42 mins which corresponds to the point where the mobile phase returns to a ratio of 95% water to 5% methanol and reequlibreation prior to the next run commences. So, it should be noted that any time after this should be disregarded. More to the point, we should note what each region corresponds to, according to the mobile phase. 0 to 38 mins sees a 2.5% per minute increase in the proportion of methanol until the mobile phase is 100% methanol at that point. It is then held for two minutes before returning to the initial ratio from 40 to 42. We can observe the character of the signal within these intervals:

## Mapping Gradient Table to Signal


To better understand the regions of the signal, we can map the change in mobile phase composition over time to the signal, then divide it based on the mapping.

Check that the intervals are labelled correctly.

In [None]:
def add_mob_phase_changes_to_time_labels(
    df: pl.DataFrame, right_ends: list[float], labels: list[str]
) -> pl.DataFrame:
    """
    define
    """
    return df.with_columns(
        pl.col("mins").cut(breaks=right_ends, labels=labels).alias("intvls")
    )


def fetch_solvent_timetable(con: db.DuckDBPyConnection) -> pl.DataFrame:
    solvent_timetable = con.execute(
        """--sql
    select * from solvprop_over_mins where runid = ? order by runid, mins, channel
    """,
        parameters=[img.get_column("runid")[0]],
    ).pl()
    return solvent_timetable


def pivot_solvcomp_timetable(df: pl.DataFrame) -> pl.DataFrame:
    return df.pivot(on="channel", values="percent").with_columns(
        pl.col("mins").rank("dense").alias("mb_int")
    )


def check_min_max_of_intervals(df: pl.DataFrame, grouper) -> pl.DataFrame:
    return (
        df.group_by(grouper)
        .agg(
            pl.min("mins").alias("min"),
            pl.max("mins").alias("max"),
        )
        .sort(["min"])
    )


def add_starting_bs(df: pl.DataFrame) -> pl.DataFrame:
    starting_bs = df.select(
        "intvls",
        pl.col("b").shift(1).alias("starting_b"),
        pl.col("mins").shift(1).alias("starting_mins"),
    )

    df = df.rename({"b": "ending_b"})

    df_ = df.join(starting_bs, on="intvls", how="left")

    return df_


def calc_grads(st_piv: pl.DataFrame) -> pl.DataFrame:
    return (
        st_piv.sort("intvls")
        .pipe(add_starting_bs)
        .with_columns(
            (pl.col("mins") - pl.col("starting_mins")).alias("mins_diff"),
            (pl.col("ending_b") - pl.col("starting_b")).alias("b_diff"),
        )
        .with_columns((pl.col("b_diff") / pl.col("mins_diff")).alias("diff"))
        .select("intvls", "mins", "diff", "starting_b")
    )


def join_solvcomp_over_time_img(
    solvcomp_over_time: pl.DataFrame, img: pl.DataFrame
) -> pl.DataFrame:
    return img.join(solvcomp_over_time, on="mins", how="left")


def join_grads(intvl_labelled_times: pl.DataFrame, grads: pl.DataFrame) -> pl.DataFrame:
    return intvl_labelled_times.join(
        grads.select("intvls", "diff", "starting_b"),
        on=[
            "intvls",
        ],
        how="left",
    )


def add_rel_mins(df):
    timestep = df["mins"].diff().mean()
    return df.with_columns(
        pl.col("mins").rank("dense").over("intvls").mul(timestep).alias("rel_mins")
    )


def calc_b(df):
    """
    to calc with diff, need the gradient at every time point

    i = (i-1)*2.5 + init.
    """

    return df.with_columns(
        (pl.col("rel_mins").mul("diff") + pl.col("starting_b")).alias("calc_b")
    )


def add_b_prop_over_gradients(
    intvl_labelled_times: pl.DataFrame, with_grads: pl.DataFrame
) -> pl.DataFrame:
    solvcomp_over_time = (
        intvl_labelled_times
        .pipe(join_grads, grads=with_grads)
        .pipe(add_rel_mins)
        .pipe(calc_b)
    )  # fmt: skip

    return solvcomp_over_time


def plot_signal_facet_intvls(signal_with_solvent_props: pl.DataFrame) -> alt.FacetChart:
    exp_df_schema = pl.Schema(
        {
            "mins": pl.Float64(),
            "abs": pl.Float64(),
            "intvls": pl.Categorical(ordering="physical"),
        }
    )
    assert (
        exp_df_schema
        == signal_with_solvent_props.select("mins", "abs", "intvls").schema
    )

    return signal_with_solvent_props.plot.line(x="mins", y="abs").facet(
        "intvls", columns=3
    )


def plot_prop_b_over_intvls_time(solcvomp_over_time: pl.DataFrame) -> alt.Chart:
    return (
        solvcomp_over_time.plot.scatter(x="mins", color="intvls", y="calc_b")
        .properties(title="prop b over intvls")
        .properties(height=400, width=1000)
    )


In [None]:
solvent_timetable = fetch_solvent_timetable(con=con)

st_piv = solvent_timetable.pipe(pivot_solvcomp_timetable)

breaks = st_piv.get_column("mins").to_list()

intvl_labels = ["start", "methanol rise", "plateau", "water rise", "re-equilib", ""]

intvl_labelled_times = time_labels[["mins"]].pipe(
    add_mob_phase_changes_to_time_labels, right_ends=breaks, labels=intvl_labels
)

st_piv = st_piv.pipe(
    add_mob_phase_changes_to_time_labels, right_ends=breaks, labels=intvl_labels
)

with_grads = calc_grads(st_piv=st_piv)
display(with_grads)

solvcomp_over_time = add_b_prop_over_gradients(
    intvl_labelled_times=intvl_labelled_times, with_grads=with_grads
)

display(solvcomp_over_time.pipe(plot_prop_b_over_intvls_time))

signal_with_solvent_props = solvcomp_over_time.pipe(
    join_solvcomp_over_time_img, img=img
)


signal_with_solvent_props.select("mins", "abs", "intvls").pipe(plot_signal_facet_intvls)


So as we can see, the sample is fully eluted well before the methanol plateau, returning to baseline ~ 35 mins and varying around it. Smoothing such that the baseline in the 35 - 38min region is 0 could be an approach.

TODO:
- [ ] clean up the code above to be able to input any sample and produce the region labels based on input runid - get the mins and other inputs from the database.
- [ ] write a function to cut the signal at the 38min mark AND produce a display of the methanol + re-equib that i can observe should I want to for quality control
- [ ] Smooth signal
- [ ] sharpen signal
- [ ] baseline subtraction
- [ ] decomposition...