In [None]:
import pandas as pd

file_path = "trace_220521-10-1.xls"
df = pd.read_excel(file_path)
df = df.drop('File Name',axis=1)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
#波の位置に印をつける。
import numpy as np
from scipy.signal import find_peaks
from sklearn.linear_model import LinearRegression  

In [None]:

%matplotlib widget
class wave_analysis():
    def __init__(self, df):
        self.df= df
        #スパイクの名前のリスト
        self.options =  list(dict.fromkeys(label.split(' ')[0] for label in self.df.columns))
    def make_figure(self,std=False,avg_wave=False, lr=False, points=False):
        
        # ドロップダウン作成
        dropdown = widgets.Dropdown(
            options=self.options,
            description='Channel:',
            value=self.options[0]
        )

        spk = dropdown.value

        label_val = spk + " Values"
        label_std = spk + " St. Dev."
        label_t = spk + " Timestamps"
        labels = [spk, label_t, label_val, label_std]


        #グラフ生成
        fig, ax = plt.subplots(figsize=(6, 4))

        # イベントを監視
        # コールバック関数（ドロップダウン変更時に呼ばれる）
        def on_change(change):
            nonlocal ax, fig, labels #make_figure()で定義したspkを入れ子関数内で扱う
            if change['type'] == 'change' and change['name'] == 'value':
                spk_new = change['new']
                labels[0] = spk_new
                labels[1] = spk_new + " Timestamps"
                labels[2] = spk_new + " Values"
                labels[3] = spk_new + " St. Dev."
                self.update(ax, fig, labels,std=std, avg_wave=avg_wave, lr=lr, points=points)

        display(dropdown)
        dropdown.observe(on_change)
        #初期のグラフ
        self.update(ax, fig, labels, std= std, avg_wave=avg_wave, lr=lr, points=points)


    def update(self, ax,fig, labels, tmin=0, tmax = 0.1,std=False,avg_wave = False, lr=False, points=False):
        ax.clear()
        """ 
        いろんなプロットを行う
        
        """
        self.plot_wave(ax, labels)
        if std == True:self.plot_std(ax, labels)
        if avg_wave == True: self.plot_avgwave(ax)
        if points ==True: self.plot_points(ax, labels, tmin, tmax)
        if lr ==True: self.plot_spk_onset(ax, labels)
        ax.legend(loc='lower left', fontsize = 'small')
        ax.grid(True)
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Value (mV?)')
        ax.set_title(labels[0] + " Waveform with peaks")
        fig.tight_layout()
        fig.canvas.draw()  

    def plot_wave(self,ax, labels):
        #波の描画
        ax.plot(self.df[labels[1]], self.df[labels[2]], label=labels[0])

    def plot_avgwave(self,ax):
        if 'avg Values' in df.columns:
            ax.plot(df['avg Timestamps'],df['avg Values'], '-', label='Avg spk')
        else:
            self.make_avgwave_data()
            ax.plot(df['avg Timestamps'],df['avg Values'], '-', label='Avg spk')
            self.drop_avgwave_data()
        
        #平均の波
        ax.plot(self.df['avg Timestamps'], self.df['avg Values'], label="average_wave")

    def plot_std(self, ax, labels):
        ax.plot(self.df[labels[1]], self.df[labels[3]], label=labels[3])


    def find_points(self, labels, tmin=0, tmax=0.1):
        lmax, _ = find_peaks(self.df[labels[2]], distance=15, prominence=0.01) #極大のｔ
        lmin, _ = find_peaks(-self.df[labels[2]], distance=15, prominence=0.01) #極小のｔ
        lmax_in_range = [p for p in lmax if tmin <= self.df[labels[1]][p] <= tmax] #tmaxとtminを満たす極大のｔ
        lmin_in_range = [t for t in lmin if tmin <= self.df[labels[1]][t] <= tmax] #tmaxとtminを満たす極小のｔ

        # シグナル検出
        signal, _ = find_peaks(-self.df[labels[2]], distance=15, prominence=0.03, height=0.04)
        sig_in_range = [m for m in signal if tmin+0.001 <= self.df[labels[1]][m] <= tmin + 0.005]

        return lmax_in_range, lmin_in_range, sig_in_range

    def plot_points(self, ax, labels, tmin=0, tmax=0.1):
        # 局所最大値・最小値検出
        # find_peaksのオプションでピークの高さや間隔を制御可能
        lmax_in_range, lmin_in_range, sig_in_range = self.find_points(labels, tmin, tmax)

        ax.plot(self.df[labels[1]][lmax_in_range], self.df[labels[2]][lmax_in_range], 'ro', label='Local maxima') 
        ax.plot(self.df[labels[1]][lmin_in_range], self.df[labels[2]][lmin_in_range], 'bo', label='Local minima')
        # シグナルの位置をプロット
        ax.plot(self.df[labels[1]][sig_in_range], self.df[labels[2]][sig_in_range], 'yo', markersize= 7, label='Signal')   
        ax.text(self.df[labels[1]][sig_in_range].values[0], self.df[labels[2]][sig_in_range].values[0]-0.02, f"{self.df[labels[1]][sig_in_range].values[0]}s", ha='center')

    def get_linear_regression(self, labels, p=0.3):
        lmax_in_range, lmin_in_range, sig_in_range = self.find_points(labels, tmin=0, tmax=0.01)

                # 直線を出す範囲のデータを取得
        #最初の極大から、シグナルの位置まで。
        n=round((sig_in_range[0]+1 - lmax_in_range[0] + 1)*p) #線の抽出範囲を真ん中に寄せる
        data_t = self.df[labels[1]].iloc[lmax_in_range[0]+n:sig_in_range[0]+1].to_frame()
        data_val = self.df[labels[2]].iloc[lmax_in_range[0]+n:sig_in_range[0]+1]
        # 線形回帰モデル
        model = LinearRegression()
        model.fit(data_t, data_val)
                # 回帰直線のデータ
        t_vals = np.linspace(data_t.min(), data_t.max(), 100).reshape(-1, 1) 
        t_vals_df = pd.DataFrame(t_vals, columns=data_t.columns)
        y_vals = model.predict(t_vals_df)

        return model, t_vals_df, y_vals

    def plot_spk_onset(self, ax, labels, p = 0.3):
        #pは最初のどのくらいお取り除くか0 <= p < 1

        # 回帰直線のデータ
        _, t_vals_df, y_vals = self.get_linear_regression(labels, p)

        ax.plot(t_vals_df, y_vals, color='red', label='linear_regression')
    def get_options(self):
        return self.options
    


    def get_spk_onset(self, name: str = None, num: int = None, p: float = 0.3):
        #返り値は係数と切片
         #pは最初のどのくらいお取り除くか0 <= p < 1
        if name is None and num is not None: spk = self.options[num] #intで入力されたとき
        if name is not  None and num is None: spk = name #"SPKC01"みたいな感じの入力の時
        label_val = spk + " Values"
        label_t = spk + " Timestamps"
        labels = [spk, label_t, label_val]
                # 直線を出す範囲のデータを取得
        #最初の極大から、シグナルの位置まで。
        model, _, _= self.get_linear_regression(labels, p)
        return model.coef_[0], model.intercept_

    def make_avgwave_data(self):
        # dfにavgwaveを加える
        columns = [item + " Values" for item in self.options]
        self.df['avg Values'] = self.df[columns].mean(axis =1)
        self.df['avg Timestamps'] = self.df[self.options[0]+ " Timestamps"]
        self.df['avg St. Dev.'] = df['avg Values'].rolling(window=7, center=True).std()


    def drop_avgwave_data(self):
        # dfからavgwaveを除く
        if 'avg Values' in self.df.columns: self.df.drop('avg Values', axis=1)
        if 'avg Timestamps' in self.df.columns: self.df.drop('avg Timestamps', axis=1)
        if 'avg St. Dev.' in self.df.columns: self.df.drop('avg St. Dev.', axis=1)
       

In [None]:
A = wave_analysis(df)
A.make_avgwave_data()
A.make_figure(avg_wave=True, std = True, lr = True, points = True)

In [None]:
A.get_options()

In [None]:
A.drop_avgwave_data()

In [None]:
options = A.get_options()
for items in options:
    print(f"{items}: y = {A.get_spk_onset(items)[0]}x + {A.get_spk_onset(items)[1]}")