### Stein's Paradoxを実装する
Stein's Paradox...不偏推定量じゃない推定量で平均二乗誤差がより小さい推定量が存在する。

まずは適当に実装

In [None]:
import numpy as np
from tqdm import tqdm as t

In [None]:
def sample(n_trial,y_true):
    '''
    sample generator
    '''
    col1=np.random.normal(loc=y_true[0,0],size=(n_trial,1))
    col2=np.random.normal(loc=y_true[0,1],size=(n_trial,1))
    col3=np.random.normal(loc=y_true[0,2],size=(n_trial,1))
    return np.hstack([col1,col2,col3])

def james_stein(X):
    '''
    parameters
    ----------
    X : np.ndarray. X.shape is (n_trial, 3).
    
    returns
    ----------
    theta_js : np.ndarray. james stein estimates. shape is (n_trial, 3)
    
    '''
    return (1-(1/(X**2).sum(axis=1,keepdims=True)))*X

def mse(y_true, y_pred):
    '''
    parameters
    ----------
    y_true : np.ndarray. shape is (1,3)
    y_pred : np.ndarray. shape is (n_trial,3)
    
    returns
    ----------
    mean squared error of each dimension
    (1次元目のmse, 2次元目の平均mse, 3次元目のmse)
    '''
    return ((y_true - y_pred)**2).mean(axis=0)
    
#     return ((y_true - y_pred)**2).sum(axis=1).mean()

In [None]:
def js_mse(n_trial, true):
    X=sample(n_trial,true) #10*7程度が手元の計算機のメモリの限界
    js=james_stein(X)
    return mse(true,js)

def ub_mse(n_trual, true):
    X=sample(n_trual,true) #10*7程度が手元の計算機のメモリの限界
    return mse(true,X)

In [None]:
def get_result(n_trial):
    '''
    return
    ----------
    mu1 ... [-20,31)まで一つずつ
    result_ub ... unbaised の mu1に対応する各次元の平均二乗誤差
    result_js ... james stein の mu1に対応する各次元の平均二乗誤差
    '''
    mu1,mu2,mu3=np.arange(-20,31),3,7
    result_ub=[]
    result_js=[]
    for m1 in t(mu1):
        true=np.array([[m1,mu2,mu3]])
        result_js.append(js_mse(n_trial,true))
        result_ub.append(ub_mse(n_trial,true))
        
    return mu1, np.array(result_ub), np.array(result_js)

In [None]:
mu1, result_ub, result_js=get_result(10**6)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

#あとでretina
import seaborn as sns
sns.set()

In [None]:
plt.plot(mu1,result_js.sum(axis=1),marker='.',linestyle='None',label='James-Stein')
plt.plot(mu1,result_ub.sum(axis=1),marker='.',linestyle='None',label='Unbiased')
plt.legend();

In [None]:
plt.plot(mu1,result_js[:,0],marker='.',linestyle='None',label='mu_1')
plt.plot(mu1,result_js[:,1],marker='.',linestyle='None',label='mu_2')
plt.plot(mu1,result_js[:,2],marker='.',linestyle='None',label='mu_3')
plt.legend();