In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import h5py
np.set_printoptions(suppress=True)

In [None]:
data_path = r'Z:\zym\IS\control_6\plx'
filelist = os.listdir(data_path)
plxlist=[]
for f in filelist:
    if f.endswith('.plx'):
        plxlist.append(f[:-4])

In [None]:
ptd_all = []
wave_data = []
unit_name = []
for idx,file_name in enumerate(plxlist):
    wave = pd.read_excel(os.path.join(data_path, file_name+'waveform.xlsx'),engine='openpyxl')
    wave = wave[wave.peak_trough_distance != 0]
    wave_data.extend(wave.values[:,0:3].tolist())
    unit_name.extend((wave.values[:,0]).tolist())
unit_name = list(dict.fromkeys(unit_name))
unit_name.sort()
for u in unit_name:
    ptd = []
    for w in wave_data:
        if w[0] == u:
            ptd.append(w[2])
    ptd_all.append(np.mean(ptd))
ptd_avg = np.array(ptd_all)
narrow_idx = np.argwhere((ptd_avg<0.4)|(ptd_avg==0.4)).flatten()
narrow_idx = narrow_idx.tolist()
broad_idx = np.argwhere(ptd_avg>0.4).flatten()
broad_idx = broad_idx.tolist()
cell_type = np.zeros_like(ptd_avg).tolist()
for i in narrow_idx:
    cell_type[i] = 'narrow'
for i in broad_idx:
    cell_type[i] = 'broad'

In [None]:
unit_narrow = []
for i in narrow_idx:
    unit_narrow.append(unit_name[i])
unit_broad = []
for i in broad_idx:
    unit_broad.append(unit_name[i])

In [None]:
data = pd.DataFrame(index=unit_name,columns=['trough to peak','type'])
data['trough to peak'] = ptd_avg
data['type'] = cell_type

In [None]:
data.to_excel(os.path.join(data_path,'info.xlsx'))

In [None]:
gap = 5 #min
narrow_fr = []
broad_fr = []
narrow_unit = []
broad_unit = []
fr = []
unit_all = []
for file_name in plxlist:
    wave = pd.read_excel(os.path.join(data_path, file_name+'waveform.xlsx'),engine='openpyxl')
    timestamp=h5py.File(os.path.join(data_path, file_name+'timestamp.mat'),'r')
    start = timestamp['timestamp/start'][:].flatten()
    #stop = timestamp['timestamp/stop'][:].flatten()
    ts_all = []
    ts_fr_all = []
    unit_list = list(wave.unit)
    unit_ = []
    unit_n = []
    unit_b = []
    for idx, i in enumerate(unit_list):
        ts_list = []
        ts_fr_list = []
        name = 'timestamp/'+i
        unit_.append(i)
        if i in unit_narrow:
            unit_n.append(idx)
        if i in unit_broad:
            unit_b.append(idx)
        ts_total = timestamp[name][:].flatten()
        for j in range(len(start)):
            ts = ts_total[(ts_total>start[j]) & (ts_total<(start[j]+gap*60))]
            ts_num = len(ts)
            ts_list.append(ts)
            ts_fr_list.append(ts_num/((start[j]+gap*60)-start[j]))
        ts_all.append(ts_list)
        ts_fr_all.append(ts_fr_list)
    ts_fr_all = np.array(ts_fr_all)
    fr_l = []
    for idx in range(len(unit_list)):
        fr_l.append(ts_fr_all[idx,:])
    fr.append(fr_l)
    unit_all.append(unit_)
    if len(unit_n) != 0:
        narrow_fr_l = []
        for idx in unit_n:
            narrow_fr_l.append(ts_fr_all[idx,:])
        narrow_fr.append(narrow_fr_l)
        narrow_u = [unit_list[unit_n[i]] for i in range(len(unit_n))]
        narrow_unit.append(narrow_u)
    if len(unit_b) != 0:
        broad_fr_l = []
        for idx in unit_b:
            broad_fr_l.append(ts_fr_all[idx,:])
        broad_fr.append(broad_fr_l)
        broad_u = [unit_list[unit_b[i]] for i in range(len(unit_b))]
        broad_unit.append(broad_u)
