In [None]:
#!pip install matplotlib
#!pip install jupyter-resource-usage

In [1]:
#import torch
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=48"
import jax
jax.config.update('jax_array', True)
import jax.numpy as jnp
import jax.scipy
import numpy as np
# Requires 'pip install jaxwt'; jaxwt is JAX-accelerated subset of PyWavelet package (pywt)
import jaxwt as jwt
import pywt
#
from scipy import signal as sgnl
from scipy import ndimage
from scipy.interpolate import interp1d
#
from functools import partial
#
import time
#
import pandas as pd
from itertools import product
from pathlib import Path
from typing import Callable
#
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LogNorm, NoNorm
from matplotlib.cm import get_cmap
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

In [2]:
base_directory = '/scratch/group/statconsult/dataset/SmallMonolithWN/'
output_dir = '/scratch/group/statconsult/wavelet_extraction/output/'
device_count = jax.local_device_count()

2023-11-10 23:24:28.012912: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /sw/eb/sw/PyTorch/1.12.1-foss-2022a-CUDA-11.7.0/lib/python3.10/site-packages/torch/lib:/sw/eb/sw/NCCL/2.12.12-GCCcore-11.3.0-CUDA-11.7.0/lib:/sw/eb/sw/magma/2.6.2-foss-2022a-CUDA-11.7.0/lib:/sw/eb/sw/GDRCopy/2.3-GCCcore-11.3.0/lib:/sw/eb/sw/cuDNN/8.4.1.50-CUDA-11.7.0/lib:/sw/eb/sw/LibTIFF/4.3.0-GCCcore-11.3.0/lib:/sw/eb/sw/libdeflate/1.10-GCCcore-11.3.0/lib:/sw/eb/sw/zstd/1.5.2-GCCcore-11.3.0/lib:/sw/eb/sw/lz4/1.9.3-GCCcore-11.3.0/lib:/sw/eb/sw/jbigkit/2.1-GCCcore-11.3.0/lib:/sw/eb/sw/libjpeg-turbo/2.1.3-GCCcore-11.3.0/lib:/sw/eb/sw/FFmpeg/4.4.2-GCCcore-11.3.0/lib:/sw/eb/sw/FriBidi/1.0.12-GCCcore-11.3.0/lib:/sw/eb/sw/X11/20220504-GCCcore-11.3.0/lib:/sw/eb/sw/fontconfig/2.14.0-GCCcore-11.3.0/lib:/sw/eb/sw/util-linux/2.38-GCCcore-11.3.0/lib:/

In [3]:
print(device_count)

48


In [4]:
# From the original FileMaker.py
def open_txt(filename):
    """
    fs: Sample Rate
    dt: sample step 1/fs
    length: number of samples at each point
    comp: the measurment component represented in the values array
    data_shape: shape of overall array, can be ignored just used to 
                reshape array on opening from binary
    x: X location of measurement
    y: Y location of measurement
    z: Z location of measurement
    values: All values from file, row index corresponds to location
    """
    with open(filename, 'rb') as f:
        info = f.read(128).decode('ascii').split('\n')
        data_dict = {'fs': eval(info[0]), 
                     'dt': eval(info[1]),
                     'length': eval(info[2]), 
                     'comp': info[3],
                     'data_shape': eval(info[4])}
        num_locs = eval(info[4])[0]
        coords = jnp.frombuffer(f.read(), jnp.float16, count = (3*num_locs)).reshape(num_locs, 3)
        f.seek(128)
        data = jnp.frombuffer(f.read(), jnp.float16, offset = 2*448)
    data = data.reshape(data_dict['data_shape'])
    data_dict['x'] = coords[:, 0]
    data_dict['y'] = coords[:, 1]
    data_dict['z'] = coords[:, 2]
    data_dict['values'] = data
    return data_dict

For a given experimental setting (fixed voltage, cut state, torque, etc); we have data of a shape that looks like:

(number of signals, number of components, length of signals)

In our data, the number of signals is the number of spatial locations (about 105), the length is the number of time steps (about 10 million), and the number of components is 3 (for x, y, and z without doing a transform)

Here, we examine doing a discrete wavelet transform for each signal and component. We summarize the resulting coefficients with a variety of different statistics in order to effectively 'compress' the information down.

In [5]:
# Wavelet analysis #
@jax.jit
def calc_energy(coeff_array):
    return jnp.log2(jnp.mean(jnp.square(coeff_array), axis=-1)+1e-4)

