# Stein's Paradox Notebook

### ref: https://www.statslab.cam.ac.uk/~rjs57/SteinParadox.pdf
### questions / suggestions --> Adam Kraft adk@mit.edu
### this notebook grew out of conversations I had about Stein's Paradox with Mike Fleder and Josh Joseph

## Why

Because Stein's paradox is utterly mind blowing. I don't rememer where I was when Kennedy was shot, likely because I didn't exist then. But I remember where I was when Mike told me about Stein's paradox... if you haven't heard about it, whoa, make sure you're sitting down ;-)

But why put so much effort into simulating this? There are lots of interesting stats and probability puzzles that blew my mind initially. The monty hall problem is an example. For reference, in the monty hall problem there are 3 doors. Behind one is a new car. behind the other two are goats. A game show contestant picks one door at random. Monty, the host, then reveals a goat behind another door (he knows where the car is, and so he knows not to reveal it to you). The contestant must now decide: switch doors or choose to open the one they chose at first. The answer is that it's always better to switch. When I first heard about it as a kid, I thought it was a lie: both choices seemed equally good. I didn't believe the math teacher---I thought, screw that, I'll write a simulation to prove that I'm right. So I did, and in the process found out *why* my intuition failed me. For the monty hall problem, you can get great intuition by messing with the numbers. Suppose there were 100 doors instead of 3, and 99 of them have goats. You pick a door, then monty reveals 98 goats behind all but one of the remaining doors. Now, it's obvious that your chances are better (or worse if you like goats) if you switch. Wheras your odds of finding a car with the door you initially picked cannot increase from 1/100, the odds if you switch are now obviously much higher (99/100) because Monty has just given up a lot of information.

I tried hard to develop the same flavor of intuition for Stein's paradox and have so far failed. This notebook is for messing with the numbers until I, with some luck, develop a better intuition for Stein's insane puzzle.

## What

Stein's paradox takes a little bit more setup than the Monty Hall problem. Suppose we have 3 or more independent random variables. It adds emphasis to the mind-blowing aspect of this problem to think of these variables as obviously unrelated things: baseball stats, birthrates in China, Martian weather. All we know about them is that they're each drawn from some normal distribution: $X_i \sim \mathcal{N}(\theta_i,1)$. For now, don't worry about the fact that the variance is 1 or even that everything's a normal distribution. With some work, a lot of distributions can be transformed to fit the preconditions of this paradox. What we want is an estimator, $\hat{\theta} = f(X)$, to approximate the true means $\theta = <\theta_1, \theta_2, \theta_3, ... \theta_p>$.

We of course want a good esimator, so we'll define a loss function $L(\hat{\theta},\theta)$ as a way of evaluating guesses at our distribution of interest's vector of means. The loss function could be squared euclidian distance, or absolute error (taxicab distance) or some other metric. In principle, this loss could be any way of describing estimation error: it represents the penalty for error in guessing this vector of means, and so there is value to us in minimizing it. A very popular choice of loss function is mean squared error: proportional to squared euclidean distance between our estimated means and the true means. Since the estimator $\hat{\theta}$ depends on random observations, our loss value will be random too. To take this randomness out of the problem and be able to compare estimators to one another deterministically, we define risk, $R(\hat{\theta},\theta) = E[L(\hat{\theta},\theta)]$ as the expeted loss given some function for obtaining $\hat{\theta}$ from random obervations of $X$.

In order to reason about relative risk of different estimators, statisticians define a concept of **dominance**. They say an estimator $\hat{\theta}$ **strictly dominates** $\tilde{\theta}$ if $R(\hat{\theta},\theta) \le R(\tilde{\theta},\theta)$ for all $\theta$, and the inequality is strict for some $\theta$. That is to say, for any value $\theta$ could take, the risk, i.e. expected loss, of using estimator $\hat{\theta}$ is less than or equal to the risk of using estimator $\tilde{\theta}$. Furthermore, there is at least some $\theta$ for which the risk of $\hat{\theta}$ is strictly less than $\tilde{\theta}$. Clearly, if we want to minimize risk, se should never use an estimator $\tilde{\theta}$ if we can show that it is strictly dominated by another estimator $\hat{\theta}$. Any estimator that we can show to be strictly dominated is called **inadmissible**. If we can show that an estimator $\hat{\theta}$ is not strictly dominated by any other estimator, we call $\hat{\theta}$ **admissible**. So far, so good.

