In [None]:
import numpy as np
import os
from scipy.signal import resample

def resample_signal(original_signal, original_freq, target_freq):
    """
    Resample a signal to a different frequency.
    
    Parameters:
        original_signal (ndarray): Original signal.
        original_freq (float): Original frequency of the signal.
        target_freq (float): Target frequency after resampling.
        
    Returns:
        ndarray: Resampled signal.
    """
    # Calculate the duration of the original signal
    original_duration = len(original_signal) / original_freq
    
    # Calculate the number of points in the resampled signal
    target_num_points = int(original_duration * target_freq)
    
    # Resample the signal
    resampled_signal = resample(original_signal, target_num_points)
    
    return resampled_signal

In [None]:
from scipy.signal import butter, filtfilt

def bandpass_filter(signal, fs, lowcut=0.5, highcut=20, order=6):
    """
    Apply a Butterworth bandpass filter to a signal.

    Parameters:
    - signal: array-like
        Input signal to be filtered.
    - fs: float
        Sampling frequency of the signal.
    - lowcut: float
        Lower cutoff frequency of the filter.
    - highcut: float
        Upper cutoff frequency of the filter.
    - order: int, optional (default=6)
        The order of the Butterworth filter.

    Returns:
    - filtered_signal: array-like
        Filtered signal.
    """
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    filtered_signal = filtfilt(b, a, signal.reshape(-1))
    return filtered_signal.reshape(-1,1)

In [None]:
import os
import numpy as np
import librosa
from PIL import Image
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime
from pathlib import Path
import csv
from os.path import isfile, join
from os import listdir
from pathlib import Path

