In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
from setup.constants import PROJECT_ROOT

In [None]:
wine_schema = {
    "free sulfur dioxide": pl.Float64,
    "total sulfur dioxide": pl.Float64,
    "quality": pl.Float64,
}

red_wine_df = pl.read_csv(
    PROJECT_ROOT / "data" / "winequality-red.csv", separator=";", schema_overrides=wine_schema
)
print(red_wine_df.head(5))

white_wine_df = pl.read_csv(
    PROJECT_ROOT / "data" / "winequality-white.csv", separator=";", schema_overrides=wine_schema
)
print(white_wine_df.head(5))

In [None]:
print(red_wine_df.null_count() == 0)
print(white_wine_df.null_count() == 0)

In [None]:
red_wine_df.describe()

In [None]:
white_wine_df.describe()

In [None]:
print(
    red_wine_df.group_by(pl.col("quality")).agg(pl.len()).sort(pl.col("quality"), descending=True)
)

In [None]:
red_wine_df_qc = red_wine_df.with_columns(
    pl.when(pl.col("quality") >= 7)
    .then(pl.lit("good"))
    .when(pl.col("quality") >= 5)
    .then(pl.lit("average"))
    .otherwise(pl.lit("bad"))
    .alias("quality category")
)

print(
    red_wine_df_qc.group_by(
        pl.col("quality category"),
    )
    .agg(pl.len().alias("# of samples"))
    .sort(pl.col("# of samples"), descending=True)
)

In [None]:
print(
    white_wine_df.group_by(pl.col("quality")).agg(pl.len()).sort(pl.col("quality"), descending=True)
)

In [None]:
white_wine_df_qc = white_wine_df.with_columns(
    pl.when(pl.col("quality") >= 7)
    .then(pl.lit("good"))
    .when(pl.col("quality") >= 5)
    .then(pl.lit("average"))
    .otherwise(pl.lit("bad"))
    .alias("quality category")
)
print(
    white_wine_df_qc.group_by(pl.col("quality category"))
    .agg(pl.len().alias("# of samples"))
    .sort(pl.col("# of samples"), descending=True)
)

In [None]:
import os

for color in ["white", "red"]:
    if not os.path.isdir(PROJECT_ROOT / "figures" / color):
        os.makedirs(PROJECT_ROOT / "figures" / color)

In [None]:
columns = [col for col in red_wine_df.columns if col != "target"]
agg_exprs = [pl.col(col).mean().alias(f"{col}_mean") for col in columns]

red_wine_means_df = (
    red_wine_df_qc.group_by(pl.col("quality category"))
    .agg(agg_exprs)
    .sort(pl.col("quality_mean"), descending=True)
)
red_wine_means_df

In [None]:
white_wine_means_df = (
    white_wine_df_qc.group_by("quality category")
    .agg(agg_exprs)
    .sort(pl.col("quality_mean"), descending=True)
)
white_wine_means_df

In [None]:
red_wine_corr = red_wine_df.corr()
plt.figure(figsize=(14, 12))
heatmap = sns.heatmap(
    red_wine_corr,
    annot=True,
    linewidths=0,
    vmin=-1,
    cmap="RdBu_r",
    xticklabels=red_wine_df.columns,
    yticklabels=red_wine_df.columns,
)
plt.title("Red Wine Correlation Matrix")
plt.savefig(PROJECT_ROOT / "figures" / "red" / "correlation_matrix")

In [None]:
def make_scatter_hist(df: pl.DataFrame, x: str, y: str, filename: Path, title: str) -> None:
    plt.figure()
    gridA = (
        sns.JointGrid(x=x, y=y, data=df)
        .plot_joint(sns.regplot, scatter_kws={"s": 10})
        .plot_marginals(sns.histplot)
    )
    plt.title(title)
    gridA.savefig(filename)

In [None]:
make_scatter_hist(
    df=red_wine_df[["fixed acidity", "pH"]],
    x="fixed acidity",
    y="pH",
    filename=PROJECT_ROOT / "figures" / "red" / "fixedAcidity_pH",
    title="Red Wine - pH vs. Fixed Acidity",
)

In [None]:
make_scatter_hist(
    df=red_wine_df[["fixed acidity", "citric acid"]],
    x="fixed acidity",
    y="citric acid",
    filename=PROJECT_ROOT / "figures" / "red" / "fixedAcidity_citricAcid",
    title="Red Wine - Citric Acid vs. Fixed Acidity",
)

In [None]:
make_scatter_hist(
    df=red_wine_df[["fixed acidity", "density"]],
    x="fixed acidity",
    y="density",
    filename=PROJECT_ROOT / "figures" / "red" / "fixedAcidity_density",
    title="Red Wine - Density vs. Fixed Acidity",
)

