In [2]:
from typing import Dict
import torch
from SINDy_library import *
from ..autoencoder.autoencoder import AutoEncoder

class SINDy(torch.nn.Module):
    """
    SINDy dz predictions

    Arguments:
        params - Dictionary object containing the parameters that specify the training.
        See params.txt file for a description of the parameters.

    Returns:
        sindy_predict - tensor containing sindy's predictions for dz.
    """


    def __init__(self, params:Dict = {}, encoder:AutoEncoder = torch.Tensor, *args, **kwargs) -> None:
      super().__init__(*args, **kwargs) 
      self.params = params
      self.encoder = encoder

      self.input_dim = self.params['input_dim']
      self.latent_dim = self.params['latent_dim']
      self.poly_order = self.params['poly_order']
      if 'include_sine' in self.params.keys():
        self.include_sine = self.params['include_sine']
      else:
        self.include_sine = False
      self.library_dim = self.params['library_dim']
      self.model_order = self.params['model_order']
      self.sequential_thresholding = self.params['sequential_thresholding']
      self.coefficient_initialization = self.params['coefficient_initialization']
      self.coefficient_mask = self.params['coefficient_initialization']

      #initialize sindy coefficients  
      self.sindy_coefficients = torch.zeros((self.library_dim, self.latent_dim))
      self.init_sindy_coefficients()


    def init_sindy_coefficients(self, name='normal', std=1., k=3):

      #self.sindy_coefficients = std*torch.randn_like(self.sindy_coefficients)
 
      if name == 'xavier':
        self.sindy_coefficients = torch.nn.init.xavier_uniform_(self.sindy_coefficients)
      elif name == 'uniform':
        self.sindy_coefficients = torch.nn.init.uniform_(self.sindy_coefficients, low=0.0, high=1.0)
      elif name == 'constant':
        self.sindy_coefficients = torch.ones_like(self.sindy_coefficients)*k
      elif name == 'normal':
        self.sindy_coefficients = torch.nn.init.normal_(self.sindy_coefficients, mean=0, std=std) 
    
      

    def forward(self, x, dx, ddx)-> torch.Tensor:

      z, dz, ddz = self.encoder(x, dx, ddx)

      #create Theta
      if self.model_order == 1:
        Theta = sindy_library_pt(z, self.latent_dim, self.poly_order, self.include_sine)
      else:
        Theta = sindy_library_pt_order2(z, dz, self.latent_dim, self.poly_order, self.include_sine)
      
      #apply thresholding or not
      if self.sequential_thresholding:
        '''
        tmp = torch.rand(size=(library_dim,latent_dim), dtype=torch.float32)
        mask = torch.zeros_like(tmp)
        mask = mask.where(self.coefficient_mask, tmp)
        '''
        mask = torch.rand(size=(self.library_dim, self.latent_dim), dtype=torch.float32)
        sindy_predict = torch.matmul(Theta, mask*self.sindy_coefficients)
      else:
        sindy_predict = tf.matmul(Theta, self.sindy_coefficients)

      #decode
      x_decode, dx_decode, ddx_decode = self.decoder(z, dz, ddz)

      if self.model_order == 1:
       dz_predict = sindy_predict
       dzz_predict = None
      else:
       ddz_predict = sindy_predict
       dz_predict = None

    return torch.cat((x, dx, z, dz, x_decode, dx_decode, dz_predict, dzz_predict, sindy_coefficients, sindy_predict))

ImportError: attempted relative import with no known parent package

In [None]:
sindy = SINDy()
sindy()