<a href="https://colab.research.google.com/github/jigyasu10/jax-stats-ml-handbook/blob/main/Chapter_2_Descriptive_Statistics_with_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Chapter 2: Descriptive Statistics with JAX

In the previous chapter, we introduced JAX and its core functionalities. Now, we'll start applying JAX to a fundamental area of data analysis: **descriptive statistics**. Descriptive statistics are the bedrock of any data science endeavor. They provide us with the tools to summarize and understand the key characteristics of our datasets, allowing us to gain initial insights and identify patterns before diving into more complex analyses.

This chapter will cover several essential descriptive statistics, exploring both their conceptual underpinnings and their efficient implementation using JAX. We'll learn how to calculate these statistics from scratch using JAX's powerful numerical capabilities.

## 2.1 Measures of Central Tendency: Mean, Median, and Mode

Measures of central tendency help us understand the "typical" or "central" value in a dataset. We'll cover the three most common measures: mean, median, and mode.

### 2.1.1 Mean (Average)

**Conceptual Understanding:**

The **mean**, often referred to as the average, is the sum of all values in a dataset divided by the number of values. It's the most widely used measure of central tendency, representing the arithmetic center of the data.

**Mathematical Formula:**

For a dataset \( X = \{x_1, x_2, ..., x_n\} \), the mean \( \mu \) is calculated as:

\[
\mu = \frac{1}{n} \sum_{i=1}^{n} x_i
\]

**When to Use:**

The mean is appropriate for **interval or ratio data** that is **symmetrically distributed** and without significant outliers. It's sensitive to extreme values.

**JAX Implementation:**

JAX makes calculating the mean incredibly straightforward using `jax.numpy.mean()`.

In [None]:
import jax.numpy as jnp
import jax

@jax.jit
def calculate_mean_jax(data):
    """Calculates the mean of a JAX array.

    Args:
        data: A JAX array of numerical data.

    Returns:
        The mean of the data as a JAX scalar.
    """
    return jnp.mean(data)

**Code Example:**

In [None]:
data_example = jnp.array([2.0, 4.0, 1.0, 3.0, 5.0])
mean_value = calculate_mean_jax(data_example)
print(f"Data: {data_example}")
print(f"Mean: {mean_value:.2f}")

**Explanation:**

The `jax.numpy.mean()` function efficiently computes the mean across all elements of the input JAX array. We've also used `@jax.jit` to compile the `calculate_mean_jax` function for potential performance gains, though for such a simple operation, the benefit might be minimal, but it's good practice to demonstrate its use.

### 2.1.2 Median (Middle Value)

**Conceptual Understanding:**

The **median** is the middle value in a dataset when it's ordered from least to greatest. If there's an even number of data points, the median is the average of the two middle values. The median is less sensitive to outliers than the mean.

**Procedure:**

1. Sort the dataset in ascending order.
2. If the number of data points (n) is odd, the median is the value at position \( \frac{n+1}{2} \).
3. If the number of data points (n) is even, the median is the average of the values at positions \( \frac{n}{2} \) and \( \frac{n}{2} + 1 \).

**When to Use:**

The median is robust to outliers and is suitable for **ordinal, interval, or ratio data**. It's especially useful when dealing with skewed distributions or datasets containing extreme values.

**JAX Implementation:**

JAX provides `jax.numpy.median()` to calculate the median.

In [None]:
@jax.jit
def calculate_median_jax(data):
    """Calculates the median of a JAX array.

    Args:
        data: A JAX array of numerical data.

    Returns:
        The median of the data as a JAX scalar.
    """
    return jnp.median(data)

**Code Example:**

In [None]:
data_example_odd = jnp.array([2.0, 4.0, 1.0, 3.0, 5.0])
median_value_odd = calculate_median_jax(data_example_odd)
print(f"Data (odd): {data_example_odd}")
print(f"Median (odd): {median_value_odd:.2f}")

data_example_even = jnp.array([2.0, 4.0, 1.0, 3.0, 5.0, 100.0]) # Adding an outlier and making it even
median_value_even = calculate_median_jax(data_example_even)
mean_value_even = calculate_mean_jax(data_example_even) # For comparison
print(f"Data (even, with outlier): {data_example_even}")
print(f"Median (even): {median_value_even:.2f}")
print(f"Mean (even): {mean_value_even:.2f} - Notice how the mean is affected by the outlier")

**Explanation:**

