In [None]:
#@title FFTConv2D Tensorflow Keras

import functools

from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import InputSpec
# imports for backwards namespace compatibility
# pylint: disable=unused-import
from tensorflow.python.keras.layers.pooling import AveragePooling1D
from tensorflow.python.keras.layers.pooling import AveragePooling2D
from tensorflow.python.keras.layers.pooling import AveragePooling3D
from tensorflow.python.keras.layers.pooling import MaxPooling1D
from tensorflow.python.keras.layers.pooling import MaxPooling2D
from tensorflow.python.keras.layers.pooling import MaxPooling3D
# pylint: enable=unused-import
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
# pylint: disable=g-classes-have-attributes
from tensorflow.keras.utils import register_keras_serializable

@register_keras_serializable()
class FFTConv2D(Layer):

  def __init__(self,
               filters,
               kernel_size,
               strides=1,
               padding='valid',
               data_format=None,
               dilation_rate=1,
               groups=1,
               activation=None,
               use_bias=True,
               kernel_initializer='glorot_uniform',
               bias_initializer='zeros',
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               trainable=True,
               name=None,
               conv_op=None,
               **kwargs):
    super(FFTConv2D, self).__init__(
        trainable=trainable,
        name=name,
        activity_regularizer=regularizers.get(activity_regularizer),
        **kwargs)
    rank = 2
    self.rank = rank

    if isinstance(filters, float):
      filters = int(filters)
    if filters is not None and filters < 0:
      raise ValueError(f'Received a negative value for `filters`.'
                       f'Was expecting a positive value, got {filters}.')
    self.filters = filters
    self.groups = groups or 1
    self.kernel_size = conv_utils.normalize_tuple(
        kernel_size, rank, 'kernel_size')
    self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
    self.padding = conv_utils.normalize_padding(padding)
    self.data_format = conv_utils.normalize_data_format(data_format)
    self.dilation_rate = conv_utils.normalize_tuple(
        dilation_rate, rank, 'dilation_rate')

    self.activation = activations.get(activation)
    self.use_bias = use_bias
    self.conv_op = conv_op

    self.kernel_initializer = initializers.get(kernel_initializer)
    self.bias_initializer = initializers.get(bias_initializer)
    self.kernel_regularizer = regularizers.get(kernel_regularizer)
    self.bias_regularizer = regularizers.get(bias_regularizer)
    self.kernel_constraint = constraints.get(kernel_constraint)
    self.bias_constraint = constraints.get(bias_constraint)
    self.input_spec = InputSpec(min_ndim=self.rank + 2)

    self._validate_init()
    self._is_causal = self.padding == 'causal'
    self._channels_first = self.data_format == 'channels_first'
    self._tf_data_format = conv_utils.convert_data_format(
        self.data_format, self.rank + 2)

  def _validate_init(self):
    if self.filters is not None and self.filters % self.groups != 0:
      raise ValueError(
          'The number of filters must be evenly divisible by the number of '
          'groups. Received: groups={}, filters={}'.format(
              self.groups, self.filters))

    if not all(self.kernel_size):
      raise ValueError('The argument `kernel_size` cannot contain 0(s). '
                       'Received: %s' % (self.kernel_size,))

    if not all(self.strides):
      raise ValueError('The argument `strides` cannot contains 0(s). '
                       'Received: %s' % (self.strides,))

    if (self.padding == 'causal' and not isinstance(self,
                                                    (FFTConv1D, SeparableConv1D))):
      raise ValueError('Causal padding is only supported for `Conv1D`'
                       'and `SeparableConv1D`.')

  def build(self, input_shape):
    input_shape = tensor_shape.TensorShape(input_shape)
    input_channel = self._get_input_channel(input_shape)
    if input_channel % self.groups != 0:
      raise ValueError(
          'The number of input channels must be evenly divisible by the number '
          'of groups. Received groups={}, but the input has {} channels '
          '(full input shape is {}).'.format(self.groups, input_channel,
                                             input_shape))
    kernel_shape = self.kernel_size + (input_channel // self.groups,
                                       self.filters)

    self.kernel = self.add_weight(
        name='kernel',
        shape=kernel_shape,
        initializer=self.kernel_initializer,
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        trainable=True,
        dtype=self.dtype)
    if self.use_bias:
      self.bias = self.add_weight(
          name='bias',
          shape=(self.filters,),
          initializer=self.bias_initializer,
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint,
          trainable=True,
          dtype=self.dtype)
    else:
      self.bias = None
    channel_axis = self._get_channel_axis()
    self.input_spec = InputSpec(min_ndim=self.rank + 2,
                                axes={channel_axis: input_channel})

    # Convert Keras formats to TF native formats.
    if self.padding == 'causal':
      tf_padding = 'VALID'  # Causal padding handled in `call`.
    elif isinstance(self.padding, str):
      tf_padding = self.padding.upper()
    else:
      tf_padding = self.padding
    tf_dilations = list(self.dilation_rate)
    tf_strides = list(self.strides)

    tf_op_name = self.__class__.__name__
    if tf_op_name == 'FFTConv1D':
      tf_op_name = 'fftconv1d'  # Backwards compat.

    self._convolution_op = functools.partial(
        nn_ops.convolution_v2,
        strides=tf_strides,
        padding=tf_padding,
        dilations=tf_dilations,
        data_format=self._tf_data_format,
        name=tf_op_name)
    self.built = True

  def call(self, inputs):
    input_shape = inputs.shape

    if self._is_causal:  # Apply causal padding to inputs for Conv1D.
      inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))

    outputs = self.fft_op(inputs, self.kernel)

    if self.use_bias:
      output_rank = outputs.shape.rank
      if self.rank == 1 and self._channels_first:
        # nn.bias_add does not accept a 1D input tensor.
        bias = array_ops.reshape(self.bias, (1, self.filters, 1))
        outputs += bias
      else:
        # Handle multiple batch dimensions.
        if output_rank is not None and output_rank > 2 + self.rank:

          def _apply_fn(o):
            return nn.bias_add(o, self.bias, data_format=self._tf_data_format)

          outputs = conv_utils.squeeze_batch_dims(
              outputs, _apply_fn, inner_rank=self.rank + 1)
        else:
          outputs = nn.bias_add(
              outputs, self.bias, data_format=self._tf_data_format)

    if not context.executing_eagerly():
      # Infer the static output shape:
      out_shape = self.compute_output_shape(input_shape)
      outputs.set_shape(out_shape)

    if self.activation is not None:
      return self.activation(outputs)
    return outputs

  @tf.function
  def fft_op(self, batch_images, kernels_outputs, strides=(1, 1), padding='valid', dilation_rate=(1, 1), data_format='channels_last'):
    kernels = tf.transpose(kernels_outputs, perm=[3,0,1,2])

    def process_image(image):
      channels = tf.transpose(image, perm=[2,0,1])

      def process_kernel(kernel):
        kernel_channels = tf.transpose(kernel, perm=[2,0,1])

        def process_channel(args):
          channel, kernel_channel = args

          channel = tf.cast(channel, dtype=tf.float32)
          kernel_channel = tf.cast(kernel_channel, dtype=tf.float32)

          if padding == 'same':
              pad_height = max((kernel_channel.shape[0] - 1) // 2, 0)
              pad_width = max((kernel_channel.shape[1] - 1) // 2, 0)
              channel = tf.pad(channel, paddings=[[pad_height, pad_height], [pad_width, pad_width]])
          if channel.shape[0] is not None and channel.shape[1] is not None:
            if channel.shape[0] > kernel_channel.shape[0] or channel.shape[1] > kernel_channel.shape[1]:
              pad_height_kernel = max(channel.shape[0] - kernel_channel.shape[0], 0)
              pad_width_kernel = max(channel.shape[1] - kernel_channel.shape[1], 0)
              kernel_channel = tf.pad(kernel_channel, paddings=[[0, pad_height_kernel], [0, pad_width_kernel]])
            elif channel.shape[0] < kernel_channel.shape[0] or channel.shape[1] < kernel_channel.shape[1]:
              kernel_channel = kernel_channel[:channel.shape[0], :channel.shape[1]]

          signal_fft = tf.signal.rfft2d(channel)
          kernel_fft = tf.signal.rfft2d(kernel_channel)

          result_fft = signal_fft * kernel_fft

          return tf.signal.irfft2d(result_fft)

        output = tf.concat(tf.vectorized_map(process_channel, (channels, kernel_channels)), axis=-1)
        output = tf.transpose(output, perm=[1,2,0])
        return tf.reduce_sum(output, axis=-1)

      return tf.concat(tf.vectorized_map(process_kernel, kernels), axis=-1)

    result_batch = tf.concat(tf.vectorized_map(process_image, batch_images), axis=-1)
    result_batch = tf.transpose(result_batch, perm=[0,2,3,1])

    return result_batch

  def _spatial_output_shape(self, spatial_input_shape):
    return [
        conv_utils.conv_output_length(
            length,
            self.kernel_size[i],
            padding=self.padding,
            stride=self.strides[i],
            dilation=self.dilation_rate[i])
        for i, length in enumerate(spatial_input_shape)
    ]

  def compute_output_shape(self, input_shape):
    input_shape = tensor_shape.TensorShape(input_shape).as_list()
    batch_rank = len(input_shape) - self.rank - 1
    if self.data_format == 'channels_last':
      return tensor_shape.TensorShape(
          input_shape[:batch_rank]
          + self._spatial_output_shape(input_shape[batch_rank:-1])
          + [self.filters])
    else:
      return tensor_shape.TensorShape(
          input_shape[:batch_rank] + [self.filters] +
          self._spatial_output_shape(input_shape[batch_rank + 1:]))

  def _recreate_conv_op(self, inputs):  # pylint: disable=unused-argument
    return False

  def get_config(self):
    config = {
        'filters':
            self.filters,
        'kernel_size':
            self.kernel_size,
        'strides':
            self.strides,
        'padding':
            self.padding,
        'data_format':
            self.data_format,
        'dilation_rate':
            self.dilation_rate,
        'groups':
            self.groups,
        'activation':
            activations.serialize(self.activation),
        'use_bias':
            self.use_bias,
        'kernel_initializer':
            initializers.serialize(self.kernel_initializer),
        'bias_initializer':
            initializers.serialize(self.bias_initializer),
        'kernel_regularizer':
            regularizers.serialize(self.kernel_regularizer),
        'bias_regularizer':
            regularizers.serialize(self.bias_regularizer),
        'activity_regularizer':
            regularizers.serialize(self.activity_regularizer),
        'kernel_constraint':
            constraints.serialize(self.kernel_constraint),
        'bias_constraint':
            constraints.serialize(self.bias_constraint),
        'trainable':
            self.trainable
    }
    base_config = super(FFTConv2D, self).get_config()
    full_config = dict(list(base_config.items()) + list(config.items()))
    return full_config

  def _compute_causal_padding(self, inputs):
    """Calculates padding for 'causal' option for 1-d conv layers."""
    left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
    if getattr(inputs.shape, 'ndims', None) is None:
      batch_rank = 1
    else:
      batch_rank = len(inputs.shape) - 2
    if self.data_format == 'channels_last':
      causal_padding = [[0, 0]] * batch_rank + [[left_pad, 0], [0, 0]]
    else:
      causal_padding = [[0, 0]] * batch_rank + [[0, 0], [left_pad, 0]]
    return causal_padding

  def _get_channel_axis(self):
    if self.data_format == 'channels_first':
      return -1 - self.rank
    else:
      return -1

  def _get_input_channel(self, input_shape):
    channel_axis = self._get_channel_axis()
    if input_shape.dims[channel_axis].value is None:
      raise ValueError('The channel dimension of the inputs '
                       'should be defined. Found `None`.')
    return int(input_shape[channel_axis])

  def _get_padding_op(self):
    if self.padding == 'causal':
      op_padding = 'valid'
    else:
      op_padding = self.padding
    if not isinstance(op_padding, (list, tuple)):
      op_padding = op_padding.upper()
    return op_padding

In [None]:
#@title VFFT-CNN

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, Activation, BatchNormalization, Add, Attention
from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D, SeparableConv2D  # straturi convolutionale si max-pooling
from tensorflow.keras.optimizers import RMSprop, SGD, Adadelta, Adam, Nadam
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.initializers import RandomNormal, HeNormal, GlorotUniform, GlorotNormal
from tensorflow.keras.regularizers import L2, L1L2

kernel_regularizer=L2(1e-4)
kernel_initializer=GlorotUniform(seed=None)
drop_rate = 0.35  # Best value for CIFAR-100 after tuning in range 0.25 - 0.75!
psiz=4
stri=2

#--------------------------  ------------------------------
# Define a convolutional block with FFTConv2D
def fft_conv_block(inputs, inputs_x, filters, kernel_size, padding, input_shape):
    x = FFTConv2D(filters=filters, kernel_size=(kernel_size, kernel_size), padding=padding, input_shape=input_shape, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer)(inputs_x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(psiz, psiz),strides=(stri,stri),padding=padding)(x)
    x = Dropout(drop_rate)(x)

    return x, inputs

def fft_block(inputs, filters, kernel_size, padding,input_shape):
    x = FFTConv2D(filters=filters, kernel_size=(kernel_size, kernel_size), padding=padding, input_shape=input_shape, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer)(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x

def create_v_cnn_fft_model(input_shape, num_classes, flat=1, fil=[20], nl=[1], hid=[], csize=15, padding='same'):
    inputs = Input(shape=input_shape)
    original_inputs = inputs
    x = inputs

    # First macro-layer - connected to input
    if nl[0] > 0:
        x = fft_block(x, fil[0], csize, padding, input_shape)
        for _ in range(nl[0]):
            x = fft_block(x, fil[0], csize, padding, input_shape)
        x, inputs = fft_conv_block(inputs, x, fil[0], csize, padding, input_shape)
    else:
        x, inputs = fft_conv_block(inputs, x, fil[0], csize, padding, input_shape)

    # The remaining macro-layers
    for layer in range(1, len(fil)):
        for _ in range(nl[layer]):
            x = fft_block(x, fil[layer], csize, padding, input_shape)
        x, inputs = fft_conv_block(inputs, x, fil[layer], csize, padding, input_shape)

    # Exit classifier
    if flat == 1:
        x = Flatten()(x)
    elif flat == 0:
        x = GlobalAveragePooling2D()(x)

    for units in hid:
        x = Dense(units, activation='relu')(x)

    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs=original_inputs, outputs=outputs)

    initial_learning_rate = 0.0001
    lr_schedule = ExponentialDecay(
        initial_learning_rate,
        decay_steps=10000,
        decay_rate=0.96,
        staircase=True)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(clipvalue=1.0, learning_rate=initial_learning_rate),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    return model
