In [148]:
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
%matplotlib inline

In [149]:
def generate_data_1D_cont(pi1, X, job=0):
    cov = np.zeros((10,10))
    for i in range(10):
        for j in range(i,10):
            if i == j:
                cov[i,j] = 1
            elif i < j and j <= 4:
                cov[i,j] = 0.25
            elif i <= 4 and j > 4:
                cov[i,j] = -0.25

    for i in range(10):
        for j in range(i):
            cov[i,j] = cov[j,i]
    
    n_samples = len(X)
    
    mean = np.zeros(10,)
    mean_alt = np.ones(10,) * 2
    values = np.random.multivariate_normal(mean, cov, n_samples / 10 ).T
    values_alt = np.random.multivariate_normal(mean_alt, cov, n_samples / 10).T
    null_z = values.T.flatten()
    alt_z = values_alt.T.flatten()
    null_p = 1 - stats.norm.cdf(null_z)
    alt_p = 1 - stats.norm.cdf(alt_z)
    
    null_p = iter(null_p)
    alt_p = iter(alt_p)
    
    p = np.zeros(n_samples)
    h = np.zeros(n_samples)
    
    for i in range(n_samples):
        rnd = np.random.uniform()
        if rnd > pi1[i]:
            p[i] = null_p.next()
            h[i] = 0
        else:
            p[i] = alt_p.next()
            h[i] = 1
    return p, h, X



In [150]:
mu1 = 1
mu2 = 4
X = np.random.uniform(high = 5, size = (100000,))
X = np.sort(X)
pi1 = np.exp(-(X-mu1) ** 2/0.2) * 0.5 + np.exp(-(X-mu2) ** 2 / 0.1)
p, h, x = generate_data_1D_cont(pi1, X)

In [151]:
with open('data/data_1d_depend3.csv', 'w') as f:
        f.write('x_value, p_value, h\n')
        for i in range(len(x)):
            f.write("{}, {}, {}\n".format(x[i],p[i], h[i]))

In [152]:
np.sum(h)

19225.0

In [135]:
print(X[:100])

[  6.86787784e-05   7.52044224e-05   1.95727497e-04   2.74009417e-04
   4.21584830e-04   4.22734455e-04   5.87592850e-04   5.94981460e-04
   6.62260706e-04   6.66086814e-04   6.92275689e-04   7.17042782e-04
   7.20898127e-04   8.44114587e-04   9.03311325e-04   9.64291982e-04
   1.02007560e-03   1.15286299e-03   1.18363088e-03   1.19553104e-03
   1.25880655e-03   1.27708894e-03   1.30050625e-03   1.33909861e-03
   1.41021694e-03   1.42206985e-03   1.44212998e-03   1.47688939e-03
   1.52102431e-03   1.53554520e-03   1.85877348e-03   1.88817735e-03
   1.88966493e-03   1.91878633e-03   1.91901301e-03   1.94413248e-03
   1.95528451e-03   1.97669066e-03   2.00496026e-03   2.08321348e-03
   2.08820093e-03   2.10645989e-03   2.11146342e-03   2.12807497e-03
   2.25102326e-03   2.26588104e-03   2.49087182e-03   2.53057173e-03
   2.63490677e-03   2.65131184e-03   2.82183822e-03   2.86904366e-03
   2.87534806e-03   2.88591781e-03   2.90329995e-03   2.92741669e-03
   2.93377680e-03   2.99009752e-03

In [138]:
values

array([[-0.41213806,  0.16676604,  0.09916507, ..., -0.2764129 ,
         1.25491915, -0.62817246],
       [-0.17127784,  1.50325565, -0.02520192, ...,  1.5539611 ,
         0.34408451, -0.18809896],
       [ 0.39845313,  0.6501015 ,  2.53359971, ..., -0.07420448,
         2.00605202, -1.34861191],
       ..., 
       [ 1.23181165,  0.3892811 ,  0.78722482, ...,  2.01620394,
         0.23232371, -1.26393574],
       [-1.60537741,  0.39189896, -0.32544914, ..., -0.22488711,
        -1.7508691 ,  0.76236396],
       [ 1.09873367,  0.72888149, -0.23069915, ..., -0.91050103,
        -0.61216291,  0.50445697]])

In [147]:
values.T.flatten()

array([-0.41213806, -0.17127784,  0.39845313, ..., -1.26393574,
        0.76236396,  0.50445697])