In [4]:
import numpy as np
import scipy as scp

In [74]:
import numpy as np
from scipy.spatial.distance import mahalanobis
from scipy.stats import multivariate_normal, invwishart
def random_data_gen(n_samples=1000, n_feats=10, maha=1.0, ratio=0.5, seed=None):
    if seed:
        np.random.seed(seed)
    random_matrix = lambda n: np.dot(mat:=np.random.randn(n, n), mat.T)
    ## initialize multivariate normal dist with normally distributed means and covariance
    ## drawn from an inverse wishart distribution (conjugate prior for MVN)
    norm_means_a = np.random.randn(n_feats)
    norm_means_b = np.zeros_like(norm_means_a)
    wishart_cov = invwishart(n_feats, random_matrix(n_feats)).rvs()
    dist = mahalanobis(norm_means_a, norm_means_b, wishart_cov)
    norm_means_a = norm_means_a * (maha / dist)
    assert np.isclose(mahalanobis(norm_means_a, norm_means_b, wishart_cov), maha)
    ## multivariate normal distributions with different means and equal variances
    mvn_a = multivariate_normal(mean=norm_means_a, cov=wishart_cov)
    mvn_b = multivariate_normal(mean=norm_means_b, cov=wishart_cov)
    ## generate data samples from a multivariate normal
    data = np.vstack([mvn_a.rvs(int(n_samples*ratio)), mvn_b.rvs(n_samples - int(n_samples*ratio))])
    labels = np.arange(len(data))<int(n_samples*ratio)
    return data, labels


In [79]:
data, labels = random_data_gen(10000, 5, 1., ratio=.5, seed=56)
data, labels

(array([[  0.42043342,  -2.12769268,   1.58060764,  -0.27377687,
           0.33953751],
        [  2.41673574,  -1.5883988 ,  -0.33667244,  -1.2879792 ,
           2.86648047],
        [ -1.65304322,   1.84963592,  -1.01250505,   0.41389216,
          -3.13223372],
        ...,
        [ -4.40412632,   7.74836198,   8.61616593,  10.93168848,
         -16.76288283],
        [  4.3128885 ,  -7.89731866,  -6.05406186,  -8.91448981,
          12.31904641],
        [  0.59019821,  -0.68048415,  -0.79076833,  -1.30790723,
          -0.15644951]]),
 array([ True,  True,  True, ..., False, False, False]))

In [80]:
np.save("random_data_X.npy", data)
np.save("random_data_y", labels)