In [247]:
import numpy as np
import scipy.stats as sp
import pandas as pd

def create_GAN_data(N, class_ratio=0.5, random_state=None):
    if random_state is not None: np.random.seed(random_state)
    # Group indicator
    #group = sp.binom.rvs(p=0.25, n=1, size=N)
    group = np.concatenate([np.zeros([int(N*(1-class_ratio))]), np.ones([int(N*class_ratio)])])

    # Continuous variables
    x0 = sp.poisson.rvs(mu=np.where(group==1,1.,2.))
    x1 = sp.norm.rvs(loc=np.where(group==1,-2,2),
                     scale=1)
    x2 = sp.norm.rvs(loc=x1,
                     scale=1)
    x3 = (x0**2 + x1**2)
    x456 = sp.multivariate_normal.rvs(mean=[0,0,0], cov= np.stack(([1,0.8,0.2], [0.8,1.,0.], [0.2,0.,1.]),axis=0), size=N)

    # Discrete variables
    # Binary
    x7 = sp.binom.rvs(p=np.where(group==1,0.6,0.3),n=1)
    # Three class
    x890_0 = sp.multinomial.rvs(p=[0.7,0.2,0.1],n=1,size=N)
    x890_1 = sp.multinomial.rvs(p=[0.2,0.7,0.1],n=1,size=N)
    x890 = np.zeros([N,3])
    for i in range(N):
        if group[i]==1:
            x890[i,] = x890_1[i,]
        else:
            x890[i,] = x890_0[i,]


    data = pd.DataFrame(np.column_stack([x0,x1,x2,x3,x456,group,x7,x890]))
    data.rename({7:"group"}, axis="columns", inplace=True)
    return data


In [246]:
create_GAN_data(40000, random_state=123).to_csv("../simulation_data/simulation.csv", index=False)
create_GAN_data(40000, random_state=124).to_csv("../simulation_data/simulation_val.csv", index=False)