if len(unit_b) != 0:
    broad_fr_list = []
    for f in range(len(broad_fr)):
        broad_fr_list.append(np.array(broad_fr[f]).flatten())
    broad_fr = np.array(broad_fr_list)
if len(unit_n) != 0:
    narrow_fr_list = []
    for f in range(len(narrow_fr)):
        narrow_fr_list.append(np.array(narrow_fr[f]).flatten())
    narrow_fr = np.array(narrow_fr_list)

In [None]:
narrow_df = pd.DataFrame(columns=unit_narrow,index=np.arange(0,18,1))
for n in range(len(narrow_unit)):
    narrow_df.loc[n,narrow_unit[n]] = narrow_fr[n]
narrow_df.fillna(0,inplace=True)
broad_df = pd.DataFrame(columns=unit_broad,index=np.arange(0,18,1))
for n in range(len(broad_unit)):
    broad_df.loc[n,broad_unit[n]] = broad_fr[n]
broad_df.fillna(0,inplace=True)
all_fr = pd.concat([broad_df,narrow_df],axis=1)

In [None]:
mean_broad = broad_df.mean(axis=1).values
std_broad = (broad_df.values.std(axis=1))/(np.sqrt(len(unit_broad)))
mean_narrow = narrow_df.mean(axis=1).values
std_narrow= (narrow_df.values.std(axis=1))/(np.sqrt(len(unit_narrow)))

In [None]:
data = pd.DataFrame(columns=all_fr.columns,index = ['0','5','15','30','45','60']*3)
for i in range(len(all_fr.columns)):
    data[all_fr.columns[i]] = all_fr.values[:,i]
    fig, ax = plt.subplots(figsize=(10,5))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth('1.2')
    ax.spines['left'].set_linewidth('1.2')
    x = np.arange(0,18,1)
    plt.plot(x,all_fr.values[:,i],alpha=0.8,c='steelblue')
    plt.scatter(x, all_fr.values[:,i],c='darkgray',edgecolors='k')
    plt.ylabel('Firing Rate(Hz)',fontsize=14)
    plt.xticks(np.arange(0,18,1),['0','5','15','30','45','60']*3)
    plt.xlabel('Trial(D1-D3)',fontsize=14)
    plt.title(all_fr.columns[i])
    plt.savefig(os.path.join(data_path, '%d_'%gap+'%s_IS.jpg'%all_fr.columns[i]),dpi=300,bbox_inches = 'tight')
    plt.close()
data.to_excel(os.path.join(data_path, '%dmin_data.xlsx'%gap))

In [None]:
fig, ax = plt.subplots(figsize=(10,5))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_linewidth('1.2')
ax.spines['left'].set_linewidth('1.2')
x = np.arange(0,18,1)

if len(narrow_idx) != 0:
    plt.plot(x,mean_narrow,alpha=0.8,c='slategray')
    #for i in range(len(narrow_idx)):
        #plt.plot(narrow_fr_all[i],c='slategray',alpha=0.3)
    plt.scatter(x, mean_narrow,c='slategray',edgecolors='k')
    plt.fill_between(x,mean_narrow-std_narrow,mean_narrow+std_narrow,color='darkgray',alpha=0.1)

if len(broad_idx) != 0:
    plt.plot(x,mean_broad,alpha=0.8,c='steelblue')
    #for i in range(len(broad_idx)):
        #plt.plot(broad_fr_all[i],c='steelblue',alpha=0.3)
    plt.scatter(x, mean_broad,c='steelblue',edgecolors='k')
    plt.fill_between(x,mean_broad-std_broad,mean_broad+std_broad,color='lightblue',alpha=0.1)
plt.ylabel('Firing Rate(Hz)',fontsize=14)
plt.xticks(np.arange(0,18,1),['0','5','15','30','45','60']*3)
plt.xlabel('Trial(D1-D3)',fontsize=14)
if (len(narrow_idx) != 0) and (len(broad_idx) != 0):
    plt.legend(['narrow','broad'])
elif (len(narrow_idx) != 0) and (len(broad_idx) == 0):
    plt.legend(['narrow'])
elif (len(narrow_idx) == 0) and (len(broad_idx) != 0):
    plt.legend(['broad'])
plt.tight_layout()
plt.savefig(os.path.join(data_path,'%d_'%gap + 'cell'+'.jpg'),dpi=400)
plt.show()