# Equation fit

Using Halerium graphs.

Author: {{ cookiecutter.author_name }}
Created: {{ cookiecutter.timestamp }}

In [0]:
# Link to project experiments folder hypothesis_experiment_learnings.board (refresh and hit enter on this line to see the link)

## How to use the notebook

The following cells:
- specify objective, variables, and variable types,
- read dataset,
- set up the equations,
- present results from the tests,

By default, the notebook is set up to run with an example (wine quality). To see how it works, run the notebook without changing the code.

For your project, adjust the code in the linked cells with your objectives, variables, dataset etc. and then execute all cells in order.

Please refer to equation.board for detailed instructions. The headers in this notebook follow the cards on the board.

In [ ]:
# <halerium id="c2999dce-0ade-4607-9779-250e225c9155">
# Link to equation.board
# </halerium id="c2999dce-0ade-4607-9779-250e225c9155">


### Imports

In [0]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

import halerium.core as hal

from halerium.core import Graph, Entity, Variable, StaticVariable, link, DataLinker
from halerium.core import get_posterior_model, get_generative_model
from halerium.core.model import Trainer
from halerium.objectives import Predictor

### 2. Import the Dataset

In [0]:
# <halerium id="0b70a16e-ce54-4b94-be63-69e87a246afc">
time_series = False
test_size = 0.25
path = 'default example' # Specify the path of the data
# </halerium id="0b70a16e-ce54-4b94-be63-69e87a246afc">


Importing the dataset

In [0]:
if path =='default example':
    path = 'https://raw.githubusercontent.com/erium/halerium-example-data/main/hypothesis_testing/WineQT.csv'

if time_series:
    df = pd.read_csv(path, parse_dates=['date'])
else:
    df = pd.read_csv(path, sep=None)

Visualising the dataset

In [0]:
df

### 3. Model the Equations

In [0]:
graph = Graph("graph")
with graph:
# <halerium id="5555f505-bd21-41f6-8a2a-acfb07aaa168">
    with inputs:
# </halerium id="5555f505-bd21-41f6-8a2a-acfb07aaa168">
        # Inputs for the equation (the x) you may specify the mean and variance if it is known
        fixed_acidity = Variable('fixed_acidity')
        volatile_acidity = Variable('volatile_acidity')
        # volatile_acidity = Variable('volatile_acidity', mean = 0.2, variance = 0.01)

# <halerium id="5555f505-bd21-41f6-8a2a-acfb07aaa168">
    with outputs:
# </halerium id="5555f505-bd21-41f6-8a2a-acfb07aaa168">
        pH = Variable('pH')

    model_parameters = Entity('model_parameters')
# <halerium id="5555f505-bd21-41f6-8a2a-acfb07aaa168">
    with model_parameters:
# </halerium id="5555f505-bd21-41f6-8a2a-acfb07aaa168">
        a0 = StaticVariable('a0', mean=0, variance=10**2)
        a1 = StaticVariable('a1', mean=0, variance=10**2)
        a2 = StaticVariable('a2', mean=0, variance=10**2)

# <halerium id="5555f505-bd21-41f6-8a2a-acfb07aaa168">
    # The equation
# </halerium id="5555f505-bd21-41f6-8a2a-acfb07aaa168">
    pH.mean = a0 + a1 * fixed_acidity + volatile_acidity ** a2
    # You may specify the variance if you have the domain knowledge
    # pH.variance = a0 + 1

In [0]:
train, test = train_test_split(df, test_size = test_size)
train.reset_index(inplace=True)
test.reset_index(inplace=True)
test

### 4. Train the model

In [0]:
# 'Training' the model
model = get_posterior_model(
    graph=graph,
    data={
        graph.inputs.fixed_acidity: train["fixed acidity"],
        graph.inputs.volatile_acidity: train["volatile acidity"],
        graph.outputs.pH: train["pH"],
    })
posterior_graph = model.get_posterior_graph()

### 5. Get the results

In [0]:
from functions.equation import show_results

model = get_generative_model(
    graph=posterior_graph,
    data={
        graph.inputs.fixed_acidity: test["fixed acidity"],
        graph.inputs.volatile_acidity: test["volatile acidity"],
    }
)

predicted = model.get_means(graph.outputs.pH)
true = list(test['pH'])

# <halerium id="6cb52f4d-4ed6-4958-8332-ed5555f7c8d1">
show_results(predicted, true)
# </halerium id="6cb52f4d-4ed6-4958-8332-ed5555f7c8d1">
