## Time2Vector

## References
- Time2Vec: Learning a Vector Representation of Time  
https://arxiv.org/pdf/1907.05321.pdf

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import random

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.callbacks import * 
from tensorflow.keras.optimizers import *
from tensorflow.keras import backend as K

In [2]:
class Time2Vector(Layer):
    
    def __init__(self, seq_len, **kwargs):
        super(Time2Vector, self).__init__()
        
        self.seq_len = int(seq_len)
        
    def build(self, input_shape):
        self.weights_linear = self.add_weight(name='weight_linear',
                                              shape=(self.seq_len,),
                                              initializer='uniform',
                                              trainable=True)
    
        self.bias_linear = self.add_weight(name='bias_linear',
                                           shape=(self.seq_len,),
                                           initializer='uniform',
                                           trainable=True)
    
        self.weights_periodic = self.add_weight(name='weight_periodic',
                                                shape=(self.seq_len,),
                                                initializer='uniform',
                                                trainable=True)

        self.bias_periodic = self.add_weight(name='bias_periodic',
                                             shape=(self.seq_len,),
                                             initializer='uniform',
                                             trainable=True)
    
    def call(self, x):
        x = tf.math.reduce_mean(x, axis=-1) # Convert (batch, seq_len, features) to (batch, seq_len)
        time_linear = self.weights_linear * x + self.bias_linear
        time_linear = tf.expand_dims(time_linear, axis=-1) # (batch, seq_len, 1)
    
        time_periodic = tf.math.sin(tf.multiply(x, self.weights_periodic) + self.bias_periodic)
        time_periodic = tf.expand_dims(time_periodic, axis=-1) # (batch, seq_len, 1)
        
        return tf.concat([time_linear, time_periodic], axis=-1) # (batch, seq_len, 2)