In [1]:
import numpy as np

# some_file.py
import sys
sys.path.insert(0, '../src/easyesn/')

In [2]:
from easyesn import SpatioTemporalESN
from easyesn.optimizers import GridSearchOptimizer
from easyesn import helper as hlp
import numpy as np
import matplotlib.pyplot as plt

from scipy.ndimage.filters import convolve

np.random.seed(42)

Using Numpy backend.


In [3]:
inputLength = 1000
size = 50

data = np.linspace(0, 20*np.pi, inputLength)
data = np.repeat(data, size*size).reshape(-1, size, size)

freq1 = np.tile(np.random.rand(size, size), (inputLength, 1, 1))*2+1 
freq2 = np.tile(np.random.rand(size, size), (inputLength, 1, 1))*2+1

inputData = np.sin(freq1*data).reshape(-1, size, size)
outputData = np.cos(freq1*data).reshape(-1, size, size)

filter1 = np.array([[-1,-1,-1],
                   [-1,9,-1],
                   [-1,-1,-1]])/9.0

for i in range(inputLength):
    outputData[i, :, :] = convolve(outputData[i, :, :], filter1)

trainingInput = inputData[:int(inputLength*0.7)]
validationInput = inputData[int(inputLength*0.3):]

trainingOutput = outputData[:int(inputLength*0.7)]
validationOutput = outputData[int(inputLength*0.3):]

In [4]:
#plt.plot(trainingInput[:200, 0,0])
#plt.plot(trainingInput[:200, 0,1])
#plt.plot(trainingInput[:200, 0,2])
#plt.show()

#plt.plot(trainingOutput[:200, 0,0])
#plt.plot(trainingOutput[:200, 0,1])
#plt.plot(trainingOutput[:200, 0,2])
#plt.show()

In [5]:
esn = SpatioTemporalESN(inputShape=(size, size), filterSize=3, stride=1, borderMode="mirror", averageOutputWeights=True,
                        n_reservoir=100, regression_parameters=[1e-2], leakingRate=0.2, spectralRadius=0.8, solver="lsqr")

In [6]:
esn.fit(trainingInput, trainingOutput, transientTime=100, verbose=1)

In [8]:
prediction = esn.predict(validationInput, transientTime=0, verbose=1)

In [None]:
plt.plot(prediction[:, 48, 20])
#plt.plot(validationOutput[:, 0, 0])
plt.show()

np.mean(prediction)

In [None]:
np.mean((prediction-validationOutput)**2, axis=None)

In [7]:
esn._WOut

array([[-0.21776517, -0.23160711, -0.08902271, -0.06657201, -0.03198176,
         0.97526248, -0.08901269, -0.12332577, -0.10674441, -0.22035305,
        -0.0297574 ,  0.00713275,  0.14129932,  0.13419346,  0.02673814,
         0.03931331,  0.09278676, -0.00693552,  0.05752113,  0.30500012,
         0.02276986, -0.05991313,  0.0065816 ,  0.09595953,  0.04888067,
        -0.01064464, -0.09717834, -0.03055349, -0.03878397, -0.06007259,
         0.28060619, -0.07737388, -0.00835358, -0.13041343,  0.03335606,
        -0.06319969,  0.18562258, -0.20584215,  0.06859663, -0.07995034,
         0.19006198, -0.07234545, -0.01270077,  0.0592227 ,  0.00397802,
         0.09483952, -0.20609116,  0.0518487 , -0.0905249 ,  0.01663222,
        -0.02028399, -0.00649596,  0.04190407, -0.01801149, -0.00501378,
         0.0780484 , -0.04605494,  0.05720729, -0.0016246 , -0.13150283,
        -0.0383095 ,  0.0366072 ,  0.11428047, -0.19908204, -0.11260537,
        -0.04877565, -0.30104276,  0.01539599, -0.0