Now back to our variables: baseball, birthrates, Mars weather. Suppose we have one sample from each and we want to try to guess the means. If we consider any one variable on its own, the intuitive answer of using the sample itself as an estimate of that one variable's mean is admissible. Using the samples to estimate the means has advantages: it's ubniased (because $E[X_i] = \theta_i$) and it turns out to be the maximum likelihood esitmator. But, **if we're concerned with minimizing the overall risk of all three variables combined, this intuitive unbiased estimator is inadmissible!**. That's right: in order to minimize risk with the MSE loss function, the following function, called the James Stein estimator, strictly dominates the simple unbiased estimator discussed above.
$$\hat{\theta}^{JS}(X) = \left( 1 - \frac{p-2}{\lVert X \rVert^2} \right)X$$
What makes this seem so crazy is that to jointly estimate all of these obviously-independent variables, we have to scale each one by a factor that depends on all of the variables! Obviously, there is some slight of hand here. Chinese birth rates and baseball stats give us no info about weather on mars, and, sure enough, if we want to estimate the mean of some weather phenomenon on mars by itself we had better not pay attention to the other variables. But if our goal is to have the smalles error, measured as mean-squared error, of all three means, we have to contract each variable toward the origin (in fact, toward any point in the space) by a factor that depends on all three variables!

As a first step, let's generate some data and confirm that this is true.



In [42]:
import numpy as np

def generate_true_means(p=3,lower=-10,upper=10):
    """
    generate p means, each sampled from a uniform distribution, 
    that will be used as our latent distributions
    """
    return np.random.uniform(lower,upper,(p,))

def sample_gaussian(theta,n_samples):
    """
    given true means, sample n_samples from multivariate gaussian in which 
    components are independent and each component has unit variance
    """
    p = theta.shape[-1]
    cov = np.eye(p)
    samples = np.random.multivariate_normal(theta,cov,n_samples)
    return samples

def squared_euclidean_loss(theta_hat,theta):
    return np.sum((theta_hat-theta)**2,axis=-1,keepdims=True) #note, this can operate on a matrix where each row is a sample

def estimator_unbiased(X):
    theta_hat = X
    return theta_hat

def estimator_JS(X):
    p = X.shape[-1]
    contraction_coefficients = (1 - (p-2)/np.sum(X**2,axis=-1,keepdims=True))
    theta_hat = X * contraction_coefficients
    return theta_hat


def estimate_risk(theta, sample_fn, estimator_fn, loss_fn, n_trials=1000000):
    """
    empirically estimate risk as defined in the intro text above, by averaging loss
    over n_trials trials. in each trial, use sample_fn to draw a sample from distribution
    represented by theta. use estimator on the resulting sample to esimate theta_hat.
    use loss_fn to compute L(theta_hat,theta). average that loss over the trials.
    """
    samples = sample_fn(theta,n_trials)
    theta_hats = estimator_fn(samples)
    thetas = np.array([theta]*n_trials)
    losses = loss_fn(theta_hats,thetas)
    return np.mean(losses)


In [53]:
def run_sim(loss_fcn=squared_euclidean_loss,
            n_trials=1000000,
            latent_var_generator=generate_true_means,
            distribution_sampler=sample_gaussian,
            n_sims=10):
    js_dominates=True
    for i in range(n_sims):
        theta = latent_var_generator()
        risk_unbiased = estimate_risk(theta,distribution_sampler,estimator_unbiased,loss_fcn,n_trials=n_trials)
        risk_js = estimate_risk(theta,distribution_sampler,estimator_JS,loss_fcn,n_trials=n_trials)
        print("unbiased risk:",risk_unbiased,"js risk:",risk_js,"js is lower risk?",(risk_js<=risk_unbiased))
        js_dominates = js_dominates and (risk_js<=risk_unbiased)
    print("js dominates?",js_dominates)


In [54]:
run_sim()

