In [None]:
import os
import os.path as osp
import numpy as np
from scipy.spatial.transform import Rotation as R
import pickle
import matplotlib.pyplot as plt
plt.rc('font',family='Times New Roman', size=14)
import pandas as pd

In [None]:
MEASUREMENT_PATH = r"C:\Users\robotflow\Desktop\rfimu-interface\imu_data\imu_mem_2023-04-18_202012\imu.pkl"
imu_data = pickle.load(open(MEASUREMENT_PATH, 'rb'))


In [None]:
imu_ids = list(imu_data.keys())
z_axis = {s:None for s in imu_ids}
for imu_id in imu_ids:
    z_axis[imu_id] =  R.from_quat(imu_data[imu_id]['quat']).apply(np.array([0, 0, 1]))

In [None]:
def plot_z_axis(vectors, start_idx=0, end_idx=None):
    end_idx = end_idx if end_idx is not None else len(vectors[vectors.keys()[0]])
    fig = plt.figure(dpi=300, figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    for key in vectors:
        x, y, z = zip(*vectors[key])
        ax.plot(x[start_idx:end_idx], y[start_idx:end_idx], z[start_idx:end_idx], label=key)
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')
    ax.legend()
    plt.show()


In [None]:
plot_z_axis(z_axis, 3000, 6000)

In [None]:
def plot_z_axis_flat(vectors, start_idx=0, end_idx=None):
    end_idx = end_idx if end_idx is not None else len(vectors[vectors.keys()[0]])
    fig, axs = plt.subplots(3, dpi=300, figsize=(10, 10))
    # fig.suptitle('Projection of Norm Vector on X, Y, Z Axis in World Coordinate')
    for key in vectors:
        x, y, z = zip(*vectors[key])
        axs[0].plot(x[start_idx:end_idx], label=key)
        axs[1].plot(y[start_idx:end_idx], label=key)
        axs[2].plot(z[start_idx:end_idx], label=key)
    axs[0].set_title('X')
    axs[1].set_title('Y')
    axs[2].set_title('Z')
    _value = ['X', 'Y', 'Z']
    for idx, ax in enumerate(axs.flat):
        ax.set(ylabel=_value[idx] + ' Axis')
        ax.legend()
    plt.show()

In [None]:
plot_z_axis_flat(z_axis, 3000, 6000)

In [None]:
from scipy.signal import hilbert, butter, filtfilt

def plot_phase_synchrony(d1: np.ndarray, d2: np.ndarray):
    def butter_bandpass(lowcut, highcut, fs, order=5):
        nyq = 0.5 * fs
        low = lowcut / nyq
        high = highcut / nyq
        b, a = butter(order, [low, high], btype='band')
        return b, a


    def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
        b, a = butter_bandpass(lowcut, highcut, fs, order=order)
        y = filtfilt(b, a, data)
        return y

    lowcut  = .001
    highcut = 50
    fs = 400.
    order = 2

    y1 = butter_bandpass_filter(d1,lowcut=lowcut,highcut=highcut,fs=fs,order=order)
    y2 = butter_bandpass_filter(d2,lowcut=lowcut,highcut=highcut,fs=fs,order=order)

    al1 = np.angle(hilbert(y1),deg=False)
    al2 = np.angle(hilbert(y2),deg=False)
    phase_synchrony = 1-np.sin(np.abs(al1-al2)/2)
    N = len(al1)

    # 绘制结果
    f,ax = plt.subplots(3,1,figsize=(14,7),sharex=True)
    ax[0].plot(y1,color='r',label='dev1')
    ax[0].plot(y2,color='b',label='dev2')
    # ax[0].legend(bbox_to_anchor=(0., 1.02, 1., .102),ncol=2)
    ax[0].legend()
    ax[0].set(xlim=[0,N], title='Gyroscope Measurement', )
    ax[1].plot(al1,color='r')
    ax[1].plot(al2,color='b')
    ax[1].set(ylabel='Angle',title='Angle',xlim=[0,N])
    ax[2].plot(phase_synchrony)
    ax[2].set(ylim=[0,1.1],xlim=[0,N],title='Instantaneous Phase Synchrony',xlabel='N samples',ylabel='Phase Synchrony')
    plt.tight_layout()
    plt.savefig('output.pdf')
    plt.show()
    return phase_synchrony.mean()


In [None]:
plot_phase_synchrony(z_axis[imu_ids[0]][:,0], z_axis[imu_ids[1]][:,0])

In [None]:
plot_phase_synchrony(imu_data[imu_ids[0]]['gyro'][:,0], imu_data[imu_ids[1]]['gyro'][:,0])


In [None]:
def plot_rolling_window_correlation(d1: np.ndarray, d2: np.ndarray):
    def get_triangle(df,k=0):
        '''
        This function grabs the upper triangle of a correlation matrix
        by masking out the bottom triangle (tril) and returns the values.

        df: pandas correlation matrix
        '''
        x = np.hstack(df.mask(np.tril(np.ones(df.shape),k=k).astype(bool)).values.tolist())
        x = x[~np.isnan(x)]
        return x

    def rolling_correlation(data, wrap=False, *args, **kwargs):
        '''
        Intersubject rolling correlation.
        Data is dataframe with observations in rows, subjects in columns.
        Calculates pairwise rolling correlation at each time.
        Grabs the upper triangle, at each timepoints returns dataframe with
        observation in rows and pairs of subjects in columns.
        *args:
            window: window size of rolling corr in samples
            center: whether to center result (Default: False, so correlation values are listed on the right.)
        '''
        data_len = data.shape[0]
        half_data_len = int(data.shape[0]/2)
        start_len = data.iloc[half_data_len:].shape[0]
        if wrap:
            data = pd.concat([data.iloc[half_data_len:],data,data.iloc[:half_data_len]],axis=0).reset_index(drop=True)
        _rolling = data.rolling(*args, **kwargs).corr()
        rs=[]
        for i in np.arange(0,data.shape[0]):
            rs.append(get_triangle(_rolling.loc[i]))
        rs = pd.DataFrame(rs)
        rs = rs.iloc[start_len:start_len+data_len].reset_index(drop=True)
        return rs

    N = 600 # number of smaples
    T = 1.0 / 400.0 # sample spacing
    window_size = 50

    f,ax = plt.subplots(2,1,figsize=(20,5),sharex=True)
    ax[0].plot(d1,color='r',label='y1')
    ax[0].plot(d2,color='b',label='y2')
    # ax[0].legend(bbox_to_anchor=(0., 1.02, 1., .102),ncol=2)
    ax[0].legend()
    ax[0].set(xlim=[0,N], title='Timeseries Data')
    window_corr_synchrony = rolling_correlation(data=pd.DataFrame({'y1':d1,'y2':d2}),wrap=True,window=window_size,center=True)
    window_corr_synchrony.plot(ax=ax[1],legend=False)
    ax[1].set(ylim=[-1.1,1.1],xlim=[0,N],title='Windowed Correlation Synchrony (size: '+str(window_size)+')',xlabel='Time',ylabel='Correlation Synchrony')
    plt.tight_layout()
    plt.show()

In [None]:
plot_rolling_window_correlation(imu_data[imu_ids[0]]['gyro'][:,0], imu_data[imu_ids[1]]['gyro'][:,0])