`jax.numpy.median()` handles both odd and even sized datasets correctly.  In the example with an outlier (100.0), you can observe how the median remains relatively stable, while the mean is significantly pulled upwards by the outlier, demonstrating the median's robustness.

### 2.1.3 Mode (Most Frequent Value)

**Conceptual Understanding:**

The **mode** is the value that appears most frequently in a dataset. A dataset can have no mode (if all values are unique), one mode (unimodal), two modes (bimodal), or more than two modes (multimodal).

**Procedure:**

1. Count the frequency of each unique value in the dataset.
2. The value(s) with the highest frequency are the mode(s).

**When to Use:**

The mode is applicable to **nominal, ordinal, interval, and ratio data**. It's particularly useful for categorical data or when you want to identify the most common category or value.

**JAX Implementation (from scratch):**

While JAX doesn't have a direct `jax.numpy.mode()` function in its core, we can implement it using JAX operations.

In [None]:
@jax.jit
def calculate_mode_jax(data):
    """Calculates the mode(s) of a JAX array.

    Args:
        data: A JAX array of numerical or categorical data.

    Returns:
        A JAX array containing the mode(s). If no mode or multiple modes with the same highest frequency exist, returns all such modes.
    """
    unique_values, counts = jnp.unique(data, return_counts=True)
    max_count_index = jnp.argmax(counts)
    max_count = counts[max_count_index]
    modes = unique_values[counts == max_count] # Handle multiple modes
    return modes

**Code Example:**

In [None]:
data_mode_example = jnp.array([1, 2, 2, 3, 3, 3, 4, 4])
modes_value = calculate_mode_jax(data_mode_example)
print(f"Data: {data_mode_example}")
print(f"Mode(s): {modes_value}") # Expected mode: 3

data_mode_example_multi = jnp.array([1, 2, 2, 3, 3, 4, 5]) # Bimodal (2 and 3)
modes_value_multi = calculate_mode_jax(data_mode_example_multi)
print(f"Data (multimodal example): {data_mode_example_multi}")
print(f"Mode(s) (multimodal): {modes_value_multi}") # Expected modes: [2, 3]

**Explanation:**

* `jnp.unique(data, return_counts=True)`:  This efficiently finds the unique values in the `data` array and returns them along with their counts.
* `jnp.argmax(counts)`: Finds the index of the maximum count, giving us the index of the most frequent value *if there's a unique mode*.
* `counts[max_count_index]` and `unique_values[counts == max_count]`:  We extract the maximum count and then find all unique values that have this maximum count to handle cases with multiple modes correctly.

## 2.2 Measures of Dispersion: Variance and Standard Deviation

Measures of dispersion, also known as measures of variability, describe how spread out or scattered the data points are in a dataset. We'll focus on variance and standard deviation.

### 2.2.1 Variance

**Conceptual Understanding:**

**Variance** measures the average squared deviation of each data point from the mean. It quantifies the overall spread of the data around the mean. A higher variance indicates greater dispersion.

**Mathematical Formula:**

For a population dataset \( X = \{x_1, x_2, ..., x_N\} \) with population mean \( \mu \), the population variance \( \sigma^2 \) is:

\[
\sigma^2 = \frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2
\]

For a sample dataset \( x = \{x_1, x_2, ..., x_n\} \) with sample mean \( \bar{x} \), the sample variance \( s^2 \) is often calculated with a Bessel's correction (using \(n-1\) in the denominator for an unbiased estimate of the population variance):

\[
s^2 = \frac{1}{n-1} \sum_{i=1}^{n} (x_i - \bar{x})^2
\]

**When to Use:**

Variance is used with **interval or ratio data**. It provides a measure of data spread, but its units are squared units of the original data, which can be less intuitive to interpret directly.

**JAX Implementation:**

JAX provides `jax.numpy.var()` for calculating variance. By default, `jnp.var()` calculates the *population variance* (dividing by N). To get the *sample variance* (dividing by n-1), we need to set the `ddof` (delta degrees of freedom) argument to 1.

In [None]:
@jax.jit
def calculate_variance_jax(data, is_sample=False):
    """Calculates the variance of a JAX array.

    Args:
        data: A JAX array of numerical data.
        is_sample: Boolean, if True, calculates sample variance (divides by n-1), else population variance (divides by n).

    Returns:
        The variance of the data as a JAX scalar.
    """
    ddof = 1 if is_sample else 0 # Delta Degrees of Freedom: 1 for sample variance
    return jnp.var(data, ddof=ddof)

**Code Example:**