unbiased risk: 2.99826315857 js risk: 2.98058829213 js is lower risk? True
unbiased risk: 3.00231236828 js risk: 2.99340992961 js is lower risk? True
unbiased risk: 2.99853299418 js risk: 2.99096974133 js is lower risk? True
unbiased risk: 2.99616159043 js risk: 2.96736481587 js is lower risk? True
unbiased risk: 3.00027596844 js risk: 2.99188544271 js is lower risk? True
unbiased risk: 2.99960346553 js risk: 2.99660652209 js is lower risk? True
unbiased risk: 2.99706995951 js risk: 2.98502284672 js is lower risk? True
unbiased risk: 2.99860805794 js risk: 2.98802362124 js is lower risk? True
unbiased risk: 3.00568511784 js risk: 2.9904872553 js is lower risk? True
unbiased risk: 2.99297273264 js risk: 2.99147959655 js is lower risk? True
js dominates? True


### Need n_trials to be large to observe the effect
the effect of JS is subtle: with 1 million samples used to compute risk, we still don't always find that $\hat{\theta}^{JS}(X)$ dominates. But it definitely does; increasing the number of samples shows that.

Ok, now how about if we use absolute error instead of euclidean error?

In [55]:
def absolute_loss(theta_hat,theta):
    return np.sum(np.abs(theta_hat-theta),axis=-1,keepdims=True)

In [56]:
run_sim(loss_fcn=absolute_loss)

unbiased risk: 2.39170699817 js risk: 2.39085319174 js is lower risk? True
unbiased risk: 2.39577703455 js risk: 2.39048429349 js is lower risk? True
unbiased risk: 2.3945096451 js risk: 2.39044798324 js is lower risk? True
unbiased risk: 2.39439740011 js risk: 2.3902831959 js is lower risk? True
unbiased risk: 2.39503940081 js risk: 2.38884710132 js is lower risk? True
unbiased risk: 2.39465117067 js risk: 2.38827192904 js is lower risk? True
unbiased risk: 2.39503674682 js risk: 2.38509888387 js is lower risk? True
unbiased risk: 2.39339898124 js risk: 2.3904099265 js is lower risk? True
unbiased risk: 2.39269465931 js risk: 2.38608964634 js is lower risk? True
unbiased risk: 2.3929781683 js risk: 2.39225451907 js is lower risk? True
js dominates? True


### Absolute loss (L1 norm) doesn't break it!
even if we use absolute loss, JS still appears to dominate the unbiased estimator.

can we choose a loss fcn that break it? I bet max error will...

In [57]:
def max_loss(theta_hat,theta):
    return np.max(np.abs(theta_hat-theta),axis=-1,keepdims=True)

In [58]:
run_sim(loss_fcn=max_loss)

unbiased risk: 1.32564558468 js risk: 1.32419042449 js is lower risk? True
unbiased risk: 1.32598934962 js risk: 1.32408490405 js is lower risk? True
unbiased risk: 1.32651057232 js risk: 1.32611025949 js is lower risk? True
unbiased risk: 1.32564458817 js risk: 1.32536032886 js is lower risk? True
unbiased risk: 1.32711840958 js risk: 1.32182619464 js is lower risk? True
unbiased risk: 1.32595609911 js risk: 1.32076143416 js is lower risk? True
unbiased risk: 1.32544108336 js risk: 1.32418763447 js is lower risk? True
unbiased risk: 1.32618432975 js risk: 1.32070706952 js is lower risk? True
unbiased risk: 1.32639837504 js risk: 1.32459700189 js is lower risk? True
unbiased risk: 1.32651895065 js risk: 1.31929153265 js is lower risk? True
js dominates? True


### Well, shit!

even using max loss, JS appears to dominate!  WTF, I thought it was more sensitive to loss fcn!
as a sanity check, a loss that only depends on one variable should no longer dominate 

In [59]:
def zero_component_loss(theta_hat,theta):
    return np.abs(theta[...,0:1]-theta_hat[...,0:1])

In [60]:
run_sim(loss_fcn=zero_component_loss)