class MyMain(object):
    __slots__ = ('HPC_flag','UID', 'start_idx','stop_idx','database','input_path','menu_path','filename_menu','output_path_fig','output_path_csv') # Restrain the attribute
        

    
    @property
    def set_MyMain(self):
        return self.start_idx,\
                self.stop_idx,\
                self.database,\
                self.input_path,\
                self.menu_path,\
                self.HPC_flag,\
                self.filename_menu,\
                self.output_path_fig,\
                self.output_path_csv,
    
    @set_MyMain.setter
    def set_MyMain(self,value):
        self.start_idx = value[0]
        self.database = value[1]
        input_path, output_path_csv, output_path_fig, menu_path, filename_menu = self.create_paths()
        self.input_path = input_path
        self.menu_path = menu_path
        self.filename_menu = filename_menu
        self.output_path_fig = output_path_fig
        self.output_path_csv = output_path_csv
        
    def create_paths(self):
        
        if self.database == 'DeepBeat' or self.database == 'deepbeat' or self.database == 'Deepbeat':
            input_path = r"R:\ENGR_Chon\Darren\Public_Database\DeepBeat\Concatenated_DeepBeat\train\Darren_conversion\ppg_1d"
            menu_path = r"R:\ENGR_Chon\Darren\Public_Database\DeepBeat\Concatenated_DeepBeat\train\Darren_conversion"
            output_path_fig = r"R:\ENGR_Chon\Darren\Public_Database\DeepBeat\Concatenated_DeepBeat\train\Darren_conversion\tfs_plot"
            output_path_csv = r"R:\ENGR_Chon\Darren\Public_Database\DeepBeat\Concatenated_DeepBeat\train\Darren_conversion\tfs_csv"
            filename_menu = 'DeepBeat_segment_names_labels.csv'
        elif self.database == 'MIMICIII' or self.database == 'mimiciii' or self.database == 'mimicIII' or self.database == 'mimic3':
            input_path = r"R:\ENGR_Chon\Darren\Public_Database\PPG_PeakDet_MIMICIII\Darren_conversion\test_1d_csv"
            menu_path = r"R:\ENGR_Chon\Darren\Public_Database\PPG_PeakDet_MIMICIII\Darren_conversion"
            output_path_fig = r"R:\ENGR_Chon\Darren\Public_Database\PPG_PeakDet_MIMICIII\Darren_conversion\test_tfs_plot"
            output_path_csv = r"R:\ENGR_Chon\Darren\Public_Database\PPG_PeakDet_MIMICIII\Darren_conversion\test_tfs_csv"
            filename_menu = '2020_Han_Sensors_MIMICIII_Ground_Truth.csv'
        elif self.database == 'Simband' or self.database == 'simband':
            input_path = r"R:\ENGR_Chon\Darren\Public_Database\PPG_PeakDet_Simband\Darren_conversion\ppg_1d_csv"
            menu_path = r"R:\ENGR_Chon\Darren\Public_Database\PPG_PeakDet_Simband\Darren_conversion"
            output_path_fig = r"R:\ENGR_Chon\Darren\Public_Database\PPG_PeakDet_Simband\Darren_conversion\tfs_plot"
            output_path_csv = r"R:\ENGR_Chon\Darren\Public_Database\PPG_PeakDet_Simband\Darren_conversion\tfs_csv"
            filename_menu = 'simband_segments_labels.csv'
        else:
            print('Invalid Database')
        
        Path(menu_path).mkdir(parents=True, exist_ok=True)
        Path(output_path_fig).mkdir(parents=True, exist_ok=True)
        Path(output_path_csv).mkdir(parents=True, exist_ok=True)
        
        if not os.path.isdir(input_path):
            # print(f"{input_path} does not exist, please check!")
            print("Input path does not exist, please check")
            print(input_path)

        return input_path, output_path_csv, output_path_fig, menu_path, filename_menu
    
    def create_menu(self):
        if isfile(os.path.join(self.menu_path, self.filename_menu)):
            print("Menu exists!")
            df_menu = pd.read_csv(os.path.join(self.menu_path, self.filename_menu))
        else:
            print("Menu does not exist!")
            # https://thispointer.com/python-get-list-of-files-in-directory-sorted-by-name/
            onlyfiles = sorted( filter( lambda x: os.path.isfile(os.path.join(self.input_path, x)), os.listdir(self.input_path) ) )
            df_menu = pd.DataFrame(onlyfiles, columns=["table_file_name"])
            df_menu.to_csv(os.path.join(self.menu_path, self.filename_menu),index=False)
        
        self.stop_idx = len(df_menu)
        return df_menu[['segment_names']] + '.csv', df_menu[['labels']]
            
    # Subfunctions:
    def my_STFT_TFS(self,sample_rate,input_sig,newsize):
        y = input_sig
        y = y.reshape(-1)
        my_hop_length = int(sample_rate/4) # 12's resolution (129x126) is better than 25 (129x61)
        my_nfft=256
        # D = librosa.stft(y,hop_length=my_hop_length,n_fft=256)  # STFT of y
        D = np.abs(librosa.stft(y,hop_length=my_hop_length,n_fft=my_nfft))**2
        S_db = librosa.amplitude_to_db(D, ref=np.max)
        image_S_db = Image.fromarray(S_db) # Convert np array to PIL image.
        # newsize=(128,128)
        S_db_resample = np.array(image_S_db.resize(newsize)) # Resize, then convert from PIL image back to np array.
        return S_db_resample, my_hop_length, my_nfft, S_db, image_S_db
    
    def my_plot_STFT_TFS(self, filename, input_sig, S_db_resample, label):
        y = input_sig
        # Show the matrix as image.
        fig, axs = plt.subplots(2, figsize=(20, 12), dpi=80, gridspec_kw={'height_ratios': [1, 3]}, constrained_layout=True)
        axs[0].plot(y)
        axs[0].set(title='PPG signal, ' + filename)
        axs[0].set_xlabel('Sample', fontsize=16)
        axs[0].set_ylabel('a.u.', fontsize=16)
        axs[0].set_xlim([0, len(y)])
        start, end = axs[0].get_xlim()
        axs[0].xaxis.set_ticks(np.append(np.arange(int(start), int(end), 50), int(end)))

        axs[1].imshow(S_db_resample, aspect="auto", origin='lower')
        axs[1].set(title='STFT output matrix S_db (amplitude_to_db)')
        axs[1].set_xlabel('Matrix Width', fontsize=16)
        axs[1].set_ylabel('Matrix Height', fontsize=16)
        start, end = axs[1].get_xlim()
        axs[1].xaxis.set_ticks(np.append(np.arange(int(start), int(end), 10), int(end)))
        start, end = axs[1].get_ylim()
        axs[1].yaxis.set_ticks(np.append(np.arange(int(start), int(end), 10), int(end)))

        # Add label to the figure
        axs[1].text(0.95, 0.95, f'Label: {label}', transform=axs[1].transAxes, horizontalalignment='right', verticalalignment='top', fontsize=12, bbox=dict(facecolor='white', alpha=0.5))

        fig_name = filename + "_STFT_output_matrix.png"
        fig.savefig(os.path.join(self.output_path_fig, fig_name))  # bbox_inches='tight'
    
    def watch_data_loading(self,filename):
        data = pd.read_csv(os.path.join(self.input_path,filename), header=None)
        data = np.array(data)  
        return data

    # My main function:
    def my_main_func(self, sample_rate=50, resample=None, filter=False, plot=False):
        df_menu, labels = self.create_menu()
        step = 1

        for rr in range(self.start_idx, self.stop_idx, step):
            filename = df_menu['segment_names'][rr] # '005_2019_09_16_13_57_00_ppg_0008_filt.csv'
            label = labels['labels'][rr]
            output_filename = os.path.join(self.output_path_csv,filename[:-4]+"_STFT.csv")
            # datetime object containing current date and time
            now = datetime.now()
            # dd/mm/YY H:M:S
            dt_string = now.strftime("%Y/%m/%d %H:%M:%S.%f")
            print("---------", dt_string,",",rr,"/",self.stop_idx,",",filename[:-4],"---------")	
            p_filt_data = self.watch_data_loading(filename)
            
            # p_filt_data = ppg_filt_data[rr]
            newsize=(128,128)
            input_sig = p_filt_data
            
            if resample is not None:
                input_sig = resample_signal(input_sig, sample_rate, resample)   
            if filter:
                input_sig = bandpass_filter(input_sig, sample_rate)
            
            S_db_resample, my_hop_length, my_nfft, S_db, image_S_db = self.my_STFT_TFS(sample_rate,input_sig,newsize)
            
            if plot:
                self.my_plot_STFT_TFS(filename[:-4], input_sig, S_db_resample, label)
                print("Saved TFS plot.")
                
            # Save TFS as csv
            np.savetxt(output_filename,S_db_resample,delimiter=',')
            print("Saved TFS csv.")
            
            plt.close('all')

In [None]:
# Create an instance of the MyMain class
my_main_instance = MyMain()

# Set the necessary attributes using the setter method
my_main_instance.set_MyMain = (0, 'simband')

# Call the main function
# Pulsewatch, Simband, MIMIC III = 50 Hz
# DeepBeat = 32 Hz, Resampled to 50 Hz in 1D PPG already (no need for additional resampling)
my_main_instance.my_main_func(sample_rate=50, resample=None, filter=True, plot=False) 