In [None]:
data_variance_example = jnp.array([2.0, 4.0, 1.0, 3.0, 5.0])
population_variance = calculate_variance_jax(data_variance_example, is_sample=False)
sample_variance = calculate_variance_jax(data_variance_example, is_sample=True)

print(f"Data: {data_variance_example}")
print(f"Population Variance: {population_variance:.2f}")
print(f"Sample Variance: {sample_variance:.2f}") # Sample variance is slightly larger

**Explanation:**

We use `jax.numpy.var()` and control whether to calculate population or sample variance using the `ddof` argument. Setting `ddof=1` implements Bessel's correction for sample variance.

### 2.2.2 Standard Deviation

**Conceptual Understanding:**

**Standard deviation** is the square root of the variance. It's another measure of data dispersion but is expressed in the same units as the original data, making it often more interpretable than variance. It represents the typical distance of data points from the mean.

**Mathematical Formula:**

The standard deviation \( \sigma \) is simply the square root of the variance \( \sigma^2 \):

\[
\sigma = \sqrt{\sigma^2} = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2}  \text{ (Population)}
\]
\[
s = \sqrt{s^2} = \sqrt{\frac{1}{n-1} \sum_{i=1}^{n} (x_i - \bar{x})^2}  \text{ (Sample)}
\]

**When to Use:**

Similar to variance, standard deviation is used with **interval or ratio data**. It's widely used because of its interpretable units and for visualizing data spread.

**JAX Implementation:**

JAX provides `jax.numpy.std()` for calculating standard deviation, which also accepts the `ddof` argument to differentiate between population and sample standard deviation.

In [None]:
@jax.jit
def calculate_std_dev_jax(data, is_sample=False):
    """Calculates the standard deviation of a JAX array.

    Args:
        data: A JAX array of numerical data.
        is_sample: Boolean, if True, calculates sample standard deviation (divides by n-1 in variance calculation), else population standard deviation (divides by n).

    Returns:
        The standard deviation of the data as a JAX scalar.
    """
    ddof = 1 if is_sample else 0
    return jnp.std(data, ddof=ddof)

**Code Example:**

In [None]:
data_std_dev_example = jnp.array([2.0, 4.0, 1.0, 3.0, 5.0])
population_std_dev = calculate_std_dev_jax(data_std_dev_example, is_sample=False)
sample_std_dev = calculate_std_dev_jax(data_std_dev_example, is_sample=True)

print(f"Data: {data_std_dev_example}")
print(f"Population Standard Deviation: {population_std_dev:.2f}")
print(f"Sample Standard Deviation: {sample_std_dev:.2f}")

**Explanation:**

`jax.numpy.std()` mirrors `jax.numpy.var()` in its usage, making it simple to calculate both population and sample standard deviations by setting `ddof`.

## 2.3 Percentiles and Quartiles

**Conceptual Understanding:**

**Percentiles** divide a sorted dataset into 100 equal parts. The *p*-th percentile is the value below which *p*% of the data points fall.  **Quartiles** are specific percentiles that divide the data into four quarters:

* **Q1 (First Quartile):** 25th percentile.
* **Q2 (Second Quartile):** 50th percentile (which is also the median).
* **Q3 (Third Quartile):** 75th percentile.

**When to Use:**

Percentiles and quartiles are useful for understanding the distribution of data, particularly for identifying the spread and skewness. They are applicable to **ordinal, interval, and ratio data** and are robust to outliers.

**JAX Implementation:**

JAX provides `jax.numpy.percentile()` to calculate percentiles.  To get quartiles, we can calculate percentiles at 25, 50, and 75.

In [None]:
@jax.jit
def calculate_percentile_jax(data, percentile_value):
    """Calculates a specific percentile of a JAX array.

    Args:
        data: A JAX array of numerical data.
        percentile_value: The percentile to calculate (e.g., 25 for 25th percentile).

    Returns:
        The percentile value as a JAX scalar.
    """
    return jnp.percentile(data, percentile_value)

@jax.jit
def calculate_quartiles_jax(data):
    """Calculates the quartiles (Q1, Q2, Q3) of a JAX array.

    Args:
        data: A JAX array of numerical data.

    Returns:
        A JAX array containing [Q1, Q2, Q3].
    """
    q1 = calculate_percentile_jax(data, 25)
    q2 = calculate_percentile_jax(data, 50) # Median
    q3 = calculate_percentile_jax(data, 75)
    return jnp.array([q1, q2, q3])