In [None]:
white_wine_corr = white_wine_df.corr()
plt.figure(figsize=(14, 12))
heatmap = sns.heatmap(
    white_wine_corr,
    annot=True,
    linewidths=0,
    vmin=-1,
    cmap="RdBu_r",
    xticklabels=white_wine_df.columns,
    yticklabels=white_wine_df.columns,
)
plt.title("White Wine Correlation Matrix")
plt.savefig(PROJECT_ROOT / "figures" / "white" / "correlation_matrix")

In [None]:
make_scatter_hist(
    df=white_wine_df[["density", "residual sugar"]],
    x="density",
    y="residual sugar",
    filename=PROJECT_ROOT / "figures" / "white" / "density_residualSugar",
    title="White Wine - Residual Sugar vs. Density",
)

In [None]:
make_scatter_hist(
    df=red_wine_df[["density", "alcohol"]],
    x="density",
    y="alcohol",
    filename=PROJECT_ROOT / "figures" / "white" / "density_alcohol",
    title="White Wine - Alcohol vs. Density",
)

In [None]:
red_va_by_quality = (
    red_wine_df.group_by(pl.col("quality"))
    .agg(pl.col("volatile acidity").mean().name.suffix(" mean"))
    .sort(pl.col("quality"), descending=True)
)
print(red_va_by_quality)

white_va_by_quality = (
    white_wine_df.group_by(pl.col("quality"))
    .agg(pl.col("volatile acidity").mean().name.suffix(" mean"))
    .sort(pl.col("quality"), descending=True)
)
print(white_va_by_quality)

In [None]:
colors = {
    np.float64(3.0): "#FF9800",
    np.float64(4.0): "#FFC107",
    np.float64(5.0): "#FFEB3B",
    np.float64(6.0): "#DCE775",
    np.float64(7.0): "#AEEA00",
    np.float64(8.0): "#64DD17",
    np.float64(9.0): "#00C853",
}


def make_barplot(
    df: pl.DataFrame,
    x: str,
    y: str,
    title: str,
    filename: Path,
    hue: str = "quality",
    palette: dict = colors,
) -> None:
    plt.figure(figsize=((12, 6)))
    sns.barplot(data=df, x=x, y=y, hue=hue, palette=palette)
    plt.title(title)
    plt.savefig(filename)

In [None]:
make_barplot(
    df=red_va_by_quality,
    x="quality",
    y="volatile acidity mean",
    title="Red Wine - Volatile Acidity Mean by Quality",
    filename=PROJECT_ROOT / "figures" / "red" / "volatile_acid_mean_quality",
)

In [None]:
make_barplot(
    df=white_va_by_quality,
    x="quality",
    y="volatile acidity mean",
    title="White Wine - Volatile Acidity Mean by Quality",
    filename=PROJECT_ROOT / "figures" / "white" / "volatile_acid_mean_quality",
)

In [None]:
red_alc_by_quality = (
    red_wine_df.group_by(pl.col("quality"))
    .agg(pl.col("alcohol").mean().name.suffix(" mean"))
    .sort(pl.col("quality"), descending=True)
)
print(red_alc_by_quality)

white_alc_by_quality = (
    white_wine_df.group_by(pl.col("quality"))
    .agg(pl.col("alcohol").mean().name.suffix(" mean"))
    .sort(pl.col("quality"), descending=True)
)
print(white_alc_by_quality)

In [None]:
make_barplot(
    df=red_alc_by_quality,
    x="quality",
    y="alcohol mean",
    title="Red Wine - Alcohol Mean by Quality",
    filename=PROJECT_ROOT / "figures" / "red" / "alcohol_quality",
)

In [None]:
make_barplot(
    df=white_alc_by_quality,
    x="quality",
    y="alcohol mean",
    title="White Wine - Alcohol Mean by Quality",
    filename=PROJECT_ROOT / "figures" / "white" / "alcohol_quality",
)

In [None]:
red_sul_by_quality = (
    red_wine_df.group_by(pl.col("quality"))
    .agg(pl.col("sulphates").mean().name.suffix(" mean"))
    .sort(pl.col("quality"), descending=True)
)
print(red_sul_by_quality)

In [None]:
make_barplot(
    df=red_sul_by_quality,
    x="quality",
    y="sulphates mean",
    title="Red Wine - Sulphates Mean by Quality",
    filename=PROJECT_ROOT / "figures" / "red" / "sulphates_quality",
)

In [None]:
white_chlor_by_quality = (
    white_wine_df.group_by(pl.col("quality"))
    .agg(pl.col("chlorides").mean().name.suffix(" mean"))
    .sort(pl.col("quality"), descending=True)
)
print(white_chlor_by_quality)

In [None]:
make_barplot(
    df=white_chlor_by_quality,
    x="quality",
    y="chlorides mean",
    title="White Wine - Chlorides Mean by Quality",
    filename=PROJECT_ROOT / "figures" / "white" / "chlorides_quality",
)