In [1]:
import numpy as np 

from mldb import ComputationGraph

In [2]:
def load_data(N, D, rng): 
    return rng.normal(0, 1, (N, D))

def extract_features(data, features): 
    return np.concatenate([
        feat(data, axis=1, keepdims=True) for feat in features
    ], axis=1)

def get_weights(D): 
    return np.random.normal(0, 1, (D, 1)) / 4, 0.5

def dot_add(feats, wb): 
    w, b = wb 
    return feats.dot(w) + b

def sigmoid(z_score): 
    return 1.0 / (1.0 + np.exp(-z_score))

feat_funcs = [np.min, np.max, np.std, np.mean, np.median, np.ptp]

In [3]:
graph = ComputationGraph()

# Define a random seed 
rng = graph.node(
    func=lambda: np.random.RandomState(1234)
)

# Generate some random data
data = graph.node(
    func=load_data, 
    rng=rng,
    N=10000, D=3
)

# Extract features from the data
features = graph.node(
    func=extract_features, 
    data=data, 
    features=feat_funcs
)

# Generate parameters for a linear model
params = graph.node(
    func=get_weights, 
    D=len(feat_funcs)
)

# Inner product between features and weights
z_score = graph.node(
    func=dot_add, 
    feats=features, wb=params
)

# Go from the score to the predictions
predictions = graph.node(
    func=sigmoid, 
    z_score=z_score
)

In [4]:
for key, value in graph.nodes.items(): 
    print(key) 
    print(value) 
    print(value.evaluate())
    print()

e2b19f2f-3fcc-45f2-817d-d417adb30308
<NodeWrapper sources=[] kwargs=[] factor=<lambda> sink=e2b19f2f-3fcc-45f2-817d-d417adb30308>
RandomState(MT19937)

b04e4752-3011-47f5-b429-a4e90b032eea
<NodeWrapper sources=[rng] kwargs=[N,D] factor=load_data sink=b04e4752-3011-47f5-b429-a4e90b032eea>
[[ 0.47143516 -1.19097569  1.43270697]
 [-0.3126519  -0.72058873  0.88716294]
 [ 0.85958841 -0.6365235   0.01569637]
 ...
 [ 0.21585284  0.12923273 -0.29803653]
 [ 0.34074804  1.75820379  0.0025372 ]
 [ 1.24617813  0.61897186 -0.50800977]]

b487cc62-b83c-4810-8307-0bcf23e55c34
<NodeWrapper sources=[data] kwargs=[features] factor=extract_features sink=b487cc62-b83c-4810-8307-0bcf23e55c34>
[[-1.19097569e+00  1.43270697e+00  1.08378781e+00  2.37722146e-01
   4.71435164e-01  2.62368266e+00]
 [-7.20588733e-01  8.87162940e-01  6.82384182e-01 -4.86925630e-02
  -3.12651896e-01  1.60775167e+00]
 [-6.36523504e-01  8.59588414e-01  6.12453664e-01  7.95870938e-02
   1.56963721e-02  1.49611192e+00]
 ...
 [-2.9803653