**Code Example:**

In [None]:
data_percentile_example = jnp.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
p90 = calculate_percentile_jax(data_percentile_example, 90) # 90th percentile
quartiles = calculate_quartiles_jax(data_percentile_example)

print(f"Data: {data_percentile_example}")
print(f"90th Percentile: {p90:.2f}") # Expected: 91 (approximately, depending on interpolation)
print(f"Quartiles (Q1, Q2, Q3): {quartiles}") # Expected roughly [30, 50, 70]

**Explanation:**

* `jax.numpy.percentile(data, percentile_value)`:  Calculates the specified percentile.
* `calculate_quartiles_jax`:  A convenience function that uses `calculate_percentile_jax` to compute the 25th, 50th, and 75th percentiles and returns them as a JAX array.

## 2.4 Histograms and Data Visualization (using JAX for calculations)

**Conceptual Understanding:**

A **histogram** is a graphical representation of the distribution of numerical data. It groups data into bins and displays the frequency or count of data points falling into each bin.  Histograms provide a visual understanding of the data's shape, central tendency, and spread.

A **box plot** (or box-and-whisker plot) is another graphical representation that summarizes the distribution of data based on its quartiles (Q1, Q2/median, Q3), minimum, and maximum values. It's effective for comparing distributions across different groups and for identifying outliers.

**When to Use:**

Histograms and box plots are valuable for visualizing the distribution of **interval or ratio data**. Histograms show the detailed shape of the distribution, while box plots provide a more concise summary, particularly for comparing distributions or detecting outliers.

**JAX Implementation (for Calculation):**

JAX itself primarily focuses on numerical computation and doesn't have built-in plotting functionalities. However, we can use JAX to efficiently *calculate* the data needed to create histograms and box plots, and then use libraries like `matplotlib` or `seaborn` (after potentially converting JAX arrays to NumPy arrays) for visualization.

In [None]:
@jax.jit
def calculate_histogram_data_jax(data, bins=10, data_range=None):
    """Calculates data for a histogram from a JAX array.

    Args:
        data: A JAX array of numerical data.
        bins: Number of bins for the histogram.
        data_range: Optional tuple (min, max) to specify the range of the histogram.

    Returns:
        A tuple (bin_counts, bin_edges) for creating a histogram.
    """
    bin_counts, bin_edges = jnp.histogram(data, bins=bins, range=data_range)
    return bin_counts, bin_edges

**Code Example (Calculation and Visualization with Matplotlib):**

In [None]:
import matplotlib.pyplot as plt # Import matplotlib for plotting

data_hist_example = jax.random.normal(jax.random.PRNGKey(1), (1000,)) * 2 + 5 # Example data (roughly normal distribution)

bin_counts, bin_edges = calculate_histogram_data_jax(data_hist_example, bins=20)

# Convert JAX arrays to NumPy arrays for Matplotlib
bin_counts_np = jnp.array(bin_counts)
bin_edges_np = jnp.array(bin_edges)

plt.figure(figsize=(8, 5))
plt.hist(bin_edges_np[:-1], bin_edges_np, weights=bin_counts_np) # Using bin edges and counts to create the histogram in matplotlib
plt.title("Histogram of Example Data (Calculated with JAX)")
plt.xlabel("Data Values")
plt.ylabel("Frequency")
plt.show()


@jax.jit
def calculate_boxplot_data_jax(data):
    """Calculates data for a box plot (quartiles, min, max) from a JAX array.

    Args:
        data: A JAX array of numerical data.

    Returns:
        A dictionary containing 'Q1', 'Q2', 'Q3', 'min', 'max'.
    """
    quartiles = calculate_quartiles_jax(data)
    min_val = jnp.min(data)
    max_val = jnp.max(data)
    return {'Q1': quartiles[0], 'Q2': quartiles[1], 'Q3': quartiles[2], 'min': min_val, 'max': max_val}


boxplot_data = calculate_boxplot_data_jax(data_hist_example) # Using same data for boxplot
print(f"Box Plot Data (JAX calculated): {boxplot_data}")

# Example of using matplotlib for boxplot (can be done directly from data if needed, but showing JAX calculation)
plt.figure(figsize=(6, 4))
plt.boxplot([jnp.array(data_hist_example)], positions=[1], widths=0.6) # Pass data as list, positions for single box
plt.xticks([1], ['Data']) # Label for the box
plt.title("Box Plot of Example Data")
plt.ylabel("Data Values")
plt.show()

**Explanation:**

