In [None]:
import os
import re
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from tslearn.utils import to_time_series_dataset
from tslearn.metrics import dtw
from tslearn.clustering import TimeSeriesKMeans
from tslearn.utils import to_time_series

In [None]:
#Function from NeuroChat - read LFP
def load_lfp_Axona(file_name):

    file_directory, file_basename = os.path.split(file_name)
    file_tag, file_extension = os.path.splitext(file_basename)
    file_extension = file_extension[1:]
    set_file = os.path.join(file_directory, file_tag + '.set')
    if os.path.isfile(file_name):
        with open(file_name, 'rb') as f:
            while True:
                line = f.readline()
                try:
                    line = line.decode('latin-1')
                except BaseException:
                    break

                if line == '':
                    break
                if line.startswith('trial_date'):
                    # Blank eeg file
                    if line.strip() == "trial_date":
                        total_samples = 0
                        return
                    date = (
                        ' '.join(line.replace(',', ' ').split()[1:]))
                if line.startswith('trial_time'):
                    time = (line.split()[1])
                if line.startswith('experimenter'):
                    experimenter = (' '.join(line.split()[1:]))
                if line.startswith('comments'):
                    comments = (' '.join(line.split()[1:]))
                if line.startswith('duration'):
                    duration = (float(''.join(line.split()[1:])))
                if line.startswith('sw_version'):
                    file_version = (line.split()[1])
                if line.startswith('num_chans'):
                    total_channel = (int(''.join(line.split()[1:])))
                if line.startswith('sample_rate'):
                    sampling_rate = (
                        float(''.join(re.findall(r'\d+.\d+|\d+', line))))
                if line.startswith('bytes_per_sample'):
                    bytes_per_sample = (
                        int(''.join(line.split()[1:])))
                if line.startswith(
                        'num_' + file_extension[:3].upper() + '_samples'):
                    total_samples = (int(''.join(line.split()[1:])))
                if line.startswith("data_start"):
                    break

            num_samples = total_samples
            f.seek(0, 0)
            header_offset = []
            while True:
                try:
                    buff = f.read(10).decode('UTF-8')
                except BaseException:
                    break
                if buff == 'data_start':
                    header_offset = f.tell()
                    break
                else:
                    f.seek(-9, 1)

            eeg_ID = re.findall(r'\d+', file_extension)
            file_tag = (1 if not eeg_ID else int(eeg_ID[0]))
            max_ADC_count = 2**(8 * bytes_per_sample - 1) - 1
            max_byte_value = 2**(8 * bytes_per_sample)

            with open(set_file, 'r', encoding='latin-1') as f_set:
                lines = f_set.readlines()
                channel_lines = dict(
                    [tuple(map(int, re.findall(r'\d+.\d+|\d+', line)[0].split()))
                        for line in lines if line.startswith('EEG_ch_')]
                )
                channel_id = channel_lines[file_tag]
                channel_id = (channel_id)

                gain_lines = dict(
                    [tuple(map(int, re.findall(r'\d+.\d+|\d+', line)[0].split()))
                        for line in lines if 'gain_ch_' in line]
                )
                gain = gain_lines[channel_id - 1]

                for line in lines:
                    if line.startswith('ADC_fullscale_mv'):
                        fullscale_mv = (
                            int(re.findall(r'\d+.\d+|d+', line)[0]))
                        break
                AD_bit_uvolt = 2 * fullscale_mv / \
                    (gain * np.power(2, 8 * bytes_per_sample))

            record_size = bytes_per_sample
            sample_le = 256**(np.arange(0, bytes_per_sample, 1))

            if not header_offset:
                print('Error: data_start marker not found!')
            else:
                f.seek(header_offset, 0)
                byte_buffer = np.fromfile(f, dtype='uint8')
                len_bytebuffer = len(byte_buffer)
                end_offset = len('\r\ndata_end\r')
                lfp_wave = np.zeros([num_samples, ], dtype=np.float64)
                for k in np.arange(0, bytes_per_sample, 1):
                    byte_offset = k
                    sample_value = (
                        sample_le[k] * byte_buffer[byte_offset:byte_offset + len_bytebuffer - 
                                                   end_offset - record_size:record_size])
                    if sample_value.size < num_samples:
                        sample_value = np.append(sample_value, np.zeros(
                            [num_samples - sample_value.size, ]))
                    sample_value = sample_value.astype(
                        np.float64, casting='unsafe', copy=False)
                    np.add(lfp_wave, sample_value, out=lfp_wave)
                np.putmask(lfp_wave, lfp_wave > max_ADC_count,
                            lfp_wave - max_byte_value)

                samples = (lfp_wave * AD_bit_uvolt)
                timestamp = (
                    np.arange(0, num_samples, 1) / sampling_rate)
                return samples

    else:
        print("No lfp file found for file {}".format(file_name))
        
        
