In [31]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tqdm
from tqdm import tqdm
import matplotlib.colors as mcolors
from copy import deepcopy

from topofisher.input_simulators.noisy_ring import CircleSimulator

from topofisher.filtrations.numpy.alphaDTML import AlphaDTMLayer

from topofisher.vectorizations.numpy.custom_vectorizations import TOPK

# from topofisher.fisher.plot_fisher_stats import plotContours, plotSummaryDerivativeHists, plotConvergence
# from topofisher.fisher.imnn import IMNNLayer, MopedLayer, ExtraDimLayer

from topofisher.pipelines.circle import CirclePipeline

In [28]:
import importlib
import topofisher

importlib.reload(topofisher.pipelines.pipeline)
importlib.reload(topofisher.pipelines.circle)
from topofisher.pipelines.circle import CirclePipeline

In [44]:
circle_pipeline = CirclePipeline(ncirc= 200, nback = 20, bgmAvg = 1., n_s = 100, n_d = 100, \
                                 theta_fid = tf.constant([1., 0.2]), delta_theta = tf.constant([0.1, 0.02]),\
                                 filtLayer = AlphaDTMLayer(m = 0.9, show_tqdm = True),
                                 vecLayer = vecLayer, fisherLayer = None, find_derivative = [True, True])
circle_pipeline.run_pipeline()

100%|██████████| 100/100 [00:00<00:00, 180.27it/s]
100%|██████████| 100/100 [00:00<00:00, 182.02it/s]
100%|██████████| 100/100 [00:00<00:00, 179.97it/s]
100%|██████████| 100/100 [00:00<00:00, 181.32it/s]
100%|██████████| 100/100 [00:00<00:00, 181.88it/s]


In [42]:
def getFinitePairs(diag):
  return diag[diag[:,1] < np.inf]

class VectorizationLayers:
    def __init__(self, vectorizations, hom_dims, name = "vecLayer"):
        if len(vectorizations) != len(hom_dims): 
            raise ValueError("Make sure that the vectorizations and the hom_dims are compatible.")
        self.vectorizations = vectorizations
        self.hom_dims = hom_dims
        self.num_hom_dim = len(hom_dims)
        self.name = name
        self.is_fitted = [False for idx in range(self.num_hom_dim)]
        
    def get_persistence_diagrams(self, sts, hom_dim):
        pds = []
        for st in sts :
            pd = st.persistence_intervals_in_dimension(hom_dim)
            pds.append(getFinitePairs(pd))
        return pds
    
    def vectorize_simplex_trees(self, sts):
        hom_dims = self.hom_dims
        vecs = []
        for idx in range(self.num_hom_dim):
            hom_dim = hom_dims[idx]
            veclayer = self.vectorizations[idx]
            pds = self.get_persistence_diagrams(sts, hom_dim)
            if(self.is_fitted[idx] == False) :
                veclayer.fit(pds)
                self.is_fitted[idx] = True
            vecs.append(veclayer.transform(pds))
        return self.post_process(np.concatenate(vecs, axis = -1))
    
    def post_process(self, vecs):
        return vecs

In [43]:
vecLayer = VectorizationLayers(vectorizations = [TOPK(bdp_type = "bd", is_binned = False), TOPK(bdp_type = "bd", is_binned = False)],\
                    hom_dims = [0, 1])