* `calculate_histogram_data_jax`:  Uses `jax.numpy.histogram()` to calculate bin counts and edges, the core data needed for a histogram.
* `calculate_boxplot_data_jax`: Calculates quartiles, minimum, and maximum, which are the key summary statistics represented in a box plot.
* **Visualization with Matplotlib:** The code demonstrates how to use `matplotlib.pyplot.hist` and `matplotlib.pyplot.boxplot` to create the visualizations.  We convert the JAX arrays to NumPy arrays (`jnp.array()`) before passing them to Matplotlib, as Matplotlib primarily works with NumPy arrays.  You could explore other plotting libraries that have better JAX array compatibility if needed for more complex visualizations.

## 2.5 Correlation: Pearson and Spearman

**Conceptual Understanding:**

**Correlation** measures the statistical relationship between two variables. It describes the strength and direction of a linear (for Pearson) or monotonic (for Spearman) relationship.

* **Pearson Correlation Coefficient (r):** Measures the strength and direction of a *linear* relationship between two continuous variables. It ranges from -1 to +1:
    * +1: Perfect positive linear correlation.
    * -1: Perfect negative linear correlation.
    * 0: No linear correlation.

* **Spearman Rank Correlation Coefficient (ρ or ρ<sub>s</sub>):** Measures the strength and direction of a *monotonic* relationship (not necessarily linear) between two variables (can be ordinal or continuous). It's based on the ranks of the data values.  Also ranges from -1 to +1, with similar interpretations.

**When to Use:**

* **Pearson Correlation:** Use for **interval or ratio data** when you want to assess the *linear* relationship between two variables. Sensitive to outliers and assumes linearity.
* **Spearman Correlation:** Use for **ordinal, interval, or ratio data** when you want to assess a *monotonic* relationship (variables tend to increase or decrease together, but not necessarily linearly). More robust to outliers than Pearson.

**Mathematical Formula (Pearson):**

For two variables \( X = \{x_1, x_2, ..., x_n\} \) and \( Y = \{y_1, y_2, ..., y_n\} \), the Pearson correlation coefficient \( r \) is:

\[
r = \frac{\sum_{i=1}^{n} (x_i - \bar{x})(y_i - \bar{y})}{\sqrt{\sum_{i=1}^{n} (x_i - \bar{x})^2} \sqrt{\sum_{i=1}^{n} (y_i - \bar{y})^2}}
\]

**JAX Implementation (Pearson):**

We can implement Pearson correlation using JAX operations.

In [None]:
@jax.jit
def calculate_pearson_correlation_jax(x, y):
    """Calculates the Pearson correlation coefficient between two JAX arrays.

    Args:
        x: A JAX array representing the first variable.
        y: A JAX array representing the second variable (same length as x).

    Returns:
        The Pearson correlation coefficient as a JAX scalar.
    """
    x_mean = jnp.mean(x)
    y_mean = jnp.mean(y)
    numerator = jnp.sum((x - x_mean) * (y - y_mean))
    denominator = jnp.sqrt(jnp.sum((x - x_mean)**2) * jnp.sum((y - y_mean)**2))
    return numerator / denominator

**JAX Implementation (Spearman):**

For Spearman correlation, we need to rank the data first. JAX doesn't have a direct ranking function in core `jax.numpy`, so we might need to use a combination of functions or potentially leverage a JAX-compatible library for ranking if a highly optimized implementation is critical (for this book's scope, a basic implementation is sufficient to demonstrate the concept). A simplified approach for now would be to convert to NumPy, rank, then convert back to JAX arrays for the correlation calculation - or implement ranking logic in JAX if you want to avoid NumPy conversion entirely. For simplicity in this introductory chapter, we'll use a less optimized but clearer implementation.

In [None]:
@jax.jit
def calculate_spearman_correlation_jax(x, y):
    """Calculates the Spearman rank correlation coefficient between two JAX arrays.
       (Simplified implementation - ranking in JAX might be less performant for very large arrays without optimized ranking algorithms)

    Args:
        x: A JAX array representing the first variable.
        y: A JAX array representing the second variable (same length as x).

    Returns:
        The Spearman rank correlation coefficient as a JAX scalar.
    """
    # **Simplified ranking - for demonstration.  For performance-critical ranking in JAX, explore optimized ranking algorithms.**
    x_ranked_indices = jnp.argsort(x) # Indices that would sort x
    y_ranked_indices = jnp.argsort(y) # Indices that would sort y
    x_ranks = jnp.argsort(x_ranked_indices) + 1 # Convert indices to ranks (1-based)
    y_ranks = jnp.argsort(y_ranked_indices) + 1
    return calculate_pearson_correlation_jax(x_ranks, y_ranks) # Spearman's rho is Pearson on ranks