class RecPos:
    """
    This data class contains information about the recording position.
    Read .pos file
    To dos:
        * read different numbers of LEDs
        * Adapt to NeuroChat
    Attributes
    ----------
    _file_tag : str
        The tag of the pos data.
    """
    def __init__(self, file_name):

        self.bytes_per_sample = 20 # Axona daqUSB manual
        file_directory, file_basename = os.path.split(file_name)
        file_tag, file_extension = os.path.splitext(file_basename)
        file_extension = file_extension[1:]
        self.pos_file = os.path.join(file_directory, file_tag + '.pos')
        if os.path.isfile(self.pos_file):
            with open(self.pos_file, 'rb') as f:
                while True:
                    line = f.readline()
                    try:
                        line = line.decode('latin-1')
                    except BaseException:
                        break

                    if line == '':
                        break
                    if line.startswith('trial_date'):
                        # Blank eeg file
                        if line.strip() == "trial_date":
                            total_samples = 0
                            print('No position data.')
                            return
                        date = (
                            ' '.join(line.replace(',', ' ').split()[1:]))
                    if line.startswith('num_colours'):
                        colors = (int(line.split()[1]))
                    if line.startswith('min_x'):
                        self.min_x = (int(line.split()[1]))
                    if line.startswith('max_x'):
                        self.max_x = (int(line.split()[1]))
                    if line.startswith('min_y'):
                        self.min_y = (int(line.split()[1]))
                    if line.startswith('max_y'):
                        self.max_y = (int(line.split()[1]))
                    if line.startswith('window_min_x'):
                        self.window_min_x = (int(line.split()[1]))
                    if line.startswith('window_max_x'):
                        self.window_max_x = (int(line.split()[1]))
                    if line.startswith('window_min_y'):
                        self.window_min_y = (int(line.split()[1]))
                    if line.startswith('window_max_y'):
                        self.window_max_y = (int(line.split()[1]))
                    if line.startswith('bytes_per_timestamp'):
                        self.bytes_per_tstamp = (int(line.split()[1]))
                    if line.startswith('bytes_per_coord'):
                        self.bytes_per_coord = (int(line.split()[1]))
                    if line.startswith('pixels_per_metre'):
                        self.pixels_per_metre = (int(line.split()[1]))
                    if line.startswith('num_pos_samples'):
                        self.total_samples = (int(line.split()[1]))
                    if line.startswith("data_start"):
                        break

                f.seek(0, 0)
                header_offset = []
                while True:
                    try:
                        buff = f.read(10).decode('UTF-8')
                    except BaseException:
                        break
                    if buff == 'data_start':
                        header_offset = f.tell()
                        break
                    else:
                        f.seek(-9, 1)

                if not header_offset:
                    print('Error: data_start marker not found!')
                else:
                    f.seek(header_offset, 0)
                    byte_buffer = np.fromfile(f, dtype='uint8')
                    len_bytebuffer = len(byte_buffer)
                    end_offset = len('\r\ndata_end\r')
                    num_samples = int(len((byte_buffer)- end_offset)/20)
                    big_spotx = np.zeros([self.total_samples,1])
                    big_spoty = np.zeros([self.total_samples,1])
                    little_spotx = np.zeros([self.total_samples,1])
                    little_spoty = np.zeros([self.total_samples,1])
                    # pos format: t,x1,y1,x2,y2,numpix1,numpix2 => 20 bytes
                    for i, k in enumerate(np.arange(0, self.total_samples*20, 20)): # Extract bytes from 20 bytes words
                        byte_offset = k
                        big_spotx[i] = int(256 * byte_buffer[k+4] + byte_buffer[k+5])  # 4,5 bytes for big LED x
                        big_spoty[i] = int(256 * byte_buffer[k+6] + byte_buffer[k+7])  # 6,7 bytes for big LED x
                        little_spotx[i] = int(256 * byte_buffer[k+4] + byte_buffer[k+5])
                        little_spoty[i] = int(256 * byte_buffer[k+6] + byte_buffer[k+7])
                        
                    self.raw_position = {'big_spotx': big_spotx, 
                                         'big_spoty':big_spoty,
                                         'little_spotx':little_spotx, 
                                         'little_spoty':little_spoty}

        else:
            print(f"No pos file found for file {file_name}")


    def get_cam_view(self):
        self.cam_view = {'min_x':self.min_x, 'max_x':self.max_x,
                    'min_y':self.min_y, 'max_y':self.max_y}
        return self.cam_view
    
    
    def get_tmaze_start(self):
        x,y = self.get_position()
        a = x[100:250] 
        b = y[100:250]
        a = pd.Series([n if n != 1023 else np.nan for n in a])
        b = pd.Series([n if n != 1023 else np.nan for n in b])
        a.clip(0, 500, inplace=True)
        b.clip(0, 500, inplace=True)
        a.fillna(method = 'backfill', inplace = True)
        b.fillna(method = 'backfill', inplace = True)
        if a.mean() < 200 and b.mean() > 300:
            start = 'top left'
        elif a.mean()  > 400 and b.mean() > 300:
            start = 'top right'
        elif a.mean()  < 200 and b.mean() < 200:
            start = 'down left'
        elif a.mean()  >300 and b.mean() < 200:
            start = 'down right'
        else:
            start = 'impossible to find'
        return start

    def get_window_view(self):
        try:
            self.windows_view = {'window_min_x':self.window_min_x, 'window_max_x':self.window_max_x,
                            'window_min_y':self.window_min_y, 'window_max_y':self.window_max_y}
            return self.windows_view
        except:
            print('No window view')

    def get_pixel_per_metre(self):
        return self.pixels_per_metre

    def get_raw_pos(self):
        bigx = [value[0] for value in self.raw_position['big_spotx']]
        bigy = [value[0] for value in self.raw_position['big_spoty']]

        return bigx,bigy
    
    
    def filter_max_speed(self, x, y, max_speed = 3): # max speed 4m/s ()
            tmp_x = x.copy()
            tmp_y = y.copy()
            for i in range(1, len(tmp_x)-1):
                if (math.sqrt((x[i]- x[i-1])**2 + (y[i] - y[i-1])**2)) > (max_speed * self.pixels_per_metre):
                    tmp_x[i] = 1023
                    tmp_y[i] = 1023
            return tmp_x, tmp_y  
    
    def get_position(self): 
        count_missing = 0
        bxx, sxx = [], []
        byy, syy = [], []
        bigx = [value[0] for value in self.raw_position['big_spotx']]
        bigy = [value[0] for value in self.raw_position['big_spoty']]
        smallx = [value[0] for value in self.raw_position['little_spotx']]
        smally = [value[0] for value in self.raw_position['little_spoty']]
        for bx, sx in zip(bigx, smallx): # Try to clean single blocked LED x
            if bx == 1023 and sx != 1023:
                bx = sx
            elif bx != 1023 and sx == 1023:
                sx = bx
            elif bx == 1023 and sx == 1023:
                count_missing +=1
                bx = np.nan
                sx = np.nan
    
            bxx.append(bx)
            sxx.append(sx)

        for by, sy in zip(bigy, smally): # Try to clean single blocked LED y
            if by == 1023 and sy != 1023:
                by = sy
            elif by != 1023 and sy == 1023:
                sy = by
            elif by == 1023 and sy == 1023:
                by = np.nan
                sy = np.nan
            byy.append(by)
            syy.append(sy)
            
        ### Remove coordinates with max_speed > 4ms
        bxx, byy = self.filter_max_speed(bxx, byy)
        sxx, syy = self.filter_max_speed(sxx, syy)
        
        ### Interpolate missing values
        bxx = (pd.Series(bxx).astype(float)).interpolate('cubic')
        sxx = (pd.Series(sxx).astype(float)).interpolate('cubic')
        byy = (pd.Series(byy).astype(float)).interpolate('cubic')
        syy = (pd.Series(syy).astype(float)).interpolate('cubic')

        ### Average both LEDs
        x = list((bxx + sxx)/2)
        y = list((byy + syy)/2)
        
        return list(x), list(y)
    
    def get_speed(self):
        print('Not implemented')
        pass

    def get_angular_pos(self):
        print('Not implemented')
        pass

