In [None]:
# Tinnitus project phase reset/erp

from google.colab import drive
drive.mount('/content/gdrive')

import numpy as np
from numpy import exp
import dill    # pkl eval
from scipy.io import loadmat
import matplotlib.pyplot as plt
from scipy.signal import hilbert
import scipy.ndimage
import scipy as sp
from scipy import signal

from mpl_toolkits.axes_grid1 import make_axes_locatable

def butter_bandpass(data, lowcut, highcut, fs, order=5):
    b, a = butter_filter_band(lowcut, highcut, fs, order=order)
    y = signal.filtfilt(b, a, data)
    return y

def butter_filter_band(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = signal.butter(order, [low, high], btype='band')
    return b, a

## Tina put the parameters that you like here ##
# Stimulus Frequency # i have to figure out how to find an elegant way to include all of them
stim_freqs = np.array(['1000', '3000', '5000', '7000', '9000', '11000'])
print('Stimulus frequency ', stim_freqs)
# BandPass filter params
# 
Delta = np.array([1., 4.])
Theta = np.array([4., 8.])
Alpha = np.array([8., 13.])
Beta = np.array([13., 30.])
Gamma = np.array([30., 100.])
fnames = ['delta', 'theta', 'alpha', 'beta', 'gamma']
lf_cutoffs = np.array([Delta[0], Theta[0], Alpha[0], Beta[0], Gamma[0]])
hf_cutoffs = np.array([Delta[1], Theta[1], Alpha[1], Beta[1], Gamma[1]])
print('Bandpass filter values ', [lf_cutoffs, hf_cutoffs])
# Cz index from 1 to 69 (or whatever)
cz_index = 69

# just testing the other file, this one might be corrupted
data = loadmat('gdrive/My Drive/Experiment 12-08-19/final MAT.mat')
channels = data['EEG']['data'][0][0]
labels = data['EEG']['event'][0][0]['type'][0]


print('Data size: ', channels.shape[0], channels.shape[1], channels.shape[2])
# cz is 69th? like always?
cz = np.squeeze(channels[cz_index-1, :, :])

for stim_freq in stim_freqs:
  ind_stim = (labels == stim_freq)
  cz_stim_orig = cz[:, ind_stim]

  # window
  x = np.arange(-100,2000)

  # take first 1.1s (100s before 1000ms after)
  plt.imshow(cz_stim_orig[0:1100,:].transpose(), aspect='auto', cmap='jet') #, vmin=0, vmax=np.max(y))  # , aspect='auto'
  plt.colorbar()
  plt.show()

  plt.plot(x, np.mean(cz_stim_orig, 1)) #, 'LineWidth', 2
  plt.plot([0, 0], [-4, 4]) #, 'LineWidth', 2)
  plt.plot([-100, 1999], [0, 0]) #, 'LineWidth', 2)
  plt.xlim([-100, 1000])
  plt.show()

  # filter, emd
  cz_stim = cz_stim_orig
  fs = 1000
  order = 2
  
  for cntf, val in enumerate(lf_cutoffs):
    #cz[cnt, :] = butter_bandpass(cz_orig[cnt, :], lf_cutoff, hf_cutoff, fs, order)
    for cntr in range(cz_stim_orig.shape[1]):
      cz_stim[:, cntr] = butter_bandpass(cz_stim_orig[:, cntr], lf_cutoffs[cntf], hf_cutoffs[cntf], fs, order)
      #print(lf_cutoffs[cntf])
      #print(hf_cutoffs[cntf])
      

    insta_phase_norms = np.zeros([cz_stim.shape[0], cz_stim.shape[1]])
    #insta_freq = np.zeros((cz_stim.shape[0], 2099))
    mean_if = np.zeros(insta_phase_norms.shape[1])
    std_if = np.zeros(insta_phase_norms.shape[1])

    for cnti in range(insta_phase_norms.shape[1]):
      y = hilbert(cz_stim[:, cnti])
      angles = np.angle(y)
      insta_phase = np.unwrap(angles) # should we ingore this and go straight to the normsss
      plt.plot(insta_phase)
      plt.show()
      insta_freq = np.diff(insta_phase)
      insta_freq[insta_freq < 0] = insta_freq[insta_freq < 0] + 2*np.pi
      insta_freq *= fs / (2*np.pi)
      mean_if[cnti] = np.mean(insta_freq)
      #print(np.mean(insta_freq))
      std_if[cnti] = np.std(insta_freq)
      insta_phase_norms[:, cnti] = (insta_phase + np.pi)/(2*np.pi) % 1.
      #plt.plot(insta_phase_norms[0:1000, cnti])
      #plt.show()
      #plt.savefig('inst_phases_epoch1', bbox_inches = 'tight', pad_inches = 0)

    nbins = 200
    hist_phases = np.zeros([nbins, insta_phase_norms.shape[0]])
    
    for cntt in range(insta_phase_norms.shape[0]):
      test = np.histogram(insta_phase_norms[cntt, :], nbins, (0, 1))  # calc hist wind_[ind, :]
      hist_phases[:, cntt] = test[0]
      #test_fr = np.histogram(insta_freq[cntt, :], nbins)


    plt.plot(mean_if)
    plt.show()

    sigma_y = 2.0
    sigma_x = 2.0
    # Apply gaussian filter
    sigma = [sigma_y, sigma_x]
    y = sp.ndimage.filters.gaussian_filter(hist_phases, sigma, mode='constant')

    plt.close('all')
    fig, ax = plt.subplots(1,1)
    img = plt.imshow(y[:, 0:1100], vmin=0, vmax=1.5)  # , aspect='auto'

    x_label_list = ['-100', '0', '500', '1000']
    ax.set_xticks([0,100,600,1100])
    ax.set_xticklabels(x_label_list)

    y_label_list = ['0', '0.5', '1']
    ax.set_yticks([0,100,199])
    ax.set_yticklabels(y_label_list)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(img, cax=cax)
    plt.show()
    plt.savefig('gdrive/My Drive/Experiment 12-08-19/figs/epoch_' + fnames[cntf] + '_' + stim_freq + 'hz.png', bbox_inches = 'tight', pad_inches = 0)