In [None]:
from topologylayer.nn import AlphaLayer, BarcodePolyFeature
from gtda.plotting import plot_point_cloud
import matplotlib.pyplot as plt

import math

import numpy as np
from scipy.signal import convolve

import torch
from torch.autograd import Variable
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F


class Transpose(nn.Module):
    def __init__(self):
        super(Transpose, self).__init__()

    def forward(self, x):
        return torch.transpose(x, 0, 1)

In [None]:
n_periods = 50
samples_per_period = 30
N = n_periods*samples_per_period
t = np.linspace(0, 2*np.pi*n_periods, N)
x = np.sin(t) + 0.01*t
x = x-np.min(x)
x /= np.max(x)
plt.figure(figsize=(12, 4))
plt.plot(x)

In [None]:
from persim import plot_diagrams

model = nn.Sequential(
    nn.BatchNorm1d(1),
    nn.Conv1d(1, 2, samples_per_period, stride=1, bias=False),
    nn.Flatten(start_dim=0, end_dim=1),
    Transpose(),
    #nn.BatchNorm1d(2),
    AlphaLayer(maxdim=1)
)
data = torch.from_numpy((100*np.sin(t)).reshape(1, 1, -1)).float()
res = model(data)
dgm1 = model(data)[0][1].detach().numpy()
plot_diagrams(dgm1)

In [None]:
class SlidingAutoencoder(nn.Module):
    def __init__(self, x, dim, win, lam=1, lr=1e-2):
        """
        Parameters
        ----------
        x: ndarray(N)
            Time series
        dim: int
            Dimension of embedding
        win: int
            Window size
        lam: float
            Weight of topological regularization
        lr: float
            Learning rate
        """
        super(SlidingAutoencoder, self).__init__()
        self.x_orig = x
        self.x = torch.from_numpy(x.reshape(1, 1, -1)).float()
        self.win = win
        self.dim = dim
        self.lam = lam
        ## TODO: This could be an RNN, and there could also be more layers
        self.linear1 = nn.Conv1d(1, 1, win+1, stride=1, padding='same')
        self.linear1_relu = nn.ReLU()
        self.norm1 = nn.BatchNorm1d(1)
        self.conv1 = nn.Conv1d(1, dim, win, stride=1, bias=False)
        self.flatten = nn.Flatten(start_dim=0, end_dim=1)
        self.transpose = Transpose()
        ## TODO: This needs to do z normalization of the point cloud
        self.norm2 = nn.BatchNorm1d(dim) 
        self.alpha = AlphaLayer(maxdim=1)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        
        # Loss functions and losses
        self.mse_loss = nn.MSELoss()
        self.maxh1_loss = BarcodePolyFeature(1,2,0)
        self.losses = []
    
    def forward(self):
        """
        Returns
        -------
        y: The warped time series
        dgms: The persistence diagrams of the sliding widow embedding of the
              warped time series
        """
        y = self.linear1(self.x)
        y = self.linear1_relu(y)
        #y = self.norm1(y)
        sw = self.conv1(y)
        sw = self.flatten(sw)
        sw = self.transpose(sw)
        sw = self.norm2(sw)
        dgms = self.alpha(sw)
        return y, dgms
    
    def train_step(self):
        self.optimizer.zero_grad()
        self.train()
        y, dgms = self.forward()
        loss = self.mse_loss(self.x, y)-self.lam*self.maxh1_loss(dgms)
        loss.backward()
        self.optimizer.step()
        self.losses.append(loss.item())
    
    def train_epochs(self, num_epochs):
        self.losses = []
        for epoch in range(num_epochs):
            self.train_step()
            if epoch%50 == 0:
                y = self.forward()[0].detach().numpy().flatten()
                plt.figure(figsize=(12, 4))
                plt.plot(self.x_orig)
                plt.plot(y)
                plt.title("{}: {:.3f}".format(epoch, self.losses[-1]))
                plt.show()
        y = self.forward()[0].detach().numpy().flatten()
        return y

autoencoder = SlidingAutoencoder(x, lam=1, dim=2, win=samples_per_period)
y = autoencoder.train_epochs(500)