In [None]:
from fourier_flow import FourierFlow
from tg_rnn_network import TGRNNNetwork
from ff_training import train_fourier_flow
from tg_training import train_time_gan
import torch
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import random

In [None]:
# Set seeds.
SEED = 12345
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

# Data creation

Create three dataset:
1. sin waves only a phase shift
2. sin waves only freqency shift
3. sin waves with phase and freqency shift

In [None]:
n_samples = 1000
T = 101

freqs = np.random.beta(a=2,b=2, size=n_samples).reshape((-1,1))
phases = np.random.normal(size=n_samples).reshape((-1,1))

signals = np.repeat(np.reshape(np.arange(T, dtype=np.float32), (1,-1)),repeats=n_samples,axis=0)
signals = np.sin(signals * freqs + phases)

X_signal = torch.tensor(signals, dtype=torch.float32)

### Forier Flow Training

In [None]:
n_epochs = 10
learning_rate = 1e-3

model, losses =  train_fourier_flow(X_signal, n_epochs, learning_rate)

### Time Gan Training

In [None]:
n_epochs = 10
learning_rate = 1e-3

(embedder, recoverer, generator, supervisor, discriminator), losses = train_time_gan(X_signal, n_epochs, learning_rate)

In [None]:
plt.plot(losses)

In [None]:
def get_freqencies(X : torch.Tensor) -> torch.Tensor:
    """Compute the dominant frequencies in the data

    Args:
        X (torch.Tensor): DxT signal data

    Returns:
        torch.Tensor: D dominant frequencies
    """
    
    freqs = (1000 - torch.argmax(torch.abs(torch.fft.fftshift(torch.fft.fft(X, dim=1), dim=1)), dim=1)) / 1000
    return freqs

In [None]:
n_sample = 1000
gen_series = model.sample(n_sample)

In [None]:
gen_freq = get_freqencies(gen_series)
real_freq = get_freqencies(X_signal)

In [None]:
sns.kdeplot(gen_freq, linestyle='-', label='generated data')
sns.kdeplot(real_freq, linestyle='--', label='real data')
plt.legend()
