In [1]:
from deap import gp
import operator
import math
import numpy as np

## Primitive set

In [2]:
def protectedDiv(left, right):
    try:
        return left / right
    except ZeroDivisionError:
        return 1

pset = gp.PrimitiveSet("MAIN", 1)
pset.addPrimitive(operator.add, 2)
pset.addPrimitive(operator.sub, 2)
pset.addPrimitive(operator.mul, 2)
pset.addPrimitive(protectedDiv, 2)
pset.addPrimitive(math.cos, 1)

pset.renameArguments(ARG0='x')

## Helpers

In [9]:
def generate_random_tree(pset, min_, max_):
    expr = gp.genHalfAndHalf(pset, min_=min_, max_=max_)
    tree = gp.PrimitiveTree(expr)
    return tree

def build_primitives_terminals_dict(pset):
    prims = dict()
    prims_funcs = list(pset.primitives.values())[0]
    prims_names = [p.name for p in prims_funcs]
    prims.update(zip(prims_names, prims_funcs))

    # for arguments, add key = value = name to the dict
    for arg in pset.arguments:
        prims[str(arg)] = arg

    return prims

def tree_to_nodes_matrix(tree: gp.PrimitiveTree, pset: gp.PrimitiveSet, prims_names: list, n_nodes=0):
    n_prims = pset.prims_count + len(pset.arguments)
    if n_nodes == 0:
        n_nodes = len(tree)
    m = np.zeros((n_nodes, n_prims))

    for i, prim in enumerate(tree):
        prim_name = prim.name.replace('ARG0', 'x')
        prim_idx = prims_names.index(prim_name)
        m[i, prim_idx] = 1.
    
    return m

def eval_fitness(tree, pset, points):
    func = gp.compile(tree, pset)

    sqerrors = ((func(x) - (x**2 + math.cos(x)))**2 for x in points)
    return math.fsum(sqerrors) / len(points)

def generate_dataset(n_samples, pset, min_, max_, points, prims_names):
    n_prims = pset.prims_count + len(pset.arguments)
    max_nodes = 2**(max_+1) - 1
    X = np.zeros((n_samples, max_nodes*n_prims))
    y = np.zeros((n_samples, 1))
    
    for i in range(n_samples):
        fit = math.nan
        while math.isnan(fit) or math.isinf(fit):
            tree = generate_random_tree(pset, min_, max_)
            m = tree_to_nodes_matrix(tree, pset, prims_names, max_nodes).ravel()
            try:
                fit = eval_fitness(tree, pset, points)
            except:
                fit = math.nan

        X[i,:] = m        
        y[i,:] = fit

    return X, y
    
    

## Generation of datasets

In [10]:
tree = generate_random_tree(pset, min_=1, max_=3)
print(tree)
prims = build_primitives_terminals_dict(pset)
tree_to_nodes_matrix(tree, pset, list(prims.keys()))

mul(protectedDiv(x, x), cos(x))


array([[0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 1.]])

In [11]:
points = np.arange(0.,1.1,0.1)
print(points)
eval_fitness(tree, pset, points)

[0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]


  return left / right


nan

In [12]:
from torch.utils.data import Dataset, DataLoader, random_split
from torch import nn
import torch

In [118]:
torch.manual_seed(0)
torch.set_default_dtype(torch.float64)
min_ = 1
max_ = 3
max_nodes = 2**(max_+1) - 1
n_prims = pset.prims_count + len(pset.arguments)
n_samples = 1000
X, y = generate_dataset(n_samples, pset, min_, max_, points, list(prims.keys()))
y_normalized = (y - np.mean(y))/np.std(y)
frac = 0.8
X_train, X_valid = random_split(X, [frac, 1-frac])
y_train, y_valid = random_split(y_normalized, [frac, 1-frac])

  return left / right
  return left / right


In [119]:
class CustomDataset(Dataset):
    def __init__(self, X, y, transform=None, target_transform=None):
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return X[idx,:], y[idx, 0]

train_dataset = CustomDataset(X_train, y_train)
valid_dataset = CustomDataset(X_valid, y_valid)

In [120]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(max_nodes*n_prims, 2*max_nodes*n_prims),
            nn.ReLU(),
            nn.Linear(2*max_nodes*n_prims, 2*max_nodes*n_prims),
            nn.ReLU(),
            nn.Linear(2*max_nodes*n_prims, 2*max_nodes*n_prims),
            nn.ReLU(),
            nn.Linear(2*max_nodes*n_prims, max_nodes*n_prims),
            nn.ReLU(),
            nn.Linear(max_nodes*n_prims, 1),
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits

In [121]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # if batch % 100 == 0:
        #    loss, current = loss.item(), (batch + 1) * len(X)
        #    print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()

    test_loss /= num_batches
    print(f"Test Error: Avg loss: {test_loss:>8f} \n")

In [122]:
model = NeuralNetwork()
loss_fn = nn.MSELoss()

learning_rate = 1e-4
batch_size = 1
epochs = 1000

