In [1]:
%matplotlib inline
# %matplotlib qt
import base
import numpy as np
import scipy as sp
from scipy.linalg import expm
from matplotlib import pyplot as plt
from importlib import reload
import torch
torch.set_default_dtype(torch.float64)
import signatory
import tensorflow as tf
tfk = tf.keras
tfkl = tf.keras.layers
tf.keras.backend.set_floatx('float64')
import Sig_method

In [2]:
class SigBag:
    def __init__(self, channels, depth):
        self.channels = channels
        self.depth = depth
        self.index2word = self.index_to_word_aug()
        self.word2index = self.word_to_index_aug()
        self.sig_dim = len(self.index2word)
        self.table = self.sig_table()
    def index_to_word_aug(self):
        index2word = [()] + signatory.all_words(self.channels, self.depth)
        return index2word
    def word_to_index_aug(self):
        word2index = {}
        for i,word in enumerate(self.index2word):
            word2index.update({word: i})
        return word2index
    def sig_table(self):
        table = [[]]*self.sig_dim
        for i in range(self.sig_dim):
            for j in range(self.sig_dim):
                word = self.index2word[i] + self.index2word[j]
                if len(word) < self.depth + 1:
                    index = self.word2index[word]
                    table[index] = table[index] + [(i,j)]
        return table
    def product(self,a,b):
        c = tf.tensordot(a, b, axes=0)
        d = tf.zeros(self.sig_dim, dtype = 'float64')
        for i in range(self.sig_dim):
            for pair in self.table[i]:
                helper = np.zeros([self.sig_dim])
                helper[i] = 1
                helper = tf.constant(helper, dtype = 'float64')
                d = d + c[pair] * helper
        return d

In [5]:
class Layer_Sig(tf.keras.layers.Layer):
    def __init__(self, channels, depth):
        super(Layer_Sig, self).__init__()
        self.channels = channels
        self.depth = depth
        self.sigbag = SigBag(channels, depth)
    def call(self, inputs):
        path = inputs
        batch, length, _ = path.shape
        sig_path_split = path[:,1:,:] - path[:,:-1,:]
        helper = tf.zeros(shape = [batch, 1, self.channels], dtype = 'float64')
        sig_path_split = tf.concat([helper,sig_path_split],axis = 1)
        helper = tf.ones(shape = [batch, length, 1], dtype = 'float64')
        sig_path_split = tf.concat([helper,sig_path_split],axis = -1)
        helper = tf.zeros(shape = [batch, length, self.sigbag.sig_dim - self.channels - 1], dtype = 'float64')
        sig_path_split = tf.concat([sig_path_split,helper],axis = -1)
        all_stream = []
        for i in range(batch):
            old = sig_path_split[i,0,:]
            stream = [old]
            for j in range(length-1):
                new = sig_path_split[i,j+1,:]
                old = self.sigbag.product(old,new)
                stream = stream + [old]
            stream = tf.stack(stream)
            all_stream = all_stream + [stream]
        sig_stream = tf.stack(all_stream)
        return sig_stream


In [11]:
batch = 1
length = 10
channels = 2 
depth = 3   
path = np.random.normal(size = [batch,length,channels])
path_tf = tf.convert_to_tensor(path)

layer_sig = Layer_Sig(channels, depth)
sig_stream = layer_sig(path_tf)

sig_stream2 = Sig_method.sig_stream2(path[0],depth)

sig_stream2.numpy() - sig_stream.numpy()

array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])