In [1]:
import jax
import numpy as np
import jax.numpy as jnp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import sklearn.metrics as skm

import warnings
warnings.filterwarnings('ignore')

In [2]:
seed = 32
num_samples = 100000
num_features = 5

# New JAX PRNG API (recommended)
key = jax.random.key(seed)

# Split keys for different random operations - JAX requires unique keys!
key, coeff_key, X_key, bias_key, noise_key = jax.random.split(key, 5)

random_coeff = jax.random.randint(coeff_key, shape=[num_features], minval=-10, maxval=10)

X = 2 * jax.random.normal(X_key, shape=(num_samples, num_features))

# Generate Random Bias and Coefficients
random_bais = jax.random.choice(bias_key, random_coeff, shape=(1,))
random_coeff = jax.random.choice(coeff_key, random_coeff, shape=(num_features,))

print(f"Random Bias: {random_bais}")
print(f"Random Coefficients: {random_coeff}")

coeff_features = []

# Construct each feature with random coeffcients choosen
for idx, coeff in enumerate(random_coeff):
  coeff_features.append(coeff * X[:, idx:idx+1])

# Print equation
equation = f"Y = {random_bais[0]}"
for idx, coeff in enumerate(random_coeff):
  equation += f" + {coeff} * X{idx+1}"
print(equation)

# Stack the features into a single matrix
coeff_features = jnp.hstack(coeff_features)

# Generate output from random data with unique noise key
y = random_bais + jnp.sum(coeff_features, axis=1) + jax.random.normal(noise_key, shape=(num_samples,))

Random Bias: [-8]
Random Coefficients: [ 5  5 -8 -2 -8]
Y = -8 + 5 * X1 + 5 * X2 + -8 * X3 + -2 * X4 + -8 * X5
