# Check implementation with Guhdi's

In [None]:
import gudhi
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from ripslayer import RipsModule
import torch.nn as nn
import torch
tf.config.set_visible_devices([], 'GPU')

# Small example
Analogy between the tf and torch's implementation

In [None]:
from gudhi.tensorflow.rips_layer import RipsLayer
from gudhi.wasserstein import wasserstein_distance

In [None]:
np.random.seed(1)
angles = np.random.uniform(0,2*np.pi,100)
X = np.hstack([ np.cos(angles)[:,None], np.sin(angles)[:,None] ])
dim = 1
X = np.array([[0.1,0.],[1.5,1.5],[0.,1.6]])
dim = 0

#### tensorflow

In [None]:
XTF = tf.Variable(X, dtype=tf.float32)
lr = 1
optimizer = tf.keras.optimizers.SGD(learning_rate=lr)

num_epochs = 1
losses, Dgs, Xs, grads = [], [], [], []
for epoch in range(num_epochs+1):
    with tf.GradientTape() as tape:
        layer = RipsLayer(homology_dimensions=[dim], maximum_edge_length=10)
        dgm = layer.call(X=XTF)[0][0]
        loss = - wasserstein_distance(dgm, tf.constant(np.empty([0,2])), order=1, enable_autodiff=True)
    Dgs.append(dgm.numpy())            
    Xs.append(XTF.numpy())
    losses.append(loss.numpy())
    gradients = tape.gradient(loss, [XTF])
    grads.append(gradients[0].numpy())
    optimizer.apply_gradients(zip(gradients, [XTF]))

In [None]:
pts_to_move = np.argwhere(np.linalg.norm(grads[0], axis=1) != 0).ravel()
plt.figure()
for pt in pts_to_move:
    plt.arrow(Xs[0][pt,0], Xs[0][pt,1], -lr*grads[0][pt,0], -lr*grads[0][pt,1], color='blue',
              length_includes_head=True, head_length=.05, head_width=.1, zorder=10)
plt.scatter(Xs[0][:,0], Xs[0][:,1], c='red', s=50, alpha=.2,  zorder=3)
plt.scatter(Xs[0][pts_to_move,0], Xs[0][pts_to_move,1], c='red',   s=150, marker='o', zorder=2, alpha=.7, label='Step i')
plt.scatter(Xs[1][pts_to_move,0], Xs[1][pts_to_move,1], c='green', s=150, marker='o', zorder=1, alpha=.7, label='Step i+1')
plt.axis('square')
plt.legend()
plt.show()

## torch

In [None]:
X_torch = nn.Parameter(torch.tensor(X, dtype=torch.float32), requires_grad=True)
lr = 1
optimizer = torch.optim.SGD([X_torch], lr=lr)
optimizer.zero_grad()
num_epochs = 1
losses, Dgs, Xs, grads = [], [], [], []
rips = RipsModule(homology_dimensions=[dim], maximum_edge_length=10)
for epoch in range(num_epochs+1):
    diag = rips(X_torch)[0][0]
    loss = - wasserstein_distance(diag, torch.empty([0,2]), order=1, enable_autodiff=True)
    Dgs.append(diag.detach().numpy())
    Xs.append(XTF.detach().numpy())
    losses.append(loss.detach().numpy())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
pts_to_move = np.argwhere(np.linalg.norm(grads[0], axis=1) != 0).ravel()
plt.figure()
for pt in pts_to_move:
    plt.arrow(Xs[0][pt,0], Xs[0][pt,1], -lr*grads[0][pt,0], -lr*grads[0][pt,1], color='blue',
              length_includes_head=True, head_length=.05, head_width=.1, zorder=10)
plt.scatter(Xs[0][:,0], Xs[0][:,1], c='red', s=50, alpha=.2,  zorder=3)
plt.scatter(Xs[0][pts_to_move,0], Xs[0][pts_to_move,1], c='red',   s=150, marker='o', zorder=2, alpha=.7, label='Step i')
plt.scatter(Xs[1][pts_to_move,0], Xs[1][pts_to_move,1], c='green', s=150, marker='o', zorder=1, alpha=.7, label='Step i+1')
plt.axis('square')
plt.legend()
plt.show()