in this notebook, we consider a case where we have two Normal distributions with the same variance but two potentially distinct means. we ask whether we can statistcally tell whether their means are different from each other. we answer this question in an empirical way. 

first, we create a null hypothesis where two Normal distributions have the same mean, i.e., they are the same distribution. we repeated draw pairs of sample sets from this distribution, aggregate the sampled mean differences and draw the histogram of these mean differences.

second, we create a target hypothesis where two Normal distributions have two different means, separated by some value, and share the same variance of $1$ (or an identity covariance in the case of multi-dimensional Normal.) from each of these Normal distributions, we draw `n_samples` samples and estimate the sample mean. we then compute the sampled mean difference between two sets drawn from the two Normal distributions, respectively. we repeat this procedure many times and aggregate these sampled mean differences. we draw the histogram on top of the histogram of the null hypothesis from above.

we will be able to see that the width of the second histogram shrinks rapidly as the number of samples we draw from each Normal distribution grows, implying that we can be increasingly more certain about the sampled mean difference if we have many samples from each of these two Normal distributions. this illustrates the idea of statistical power; in order for us to have confidence in our conclusion drawn from a set of samples, we better have enough of them.

we also notice (perhaps obviously) that the mean of the sampled mean differences is centered at the true mean difference between two Normal distributions. according to this, we would need increasingly more samples in order to ensure that our conclusion from a single set of samples (or a single pair of sample sets) is solid if the true means are close to each other. this in fact tells us that statistical testing is not really about reject or not but about the degree to which we can trust our conclusion.

In [2]:
import numpy as np
import torch
import pyro
import scipy.stats as stats

In [3]:
from IPython.core.debugger import set_trace

In [4]:
# let's import some plotting libraries for drawing pretty plots.
import matplotlib.pyplot as plt
import seaborn as sns

import ipywidgets as widgets
from ipywidgets import interact, interact_manual

In [5]:
# this function fits a normal distribution to the given data.
# it returns the mean and variance of the fitted normal distribution.
def fit_normal(data):
    mean = data.mean()
    variance = data.var()
    return mean, variance

In [6]:
# this function estimates the empirical cumulative distribution function of the data.
# this function then computes the cumulative density of the given value.
def ecdf(data, value):
    data = np.sort(data)
    n = len(data)
    cdf = np.searchsorted(data, value, side='right') / n
    return cdf

In [16]:
# this function draws a set of samples from a normal distribution with given mean and variance.
# this function supports multi-dimensional data.
# this function uses pyro for sampling.
def draw_samples(mean, variance, num_samples):
    samples = pyro.sample("samples", 
                          pyro.distributions.Normal(mean, 
                                                    variance).expand([num_samples, 
                                                                      mean.shape[0]]))
    return samples

In [37]:
# this function computes the kernel MMD between two sets of samples.
# this function uses the Gaussian kernel for computing the MMD.
def compute_mmd(samples1, samples2, bandwidth=1):
    mmd = 0
    for sample1 in samples1:
        for sample2 in samples2:
            mmd += torch.exp((-torch.norm(sample1 - sample2) ** 2) / (2 * bandwidth ** 2))
    return mmd

In [17]:
# draw two sets of `n_samples` samples from a `n_dimensions`-dimensional 
# normal distribution with mean 0 and variance 1.
# compute the difference in the means of the two sets of samples.
# repeat this process `n_repeats` times.
# return the list of mean differences.
def compute_mean_diff(n_samples, n_dimensions, n_repeats):
    mean_diffs = []
    for _ in range(n_repeats):
        samples1 = draw_samples(0 * torch.ones(n_dimensions), 
                                1 * torch.ones(n_dimensions), 
                                n_samples)
        samples2 = draw_samples(0 * torch.ones(n_dimensions), 
                                1 * torch.ones(n_dimensions), 
                                n_samples)
        mean_diffs.append(samples1.mean() - samples2.mean())
    return mean_diffs

In [18]:
# now draw two sets of samples from two different normal distributions with different means.
# compute the difference in the means of the two sets of samples.
# repeat this process `n_repeats` times.
# return the list of mean differences.
def compute_mean_diff_diff_means(n_samples, n_dimensions, n_repeats, mean1, mean2, compute_p_value=False):
    mean_diffs = []
    if compute_p_value:
        p_values = []
    for _ in range(n_repeats):
        samples1 = draw_samples(mean1, 1 * torch.ones(n_dimensions), n_samples)
        samples2 = draw_samples(mean2, 1 * torch.ones(n_dimensions), n_samples)
        if compute_p_value:
            t_stat, p_value = stats.ttest_ind(samples1, samples2)
            p_values.append(p_value)
        mean_diffs.append(samples1.mean() - samples2.mean())
    if compute_p_value:
        return mean_diffs, p_values
    return mean_diffs

In [None]:
# create an interactive plot where we can vary the sample size, dimensionality of the data, and the mean difference.
# this plot shows how the difference in the means of two sets of samples drawn from the same distribution changes.
# this plot also shows how the difference in the means of two sets of samples drawn from two different distributions changes.
def plot_mean_diffs(n_samples, n_dimensions, n_repeats, mean1, mean2):
    mean_diffs_same = compute_mean_diff(n_samples, n_dimensions, 5_000)
    mean_diffs_diff = compute_mean_diff_diff_means(n_samples, 
                                                   n_dimensions, 
                                                   n_repeats, 
                                                   mean1 * torch.ones(n_dimensions) / np.sqrt(n_dimensions), 
                                                   mean2 * torch.ones(n_dimensions) / np.sqrt(n_dimensions))
    fig, ax = plt.subplots(figsize=(5, 3))
    # use the normalized frequency for the y-axis.
    sns.histplot([md.item() for md in mean_diffs_same], ax=ax, color='blue', label='Same Distribution', stat='density')
    sns.histplot([md.item() for md in mean_diffs_diff], ax=ax, color='red', label='Different Distribution', stat='density')
    ax.set_xlabel('Mean Difference')
    ax.set_ylabel('Frequency')
    ax.set_title('Mean Difference vs Frequency')
    ax.legend()
    plt.show()

# now create an interactive plot.
# we can vary the sample size, dimensionality of the data, and the mean difference.
interact_manual(plot_mean_diffs,
                n_samples=widgets.IntSlider(min=10, max=100, step=1, value=100),
                n_dimensions=widgets.IntSlider(min=1, max=10, step=1, value=1),
                n_repeats=widgets.IntSlider(min=10, max=1000, step=10, value=100),
                mean1=widgets.FloatSlider(min=-1, max=1, step=0.1, value=0),
                mean2=widgets.FloatSlider(min=-1, max=1, step=0.1, value=1))

interactive(children=(IntSlider(value=100, description='n_samples', min=10), IntSlider(value=1, description='n…

<function __main__.plot_mean_diffs(n_samples, n_dimensions, n_repeats, mean1, mean2)>