In [None]:
import os
import sys
import numpy as np
import tensorflow as tf
import tensorflow.math as tfm
import tensorflow_probability as tfp
from tensorflow.keras import layers

current_dir = %pwd
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)

from datagen import synthetic_bn
from synthetic import train

In [2]:
hidden_layer_sizes = [4, 4]
dim_output = 1
model_layers = [layers.Dense(size) for size in hidden_layer_sizes] + [layers.Dense(dim_output)]
rseed = 42
k = 5

In [None]:
#initial weight shape
N = 10
dim_input = 4
X_tmp = np.random.randint(2, size=[N, dim_input])
x = X_tmp[:4,:]
for layer in model_layers:
    logits = layer(x)
    x = tfp.distributions.Bernoulli(logits=logits).sample()

In [4]:
#manually set weights
np.random.seed(rseed)
dims_in = [dim_input] + hidden_layer_sizes
dims_out = hidden_layer_sizes + [dim_output]
for d_in, d_out, layer in zip(dims_in, dims_out, model_layers):
    w = np.random.uniform(low=-10.0, high=10.0, size=[d_in, d_out])
    #w = w * 2 - 1 #change {0,1} to {-1,1}
    b = np.zeros(d_out)
    layer.set_weights([w, b])

In [7]:
def data_gen(N, dim_in, model_layers, rseed = 42):
    
    np.random.seed(rseed)
    X = np.random.randint(2, size=[N, dim_in])
    for layer in model_layers:
        logits = layer(X)
        output = tfp.distributions.Bernoulli(logits=logits).sample()
    Y = output.numpy()

    return X, Y

In [None]:
X, Y = data_gen(int(1e6), dim_input, model_layers, rseed) # sample enough times to get the true P(Y|X)
dim_x = dim_input
bits = np.array([int(2 ** i) for i in range(dim_x)][::-1], dtype='int')
cnt, pos = np.zeros(int(2 ** dim_x), dtype='int'), np.zeros(int(2 ** dim_x))
for x, y in zip(X, Y):
    idx = x@bits.T
    cnt[idx] += 1
    pos[idx] += y.item()

true_p = np.round(pos / cnt, 4)
print(cnt)
print("The true probabilities of all inputs are ", true_p)

In [None]:
# save the dataset
X_dat, Y_dat = data_gen(int(1e4), dim_input, model_layers, rseed)
model_weights = []
for layer in model_layers:
    model_weights.append(layer.get_weights())
filename = f"synthetic_data/bn_synthetic_{k}.npz"
np.savez(filename, weight=model_weights, x=X_dat, y=Y_dat, prob=true_p)