**Code Example (Correlation):**

In [None]:
x_data_corr = jnp.array([1, 2, 3, 4, 5])
y_data_corr_pos = jnp.array([2, 4, 6, 8, 10]) # Positively correlated
y_data_corr_neg = jnp.array([10, 8, 6, 4, 2]) # Negatively correlated
y_data_corr_no = jax.random.uniform(jax.random.PRNGKey(2), (5,)) * 10 # No correlation (random)

pearson_pos = calculate_pearson_correlation_jax(x_data_corr, y_data_corr_pos)
pearson_neg = calculate_pearson_correlation_jax(x_data_corr, y_data_corr_neg)
pearson_no = calculate_pearson_correlation_jax(x_data_corr, y_data_corr_no)

spearman_pos = calculate_spearman_correlation_jax(x_data_corr, y_data_corr_pos)
spearman_neg = calculate_spearman_correlation_jax(x_data_corr, y_data_corr_neg)
spearman_no = calculate_spearman_correlation_jax(x_data_corr, y_data_corr_no)

print(f"X Data: {x_data_corr}")
print(f"Y Data (Positive Corr): {y_data_corr_pos}")
print(f"Pearson Correlation (Positive): {pearson_pos:.2f}")
print(f"Spearman Correlation (Positive): {spearman_pos:.2f}") # Should be very close to Pearson for linear data

print(f"\nY Data (Negative Corr): {y_data_corr_neg}")
print(f"Pearson Correlation (Negative): {pearson_neg:.2f}")
print(f"Spearman Correlation (Negative): {spearman_neg:.2f}")

print(f"\nY Data (No Corr): {y_data_corr_no}")
print(f"Pearson Correlation (No): {pearson_no:.2f}")
print(f"Spearman Correlation (No): {spearman_no:.2f}") # Spearman will also be close to zero for uncorrelated data

**Explanation:**

* **Pearson Correlation Implementation:** Follows the mathematical formula directly using `jax.numpy` operations.
* **Spearman Correlation Implementation (Simplified):**
    * `jnp.argsort()` is used to get the indices that would sort the arrays.
    * `jnp.argsort(jnp.argsort(...))` is a trick to get ranks (though less optimized for very large arrays).  We calculate ranks for both `x` and `y`.
    * Spearman's rho is then simply the Pearson correlation calculated on the ranks.
    * **Note on Ranking Performance:**  The ranking implementation here is illustrative. For large datasets in performance-critical applications, you might need to explore more optimized JAX-based ranking algorithms if available or consider specialized JAX libraries if they emerge. For this book's focus on "common algorithms" and conceptual understanding, this implementation is sufficient for demonstration.

## 2.6 Chapter Summary

In this chapter, we've explored essential descriptive statistics and their implementation in JAX. We covered:

* **Measures of Central Tendency:** Mean, Median, and Mode, understanding their properties and appropriate use cases.
* **Measures of Dispersion:** Variance and Standard Deviation, quantifying data spread.
* **Percentiles and Quartiles:**  Dividing data into portions to understand distribution.
* **Histograms and Box Plots:**  Using JAX to calculate data for visualizing data distributions.
* **Correlation:** Pearson and Spearman correlation coefficients for measuring relationships between variables.

For each statistic, we provided both conceptual explanations and practical JAX code examples. You now have a solid foundation for using JAX to calculate descriptive statistics, which are crucial for any data analysis workflow.

In the next chapter, we'll move on to another fundamental area: hypothesis testing, and continue to build our JAX-based data science toolkit.


This chapter aims to provide a comprehensive introduction to descriptive statistics in JAX, balancing conceptual understanding with practical implementation and setting a solid foundation for the rest of the book. Let me know if you have any feedback or would like modifications!

<div class="md-recitation">
  Sources
  <ol>
  <li><a href="https://github.com/ajaverett/statistics">https://github.com/ajaverett/statistics</a></li>
  <li><a href="https://github.com/TapasChatterjee/Statistics">https://github.com/TapasChatterjee/Statistics</a></li>
  <li><a href="https://github.com/paras007frnd/Data_science">https://github.com/paras007frnd/Data_science</a></li>
  </ol>
</div>