unbiased risk: 0.7972482459 js risk: 0.792144150751 js is lower risk? True
unbiased risk: 0.798436227119 js risk: 0.793525124094 js is lower risk? True
unbiased risk: 0.797453486825 js risk: 0.793044001236 js is lower risk? True
unbiased risk: 0.798219042859 js risk: 0.798657360774 js is lower risk? False
unbiased risk: 0.79889274602 js risk: 0.787940634562 js is lower risk? True
unbiased risk: 0.798947281604 js risk: 0.739980348082 js is lower risk? True
unbiased risk: 0.79728341616 js risk: 0.794905736018 js is lower risk? True
unbiased risk: 0.797673424632 js risk: 0.807084676145 js is lower risk? False
unbiased risk: 0.79834554216 js risk: 0.796573929604 js is lower risk? True
unbiased risk: 0.79698448702 js risk: 0.802350612499 js is lower risk? False
js dominates? False


### sanity check passed

as expected, if we only pay attention to 1 variable instead of all three, JS no longer dominates the unbiased estimator.

### one thing still bugs me, though:
the word "contraction" gets thrown around a lot when people talk about Stein's. Intuitively, it makes sense that the JS estimator is pulling points toward the origin; essentially "shrinking" the distribution by pulling it towards that point. **But, what if the point is already close to the origin?** Then, it seems to me like the JS estimator flings it farther away from the origin.

In code:

In [62]:
def always_at_origin(p=3,**unused):
    """
    a "generator" that always returns the origin
    """
    return np.array([0.0]*p)

run_sim(latent_var_generator=always_at_origin)

unbiased risk: 3.00030279659 js risk: 1.99367355102 js is lower risk? True
unbiased risk: 2.99943855049 js risk: 1.99282332762 js is lower risk? True
unbiased risk: 2.99921387891 js risk: 2.00374666511 js is lower risk? True
unbiased risk: 3.00106265255 js risk: 1.99454508109 js is lower risk? True
unbiased risk: 3.00227894682 js risk: 2.00056366207 js is lower risk? True
unbiased risk: 3.00063555488 js risk: 1.99203483437 js is lower risk? True
unbiased risk: 2.99879198694 js risk: 1.99320834782 js is lower risk? True
unbiased risk: 2.99945540393 js risk: 1.97827891061 js is lower risk? True
unbiased risk: 2.99458495796 js risk: 1.98807995623 js is lower risk? True
unbiased risk: 2.99990235957 js risk: 2.00094711632 js is lower risk? True
js dominates? True


### Huh.

So, I guess if a distribution has variance 1 there are still sufficiently many points farther than 1 unit away from the origin that risk is reduced, even though we're taking the points that are near the origin and flinging them far away. 

Let me just confirm that:

In [64]:
def sample_sharper_peak(theta,n_samples):
    """
    use a tighter variance than our standard variance of 1.
    
    Note: this violates a precondition of Stein's; I'm just trying
    to verify that this in fact breaks Stein's by flinging too many
    points far away from the origin.
    """
    p = theta.shape[-1]
    cov = np.eye(p) * 0.1 #variance of 0.1 instead of 1
    samples = np.random.multivariate_normal(theta,cov,n_samples)
    return samples

run_sim(latent_var_generator=always_at_origin, distribution_sampler=sample_sharper_peak)

unbiased risk: 0.300228779805 js risk: 8.02888388347 js is lower risk? False
unbiased risk: 0.299923972537 js risk: 8.18654015186 js is lower risk? False
unbiased risk: 0.300086425943 js risk: 8.22135080668 js is lower risk? False
unbiased risk: 0.299888916902 js risk: 8.20168244419 js is lower risk? False
unbiased risk: 0.300068480635 js risk: 8.55650105357 js is lower risk? False
unbiased risk: 0.299888665436 js risk: 8.20264150545 js is lower risk? False
unbiased risk: 0.300060158817 js risk: 8.41407124005 js is lower risk? False
unbiased risk: 0.299887970711 js risk: 8.31567443368 js is lower risk? False
unbiased risk: 0.299681511716 js risk: 8.17051738886 js is lower risk? False
unbiased risk: 0.299872232031 js risk: 8.3230930708 js is lower risk? False
js dominates? False


### Good.

That is totally what I'd expect. when the distribution has a sharp peak and most samples are less than distance 1 from the origin, the $1-\frac{p-2}{\lVert X \rVert^2}$ multiplier in the JS estimator is negative with magnitude greater than 1 for most points X, and so takes X and flings it far away from the origin.

