In [1]:
# this is to setup the path so we can import the mindpype library
import os; os.sys.path.append(os.path.dirname(os.path.abspath('.')))

In [2]:
import mindpype as mp

import numpy as np
from scipy import signal

In [3]:
# Create a manual computation function to compare the output of the graph with
def manual_computation(input_data):

    # first filter the data
    sos = signal.butter(4,(8,35),btype='bandpass',output='sos',fs=250)
    filtered_data = signal.sosfilt(sos,input_data,axis=1)
    cov_mat = np.cov(filtered_data)

    return cov_mat

In [4]:
# create a session
s = mp.Session.create()
trial_graph = mp.Graph.create(s)

In [5]:
# Create dummy input data and tensor
input_data = np.random.randn(12,500)
input_tensor = mp.Tensor.create_from_data(s,input_data)

# Create output tensor
output_tensor = mp.Tensor.create(s,(12,12))

# Create virtual tensor for intermediate output
intermediate_tensor = mp.Tensor.create_virtual(s)

In [6]:
# create a filter
order = 4
bandpass = (8,35) # in Hz
fs = 250
filter_obj = mp.Filter.create_butter(s,order,bandpass,btype='bandpass',fs=fs,implementation='sos')

In [7]:
# add the nodes
mp.kernels.FilterKernel.add_to_graph(trial_graph,input_tensor,filter_obj,intermediate_tensor)
mp.kernels.CovarianceKernel.add_to_graph(trial_graph,intermediate_tensor,output_tensor)

<mindpype.graph.Node at 0x1d36f4d68b0>

In [8]:
# verify the session (i.e. schedule the nodes)
trial_graph.verify()

Verifying kernel Filter...
Verifying kernel Covariance...


In [9]:
# initializing the graph - not required since there are no nodes that require initialization/training data
trial_graph.initialize()

In [10]:
# RUN!
trial_graph.execute()

Executing trial with label: None


In [11]:
# compare the output with manual calculation
ground_truth = manual_computation(input_data)

max_diff = np.max(np.abs(output_tensor.data - ground_truth))
print(max_diff)

# Check if the difference is within the machine epsilon
if max_diff <= np.finfo(np.float64).eps:
    print("Test Passed =D")
else:
    print("Test Failed D=")


0.0
Test Passed =D
