In [None]:
%matplotlib ipympl

import h5py
import matplotlib.pyplot as plt
import numpy as np
import obspy
import os
import scipy.signal
import scipy.spatial

DTYPE_INT  = np.int32
DTYPE_REAL = np.float64

In [None]:
npts = 150
nricker, nboxcar, nnoise = 16, 16, 16

data, label, sampling_rate = [], [], []

for i in range(nricker):
    a = 5 * np.random.rand() + 5 # Width of Ricker wavelet
#     sign = np.random.choice([-1, 1])
    _data = scipy.signal.ricker(npts, a)
    sigma = np.random.rand() * 16
    noise = np.random.randn(npts) / (sigma + 16)
    _data = _data + noise
    data.append(_data)
    label.append("R")
    sampling_rate.append(1)

for i in range(nboxcar):
    istart, iend = np.random.randint(0, npts, 2)
    istart, iend = min(istart, iend), max(iend, istart)
    _data = np.ones(npts)
    _data[istart: iend] = -1
    sigma = np.random.rand() * 8
    noise = np.random.randn(npts) / (sigma + 8)
    _data = _data + noise
    data.append(_data)
    label.append("B")
    sampling_rate.append(1)

for i in range(nnoise):
    sigma = np.random.rand()
    noise = sigma * np.random.randn(npts)
    _data = noise
    data.append(_data)
    label.append("N")
    sampling_rate.append(1)