In [None]:
import numpy
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt

from tslearn.generators import random_walks
from tslearn.preprocessing import TimeSeriesScalerMeanVariance
from tslearn import metrics

numpy.random.seed(0)

file1 = r'/mnt/d/Beths/CanCCaRet1/tmaze/s3_27022019/t2/27022019_CanCCaRet1_tmaze_3_2.eeg'
file2 = r'/mnt/d/Beths/CanCCaRet1/tmaze/s3_27022019/t2/27022019_CanCCaRet1_tmaze_3_2.eeg3'

s_x = load_lfp_Axona(file1)[0:(250*10)]
s_y = load_lfp_Axona(file2)[0:(250*10)]

s_y1 = numpy.concatenate((s_x, s_x)).reshape((-1, 1))
s_y2 = numpy.concatenate((s_y, s_y[::-1])).reshape((-1, 1))
sz = s_y1.shape[0]

path, sim = metrics.dtw_path(s_y1, s_y2)

plt.figure(1, figsize=(8, 8))

# definitions for the axes
left, bottom = 0.01, 0.1
w_ts = h_ts = 0.2
left_h = left + w_ts + 0.02
width = height = 0.65
bottom_h = bottom + height + 0.02

rect_s_y = [left, bottom, w_ts, height]
rect_gram = [left_h, bottom, width, height]
rect_s_x = [left_h, bottom_h, width, h_ts]

ax_gram = plt.axes(rect_gram)
ax_s_x = plt.axes(rect_s_x)
ax_s_y = plt.axes(rect_s_y)

mat = cdist(s_y1, s_y2)

ax_gram.imshow(mat, origin='lower')
ax_gram.axis("off")
ax_gram.autoscale(False)
ax_gram.plot([j for (i, j) in path], [i for (i, j) in path], "w-",
             linewidth=3.)

ax_s_x.plot(numpy.arange(sz), s_y2, "b-", linewidth=.5)
ax_s_x.axis("off")
ax_s_x.set_xlim((0, sz - 1))

ax_s_y.plot(- s_y1, numpy.arange(sz), "b-", linewidth=.5)
ax_s_y.axis("off")
ax_s_y.set_ylim((0, sz - 1))

# plt.tight_layout()
plt.show()