In [7]:
import numpy as np
import wave
import os
import shutil

def read_wave(file_name):
    file = wave.open(file_name,'rb')
    params = file.getparams()
    nchannels, sampwidth, framerate, nframes = params[:4]
#     print("nchannels:", nchannels)
#     print("sampwidth:", sampwidth)
#     print("framerate:", framerate)
#     print("nframes:", nframes)
    strData = file.readframes(nframes)  # encoded in bytes
    file.close()
    waveData = np.frombuffer(strData, dtype=np.int16)  # convert to int16 
    return waveData, framerate, nframes

def segment_data(waveData, segment_size=2000, overlap_size=500):
    """
    Segments waveform data into NumPy arrays with segment_size data points,
    and a specified overlap between adjacent segments.
    """
    data_len = len(waveData)
    print("num_segments=%d"%data_len)
    num_segments = int(np.ceil((data_len - segment_size) / overlap_size)) + 1
    print("num_segments=%d"%num_segments)
    segments = np.zeros((num_segments, segment_size), dtype=np.int16)
    for i in range(num_segments):
        start = i * overlap_size
        print("start=%d"%start)
        end = start + segment_size
        print("end=%d"%end)
        if end > data_len:
            segments[i, :data_len-start] = waveData[start:]
        else:
            segments[i, :] = waveData[start:end]
    return segments

def save_segments(segments, file_prefix):
    """
    Saves waveform data segments to multiple .txt files with the specified prefix.
    """
    for i, segment in enumerate(segments):
        filename = f"{file_prefix}_{i+1}.txt"
        np.savetxt(filename, segment, delimiter='\t', fmt='%d')
        
def classify(path, save_path, threshold):
    files = os.listdir(path)
    for file in files:
        with open(path + "/" +file, "rb") as f:
            signal = []
            for line in f:
                try:
                    sample = float(line.strip())
                    signal.append(sample)
                except ValueError:
                    pass
            #signal = [float(line.strip()) for line in f]
            action_count = sum(1 for sample in signal if sample > threshold)
            if action_count / len(signal) > 0.3: #规则是一段信号里面有30%的action就定义为动作电位，可以后续根据结果来调整
                shutil.copyfile(path + "/" +file, save_path + "/action/" + file)
            else:
                shutil.copyfile(path + "/" +file, save_path + "/rest/" + file)
            
        
if __name__ == '__main__':
    '''
    waveData, framerate, nframes = read_wave('./实验3-flexion.wav')
    segments = segment_data(waveData, segment_size=2000, overlap_size=500)
    save_segments(segments, './segment_signal/segment')
    '''
    classify('./segment_signal_flexion', './flexion', 400)