In [None]:
class FastMapTSLibrary(object):


    def __init__(self, path, kdim, sampling_rate=100, mode="w", overwrite=False):
        self._kdim = kdim
        self._sampling_rate = sampling_rate
        self._init_hdf5(path, mode, overwrite=overwrite)
        self._library_size = 0
        self._kdtree = None

    
    @property
    def image(self):
        return (
            self.hdf5.require_dataset(
                "/image",
                shape=(self.library_size, self.kdim),
                maxshape=(None, self.kdim),
                dtype=DTYPE_REAL,
                fillvalue=np.nan
            )
        )

    @property
    def kdim(self):
        return (self._kdim)
    
    @property
    def kdtree(self):
        if self._kdtree == None:
            self._kdtree = scipy.spatial.cKDTree(self.image)
        return (self._kdtree)
    
    @property
    def hdf5(self):
        return (self._hdf5)
    
    @property
    def library_size(self):
        """
        [Read only] The number of time series in the library.
        """
        return (self._library_size)
    
    @property
    def pivot(self):
        return (
            self.hdf5.require_dataset(
                "/pivot", 
                shape=(2, self.kdim),
                dtype=DTYPE_INT
            )
        )
    
    @property
    def sampling_rate(self):
        return (self._sampling_rate)
    
    @property
    def waveforms(self):
        return (self.hdf5.require_group("/waveforms"))


    def __del__(self):
        self.hdf5.close()


    def __enter__(self):
        return (self)


    def __exit__(self, exc_type, exc_value, exc_traceback):
        pass
    
    
    def _embed(self, k, dist, icol):
        """
        Recursive function to embed the library data into k-dimensional
        Euclidean space.
        """

        if k <= 0:
            return (True)
        
        icol += 1
        
        # Choose the pivot objects.
        keys = list(self.waveforms.keys())
        b_name = np.random.choice(keys)
        b = self.waveforms[b_name]
        a_name = self.furthest(b, dist)
        a = self.waveforms[a_name]
        b_name = self.furthest(a, dist)
        b = self.waveforms[b_name]
        
        # Record the names of the pivot objects.
        self.pivot[0, icol] = int(a_name)
        self.pivot[1, icol] = int(b_name)
        
        if dist(a, b) == 0:
            self.image[:, icol] = 0
            return (True)
        
        # Project all objects onto line between objects a and b.
        d_ab = dist(a, b)
        
        def update_image(name, i):
            irow = int(name.split("/")[-1])
            d_ai = dist(a, i)
            d_bi = dist(b, i)
            xi = (d_ai**2 + d_ab**2 - d_bi**2) / (2 * d_ab)
            self.image[irow, icol] = xi
        
        self.waveforms.visititems(update_image)
        
        # Project all objects onto the hyperplane perpendicular to the
        # line between objects a and b.
        def new_dist(a, b, icol=icol, old_dist=dist):
            i = int(a.name.split("/")[-1])
            j = int(b.name.split("/")[-1])
            d_ab = old_dist(a, b)
            xa = self.image[i, icol]
            xb = self.image[j, icol]
            d = np.sqrt(d_ab**2 - (xa - xb)**2)
            if d_ab**2 - (xa - xb)**2 < 0:
                print(d, icol)
            return (d)
        
        return (self._embed(k-1, new_dist, icol=icol))
        
    
    def _init_hdf5(self, path, mode, overwrite=False):
        """
        Initialize the HDF5 backend.
        """
        
        if os.path.exists(path) and mode == "w" and not overwrite:
            raise (IOError(f"{path} already exists."))
        self._hdf5 = h5py.File(path, mode=mode)
        

    def append(self, data, labels, sampling_rate):
        """
        Append data to the library inventory. This does not perform the
        actual embedding into k-dimensional Euclidean space.
        
        Arguments
        data - A single np.ndarray or a list of np.ndarrays.
        """
        
        id = f"{self.library_size:09d}"
        self._library_size += 1
        dataset = self.waveforms.create_dataset(id, data=data)
        dataset.attrs["labels"] = labels
        
        return (True)


    def embed(self):
        """
        Embed the library data into k-dimensional Euclidean space.
        """

        return_value = self._embed(self.kdim, distance, icol=-1)
        return (return_value)

    
    def furthest(self, b, dist):
        """
        Return the name of the object furthest from b.
        """
        
        self._furthest_name, self._furthest_dist = None, 0
        
        def _furthest(name, a, b=b):
            d = dist(a, b)
            if d > self._furthest_dist:
                self._furthest_name = name
                self._furthest_dist = d

        self.waveforms.visititems(_furthest)
        
        furthest_name = self._furthest_name
        del (self._furthest_name, self._furthest_dist)
        return (furthest_name)
    
    
    def query(self, q):
        
        image = np.zeros(self.kdim,)
        icol = -1
        dist = distance
        k = self.kdim
        
        while k > 0:
        
            k -= 1
            icol += 1

            # Retrieve the names of the pivot objects.
            a_name = self.pivot[0, icol] 
            b_name = self.pivot[1, icol] 
            a_name = f"{a_name:09d}"
            b_name = f"{b_name:09d}"
            
            # Retrieve the pivot objects
            a = self.waveforms[a_name]
            b = self.waveforms[b_name]

            if dist(a, b) == 0:
                image[icol] = 0
                continue

            # Project query object onto line between objects a and b.
            d_ab = dist(a, b)

            d_aq = dist(a, q)
            d_bq = dist(b, q)
            xq = (d_aq**2 + d_ab**2 - d_bq**2) / (2 * d_ab)
            image[icol] = xq

            # Update the dist metric by projecting query object onto 
            # the hyperplane perpendicular to the line between objects
            # a and b.
            def new_dist(a, b, image=image, icol=icol, old_dist=dist):
                d_ab = old_dist(a, b)
                xa = image[icol]
                xb = image[icol]
                d = np.sqrt(d_ab**2 - (xa - xb)**2)
                return (d)
        
            dist = new_dist

        return (self.kdtree.query(image))
        

def correlate(a, b):
    """
    Return the normalized cross-correlation of a and b.
    """
    
    a = (a - np.mean(a)) / (np.std(a) * len(a))
    b = (b - np.mean(b)) / (np.std(b))
    corr = np.correlate(a, b, "full")
    
    return (corr)


def distance(a, b):
    """
    Return the distance between a and b.
    """
    
    corr = correlate(a, b)
    dist = 1 - np.max(corr)
    
    return (dist)

In [None]:
fastmap = FastMapTSLibrary("fastmap_test.h5", 8, overwrite=True)
for idx in range(len(data)):
    fastmap.append(data[idx], label[idx], sampling_rate[idx])
    
fastmap.embed()

In [None]:
import mpl_toolkits.mplot3d.axes3d

plt.close("all")
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection="3d")
for label in ("R", "B", "N"):
    idxs = [int(key) for key in fastmap.waveforms.keys() if fastmap.waveforms[key].attrs["labels"] == label]
    ax.scatter(
        fastmap.image[idxs, 0],
        fastmap.image[idxs, 1],
        fastmap.image[idxs, 2]
)