train_dataloader = DataLoader(train_dataset, batch_size = batch_size)
valid_dataloader = DataLoader(valid_dataset, batch_size = None)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(valid_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
Test Error: Avg loss: 0.368981 

Epoch 2
-------------------------------
Test Error: Avg loss: 0.374681 

Epoch 3
-------------------------------
Test Error: Avg loss: 0.606294 

Epoch 4
-------------------------------
Test Error: Avg loss: 1.083269 

Epoch 5
-------------------------------
Test Error: Avg loss: 1.600060 

Epoch 6
-------------------------------
Test Error: Avg loss: 2.302049 

Epoch 7
-------------------------------
Test Error: Avg loss: 2.926142 

Epoch 8
-------------------------------
Test Error: Avg loss: 3.352274 

Epoch 9
-------------------------------
Test Error: Avg loss: 2.962078 

Epoch 10
-------------------------------
Test Error: Avg loss: 1.299532 

Epoch 11
-------------------------------
Test Error: Avg loss: 0.727187 

Epoch 12
-------------------------------
Test Error: Avg loss: 0.465645 

Epoch 13
-------------------------------
Test Error: Avg loss: 0.622650 

Epoch 14
-------------------------------
Test E

In [123]:
test = np.zeros((max_nodes, n_prims))
test[:6,:] = np.array([[1,0,0,0,0,0], [0,0,1,0,0,0], [0,0,0,0,0,1], [0,0,0,0,0,1], [0,0,0,0,1,0], [0,0,0,0,0,1]])
test_tensor = torch.from_numpy(test.flatten())
pred = model(test_tensor)
pred.item()

0.0022937242056880885

In [124]:
def softmax(x):
    sm = torch.nn.Softmax(dim=1)
    with torch.no_grad():
        x_reshaped = torch.reshape(x, (max_nodes, n_prims))
    return sm(x_reshaped)

In [128]:
def optimize_tree(x0: np.array, learning_rate, max_iter):
    x0 = torch.tensor(x0, requires_grad = True)
    optimizer_tree = torch.optim.Adam([x0], lr = learning_rate)
    #sm = torch.nn.Softmax(dim=1)
    for i in range(max_iter):
        pred = model(x0)
        pred.backward()
        optimizer_tree.step()
        optimizer_tree.zero_grad()
        #with torch.no_grad():
        #    x0_reshaped = torch.reshape(x0, (max_nodes, n_prims))
        #    x0 = sm(x0_reshaped).flatten().requires_grad_()
        print(pred.item())
        # print(softmax(x0))
    return x0

In [132]:
tree = generate_random_tree(pset, min_=1, max_=3)
print(tree)
x0 = tree_to_nodes_matrix(tree, pset, list(prims.keys()), n_nodes = max_nodes)
x = optimize_tree(x0.ravel(), 1e-3, 400)

mul(protectedDiv(sub(x, x), sub(x, x)), protectedDiv(mul(x, x), protectedDiv(x, x)))
0.743423375185207
0.7115612591726094
0.6779824539549805
0.6405427254642633
0.6082357965849474
0.6371015366847466
0.6323480415767231
0.6033591164087836
0.5877860267980892
0.5899527344765011
0.590751602655142
0.5894782682931514
0.5864207322338725
0.5802908937429574
0.5721595295089139
0.563152281298401
0.5543309214101926
0.545924985942168
0.5454260241105295
0.5439196620655649
0.5402156376333195
0.5342328464892746
0.526318508294126
0.5183369331941824
0.5171532075745663
0.5158225312860034
0.5127345753500063
0.5074870546502207
0.5007357090898481
0.4942823762765209
0.49252736494968374
0.4902502694909402
0.48627357574638047
0.4817789423182731
0.47661945837543646
0.4733166067438988
0.47165286086348873
0.46864042850429255
0.46416213621703595
0.45989178858266627
0.4584086573131883
0.45590483580978985
0.4522567992884226
0.44839136375159633
0.4449892350328953
0.44322297904693525
0.4405474169137667
0.436803916677060

In [133]:
sm = torch.nn.Softmax(dim=1)
with torch.no_grad():
    x_reshaped = torch.reshape(x, (max_nodes, n_prims))
    x = sm(x_reshaped)
x

tensor([[0.1314, 0.1313, 0.3365, 0.1336, 0.1311, 0.1361],
        [0.1329, 0.1160, 0.1198, 0.3806, 0.1219, 0.1288],
        [0.1312, 0.3544, 0.1196, 0.1363, 0.1009, 0.1576],
        [0.1135, 0.1486, 0.1131, 0.1647, 0.1275, 0.3326],
        [0.0993, 0.1305, 0.1279, 0.1368, 0.1132, 0.3923],
        [0.1342, 0.3657, 0.1296, 0.1325, 0.1303, 0.1078],
        [0.1327, 0.1399, 0.1192, 0.1238, 0.1399, 0.3445],
        [0.1351, 0.1207, 0.1461, 0.1338, 0.1218, 0.3425],
        [0.1061, 0.1221, 0.1225, 0.3979, 0.1262, 0.1252],
        [0.1180, 0.1378, 0.3576, 0.1382, 0.1202, 0.1282],
        [0.1256, 0.1240, 0.1022, 0.1641, 0.1413, 0.3427],
        [0.1139, 0.1126, 0.1165, 0.1391, 0.1221, 0.3958],
        [0.1214, 0.1261, 0.1536, 0.3167, 0.1461, 0.1362],
        [0.1134, 0.1255, 0.1239, 0.1462, 0.1511, 0.3398],
        [0.1356, 0.1220, 0.1268, 0.1394, 0.1589, 0.3172]])

In [35]:
print(list(prims.keys()))

['add', 'sub', 'mul', 'protectedDiv', 'cos', 'x']