@jax.jit
def compressed_features(coeff_array):
    median = jnp.nanpercentile(coeff_array, 50, axis=-1)
    std = jnp.nanstd(coeff_array, axis=-1)
    perc05 = jnp.nanpercentile(coeff_array, 5, axis=-1)
    perc25 = jnp.nanpercentile(coeff_array, 25, axis=-1)
    perc75 = jnp.nanpercentile(coeff_array, 75, axis=-1)
    perc95 = jnp.nanpercentile(coeff_array, 95, axis=-1)
    #perc_crosses_zero = (jnp.count_nonzero(jnp.diff(coeff_array > 0)))/coeff_array.shape[0]
    energy = calc_energy(coeff_array)
    return jnp.stack([median, std, perc05, perc25, perc75, perc95, energy], axis=-1)

@jax.jit
def extract_wavelet_coeffs(signal, waveletname = 'db32', level = 16):
    coeffs = jwt.wavedec(signal, waveletname, level = level)
    features = jnp.concatenate([compressed_features(coeff) for coeff in coeffs], axis=-1)
    return features

@partial(jax.jit, static_argnums = 2)
def calc_hurst_exponent(signal, waveletname = 'haar', level = 5):
    coeffs = jwt.wavedec(signal, waveletname, level = level)
    energies = jnp.array([calc_energy(coeff) for coeff in coeffs]).at[1:].get()
    print(energies)
    x = jnp.arange(1, level+1, dtype = jnp.dtype(jnp.float32))
    ls_fit = jnp.polyfit(x, energies, 1)
    slope = ls_fit.at[0].get()
    hurst = -0.5*(slope + 1)
    return hurst, x, energies, ls_fit
    

extract_all_wavelet_coeffs = jax.pmap(extract_wavelet_coeffs)

In [14]:
def create_features_from_txt(filename, num_devices = 8):
    data_dict = open_txt(filename)
    x_loc_vec = data_dict['x']
    y_loc_vec = data_dict['y']
    num_locations = data_dict['values'].shape[0]

    num_full_batches = num_locations//num_devices
    ragged_batch_size = num_locations % num_devices
    full_batch_dim = num_full_batches*num_devices

    full_batch_coeffs = extract_all_wavelet_coeffs(data_dict['values'][:full_batch_dim].reshape(num_devices,num_full_batches,-1)).reshape(full_batch_dim,-1)
    ragged_batch_coeffs = extract_all_wavelet_coeffs(data_dict['values'][full_batch_dim:].reshape(ragged_batch_size,1,-1)).reshape(ragged_batch_size,-1)
    wavelet_coeffs = jnp.concatenate([full_batch_coeffs, ragged_batch_coeffs], axis=0)
    return jnp.concatenate([x_loc_vec.reshape(-1,1), y_loc_vec.reshape(-1,1), wavelet_coeffs], axis=-1)

The above returns an array of size (number of locations, number of extracted features); combine with the other 2 components of motion to get 
(nlocs, n_extracted*3)

Then, we have a standard tabular data set! Stick whatever other predictors / responses you want and fit any out of the box model.

In [7]:
#device_count = jax.local_device_count()
#test_feats = create_features_from_txt('/scratch/group/statconsult/dataset/SmallMonolithWN/RoughCut/480 lbin txt/Hanging_480lbin_ampVolts10_x.txt', num_devices = device_count)

In [15]:
def walk_filetree(
    extraction_function: Callable[[Path], np.ndarray],
    *,
    parent_directory: Path | None = None,
    target_extension: str = "*.npy",
):
    if parent_directory is None:
        raise ValueError("parent_directory must be specified")
    for dirpath, dirnames, filenames in os.walk(parent_directory):
        for filename in filenames:
            if filename.endswith(target_extension):
                filepath = os.path.join(dirpath, filename)
                if "Horizontal" in filepath:
                    continue
                if "NoTorque" in filepath:
                    continue
                filepath = Path(filepath)
                yield extraction_function(filepath)

def extraction_func(filepath):
    processed_features = create_features_from_txt(filepath, device_count)
    cut_state = str(filepath.parent.parent).split('/')[-1]
    output_filename = str(filepath).split('/')[-1].split('.')[0] + '.npy'
    if not os.path.exists(os.path.join(output_dir, cut_state)):
        os.makedirs(os.path.join(output_dir, cut_state))
    target = os.path.join(output_dir, cut_state, output_filename)
    jnp.save(target, processed_features)
    return 0

In [None]:
list(walk_filetree(extraction_func, parent_directory = base_directory, target_extension = ".txt"))

In [None]:
job_id = os.environ['SLURM_JOB_ID']
print(job_id)
time.sleep(10)
os.system('scancel ' + job_id)