In [None]:
#!pip install matplotlib
#!pip install jupyter-resource-usage

In [13]:
import torch
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax
jax.config.update('jax_array', True)
import jax.numpy as jnp
import jax.scipy
import numpy as np
# Requires 'pip install jaxwt'; jaxwt is JAX-accelerated subset of PyWavelet package (pywt)
import jaxwt as jwt
import pywt
#
from scipy import signal as sgnl
from scipy import ndimage
from scipy.interpolate import interp1d
#
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LogNorm, NoNorm
from matplotlib.cm import get_cmap
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

In [2]:
# From the original FileMaker.py
def open_txt(filename):
    """
    fs: Sample Rate
    dt: sample step 1/fs
    length: number of samples at each point
    comp: the measurment component represented in the values array
    data_shape: shape of overall array, can be ignored just used to 
                reshape array on opening from binary
    x: X location of measurement
    y: Y location of measurement
    z: Z location of measurement
    values: All values from file, row index corresponds to location
    """
    with open(filename, 'rb') as f:
        info = f.read(128).decode('ascii').split('\n')
        data = jnp.frombuffer(f.read(), jnp.float16)
    data_dict = {'fs': eval(info[0]), 'dt': eval(info[1]),
                 'length': eval(info[2]), 'comp': info[3],
                 'data_shape': eval(info[4])}

    data = data[448:].reshape(data_dict['data_shape'])
    data_dict['x'] = data[:, 0]
    data_dict['y'] = data[:, 1]
    data_dict['z'] = data[:, 2]
    data_dict['values'] = data[:, 3:]
    return data_dict

In [19]:
xcomp_rough_480lb_10v_file = "/scratch/group/statconsult/dataset/SmallMonolithWN/RoughCut/480 lbin Cut_txt/Hanging_480lbin_ampVolts10_x.txt"

xcomp_rough_480lb_10v = open_txt(xcomp_rough_480lb_10v_file)

print(xcomp_rough_480lb_10v.keys())
print(xcomp_rough_480lb_10v['values'].shape)
print(f"Time step is {xcomp_rough_480lb_10v['dt']} seconds\n")
print(f"Data of shape {xcomp_rough_480lb_10v['data_shape']} contains {xcomp_rough_480lb_10v['length']} time points and {xcomp_rough_480lb_10v['x'].shape} spatial locations\n")

dict_keys(['fs', 'dt', 'length', 'comp', 'data_shape', 'x', 'y', 'z', 'values'])
(105, 9999997)
Time step is 8e-07 seconds

Data of shape (105, 10000000) contains 10000000 time points and (105,) spatial locations



In [None]:
# 1D discrete FFT over time axis
fft_example = jnp.fft.fftn(xcomp_rough_480lb_10v['values'], axes = [1])

print(f"Resulting FFT has shape {fft_example.shape}")

For a given experimental setting (fixed voltage, cut state, torque, etc); we have data of a shape that looks like:

(number of signals, number of components, length of signals)

In our data, the number of signals is the number of spatial locations (about 105), the length is the number of time steps (about 10 million), and the number of components is 3 (for x, y, and z without doing a transform)

Here, we examine doing a discrete wavelet transform for each signal and component. We summarize the resulting coefficients with a variety of different statistics in order to effectively 'compress' the information down.

In [4]:
# Wavelet analysis #
@jax.jit
def compressed_features(coeff_array):
    coeff_array = coeff_array.squeeze()
    mean = jnp.nanmean(coeff_array)
    median = jnp.nanpercentile(coeff_array, 50)
    std = jnp.nanstd(coeff_array)
    perc05 = jnp.nanpercentile(coeff_array, 5)
    perc25 = jnp.nanpercentile(coeff_array, 25)
    perc75 = jnp.nanpercentile(coeff_array, 75)
    perc95 = jnp.nanpercentile(coeff_array, 95)
    perc_crosses_zero = (jnp.count_nonzero(jnp.diff(coeff_array > 0)))/coeff_array.shape[0]
    perc_crosses_mean = (jnp.count_nonzero(jnp.diff(coeff_array > mean)))/coeff_array.shape[0]
    return jnp.array([mean, median, std, perc05, perc25, perc75, perc95, perc_crosses_zero, perc_crosses_mean])

@jax.jit
def extract_wavelet_coeffs(signal, waveletname = 'haar'):
    coeffs = jwt.wavedec(signal, waveletname, level = 5)
    features = jnp.concatenate([compressed_features(coeff) for coeff in coeffs])
    return features
    

In [5]:
test = extract_wavelet_coeffs(xcomp_rough_480lb_10v['values'][1,:])
print(test)

[ 0.00000000e+00  4.38094139e-05  4.15039062e-03 -4.85992432e-03
 -1.96838379e-03  2.06184387e-03  4.97055054e-03  5.80160022e-01
  5.80160022e-01  0.00000000e+00  2.68220901e-06  3.74984741e-03
 -5.61523438e-03 -2.29263306e-03  2.30598450e-03  5.63049316e-03
  5.36745608e-01  5.36745608e-01  0.00000000e+00 -6.73532486e-06
  9.72747803e-03 -1.58538818e-02 -6.48117065e-03  6.48880005e-03
  1.58538818e-02  2.77363211e-01  2.77363211e-01  0.00000000e+00
 -1.21593475e-05  1.06277466e-02 -1.73797607e-02 -7.11441040e-03
  7.13348389e-03  1.74102783e-02  5.25674403e-01  5.25674403e-01
  0.00000000e+00  2.68220901e-06  7.69424438e-03 -1.25885010e-02
 -5.14221191e-03  5.15365601e-03  1.25961304e-02  6.06627226e-01
  6.06627226e-01 -0.00000000e+00  0.00000000e+00  4.20761108e-03
 -6.79779053e-03 -2.77137756e-03  2.76947021e-03  6.80160522e-03
  6.30391955e-01  6.30391955e-01]


In [11]:
print(jax.devices())
print(jax.local_device_count())

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
8


In [9]:
full_wavelet_coeffs = jax.vmap(extract_wavelet_coeffs)(xcomp_rough_480lb_10v['values'])

The above returns an array of size (number of locations, number of extracted features); combine with the other 2 components of motion to get 
(nlocs, n_extracted*3)

Then, we have a standard tabular data set! Stick whatever other predictors / responses you want and fit any out of the box model.

In [10]:
print(full_wavelet_coeffs.shape)

(105, 54)