So this is a pretty good confirmation of what I said before, that, when the variance is 1, even if the distribution mean is at the origin, there are still sufficiently many points farther than 1 unit away from the origin that risk is reduced, even though we are taking the points that are very near the origin and flinging them far away. That is, the effect of flinging these points is sufficiently weak not to matter, when variance is 1 (as it's required to be).

## Once upon a time, at Mamalehs,
Mike, Josh and I tried to come up with a geometric interpretation of Stein's paradox. The thought was, somehow the volume of sample space in which estimates are improved by contracting toward the origin exceeds the volume of space in which estimates are made worse by contracting toward the origin. Or else maybe the volume in which things are improved is improved more than the volume in which things are made worse. Or some combination of those effects.

If the geometric interpretation can be made intuitive somehow (like with a picture) then even though the stats interpretation is still insanity-inducing, we can at least think of it in terms of the geometric metaphor. Maybe, we can put it all behind us like a bad dream: could it be that Stein's paradox is just an artifact of symmetry? That is, because our distributions are all rotationally symmetric about all axes that pass through their means (because independent & all unit variance) maybe it's just a trick about the geometry of spheres (or hyperspheres---balls?) that makes stein's paradox work at all. then we can rest easier; distributions with less symmetry are less likely to be affected by this paradox.

IIRC, we had drawn something like this on a napkin:
![alt text](stein_imgs/F1.png "Figure 1")

The JS esitimator moves $X$ toward the origin along the line through $X$ and the origin (except in the degenerate case discussed and simulated above, where it flings $X$ away, but we can safely ignore that case).

If $X$ was farther from the origin along that line than point $q$, as shown on the figure, then the movement brings $X$ closer to $\theta$ until $X$ passes $q$. Then, once it's past $q$, $X$'s movement toward the origin starts to increase the distance between $X$ and $q$. The distance of point $q$ from the origin is 

$$\lVert\theta\rVert\cos(\alpha) = \frac{X \cdot \theta}{\lVert X \rVert}$$

the coordinates of $q$ are

$$\left(\frac{X \cdot \theta}{\lVert X \rVert^2}\right)X$$

This gets us somewhere, maybe. all points $X$ that get closer to $q$, after the JS estimator has had its way with them, will result in improved estimates of $\theta$ under JS than if we used the points themselves as estimates of the mean. That is, the points that improve under Stein's are those points for which:

$$\left\lVert\hat{\theta}^{JS}(X) - \left(\frac{X \cdot \theta}{\lVert X \rVert^2}\right)X\right\rVert \lt \left\lVert X - \left(\frac{X \cdot \theta}{\lVert X \rVert^2}\right)X \right\rVert$$

$$\left\lVert \left( 1 - \frac{p-2}{\lVert X \rVert^2} \right)X - \left(\frac{X \cdot \theta}{\lVert X \rVert^2}\right)X\right\rVert \lt \left\lVert X - \left(\frac{X \cdot \theta}{\lVert X \rVert^2}\right)X \right\rVert$$

$$\left\lVert \left( 1 - \frac{p-2}{\lVert X \rVert^2} - \frac{X \cdot \theta}{\lVert X \rVert^2}\right)X\right\rVert \lt \left\lVert \left(1 - \frac{X \cdot \theta}{\lVert X \rVert^2}\right)X \right\rVert$$

$$\left| 1 - \frac{p-2}{\lVert X \rVert^2} - \frac{X \cdot \theta}{\lVert X \rVert^2}\right|\lVert X\rVert \lt \left|1 - \frac{X \cdot \theta}{\lVert X \rVert^2}\right|\lVert X\rVert$$

$$\left| 1 - \frac{p-2}{\lVert X \rVert^2} - \frac{X \cdot \theta}{\lVert X \rVert^2}\right| \lt \left|1 - \frac{X \cdot \theta}{\lVert X \rVert^2}\right|$$

$$\left| 1 - \frac{p-2 - X \cdot \theta}{\lVert X \rVert^2}\right| \lt \left|1 - \frac{X \cdot \theta}{\lVert X \rVert^2}\right|$$