In [1]:
import tensorflow as tf
import numpy as np
import math

from module.one_dim.layers import *

In [2]:
class Modulation(tf.keras.layers.Layer):
    def __init__(self, limitation=None, phi_max=0.0):
        super(Modulation, self).__init__()
        if limitation is not None:
            self.limitation = tf.Variable(limitation, validate_shape=False, name="limitation", trainable=False)
            self.limitation = limitation
        else:
            self.limitation = tf.Variable("None", validate_shape=False, name="limitation", trainable=False)
            self.limitation = limitation

        self.phi_max = tf.Variable(phi_max, validate_shape=False, name="theta_max", trainable=False)
        assert self.phi_max.numpy() >= 0.0

    def build(self, input_dim):
        self.input_dim = input_dim
        self.phi = self.add_weight("phi",
                                   shape=[int(input_dim[-1])])
        super(Modulation, self).build(input_dim)

    @tf.function
    def get_limited_phi(self):
        if self.limitation == 'sigmoid':
            return self.phi_max * tf.sigmoid(self.phi)
        else:
            return self.phi

    def get_config(self):
        config = super().get_config()
        config.update({
            "limitation": self.limitation,
            "phi_max": self.phi_max
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

    @tf.function
    def call(self, x):
        return x * tf.complex(tf.cos(self.phi), tf.sin(self.phi))


In [11]:
class ElectricFieldToIntensity(tf.keras.layers.Layer):
    def __init__(self):
        super(ElectricFieldToIntensity, self).__init__()

    def build(self, input_dim):
        self.input_dim = input_dim
        super(ElectricFieldToIntensity, self).build(input_dim)

    def get_config(self):
        config = super().get_config()
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

    @tf.function
    def call(self, x):
        return tf.abs(x)**2/2

In [12]:
shape = (100, 100)
inputs = tf.keras.Input(10)
x = IntensityToElectricField()(inputs)
x = Modulation()(x)
x = ElectricFieldToIntensity()(x)

model = tf.keras.Model(inputs, x)

In [14]:
one = np.ones((5, 10))
pred = model.predict(one)
pred

array([[0.99999994, 0.99999994, 1.0000001 , 0.99999994, 0.99999994,
        0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994],
       [0.99999994, 0.99999994, 1.0000001 , 0.99999994, 0.99999994,
        0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994],
       [0.99999994, 0.99999994, 1.0000001 , 0.99999994, 0.99999994,
        0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994],
       [0.99999994, 0.99999994, 1.0000001 , 0.99999994, 0.99999994,
        0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994],
       [0.99999994, 0.99999994, 1.0000001 , 0.99999994, 0.99999994,
        0.99999994, 0.99999994, 0.99999994, 0.99999994, 0.99999994]],
      dtype=float32)

[array([ 0.21082616,  0.40405482,  0.46338618,  0.5306232 , -0.41534615,
         0.5390552 ,  0.39568508,  0.0988189 , -0.31008095,  0.2886622 ],
       dtype=float32),
 0.0]