In [None]:
%matplotlib inline
import sys  
sys.path.insert(0, 'D:/Beths/')
# sys.path.insert(0, 'D:/NeuroChaT/neurochat/')
# from neurochat.nc_lfp import NLfp
import os
import re
from mne.preprocessing import ICA
import mne
import datetime
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import statistics
import math
import random
from sklearn import linear_model
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cbook as cbook
import matplotlib.cm as cm
from shapely.geometry import Point, Polygon
from matplotlib.collections import LineCollection
from matplotlib.ticker import MultipleLocator

## Import from my files
from data_lfp import mne_lfp_Axona, load_lfp_Axona
from data_pos import RecPos

#### **Class to read position files**

In [None]:
class RecPos:
    """
    This data class contains information about the recording position.
    Read .pos file
    To dos:
        * read different numbers of LEDs
        * Adapt to NeuroChat
    Attributes
    ----------
    _file_tag : str
        The tag of the pos data.
    """

    def __init__(self, file_name):

        self.bytes_per_sample = 20  # Axona daqUSB manual
        file_directory, file_basename = os.path.split(file_name)
        file_tag, file_extension = os.path.splitext(file_basename)
        file_extension = file_extension[1:]
        self.pos_file = os.path.join(file_directory, file_tag + ".pos")
        if os.path.isfile(self.pos_file):
            with open(self.pos_file, "rb") as f:
                while True:
                    line = f.readline()
                    try:
                        line = line.decode("latin-1")
                    except BaseException:
                        break

                    if line == "":
                        break
                    if line.startswith("trial_date"):
                        # Blank pos file
                        if line.strip() == "trial_date":
                            total_samples = 0
                            print("No position data.")
                            return
                        date = " ".join(line.replace(",", " ").split()[1:])
                    if line.startswith("num_colours"):
                        colors = int(line.split()[1])
                    if line.startswith("min_x"):
                        self.min_x = int(line.split()[1])
                    if line.startswith("max_x"):
                        self.max_x = int(line.split()[1])
                    if line.startswith("min_y"):
                        self.min_y = int(line.split()[1])
                    if line.startswith("max_y"):
                        self.max_y = int(line.split()[1])
                    if line.startswith("window_min_x"):
                        self.window_min_x = int(line.split()[1])
                    if line.startswith("window_max_x"):
                        self.window_max_x = int(line.split()[1])
                    if line.startswith("window_min_y"):
                        self.window_min_y = int(line.split()[1])
                    if line.startswith("window_max_y"):
                        self.window_max_y = int(line.split()[1])
                    if line.startswith("bytes_per_timestamp"):
                        self.bytes_per_tstamp = int(line.split()[1])
                    if line.startswith("bytes_per_coord"):
                        self.bytes_per_coord = int(line.split()[1])
                    if line.startswith("pixels_per_metre"):
                        self.pixels_per_metre = int(line.split()[1])
                    if line.startswith("num_pos_samples"):
                        self.total_samples = int(line.split()[1])
                    if line.startswith("data_start"):
                        break

                f.seek(0, 0)
                header_offset = []
                while True:
                    try:
                        buff = f.read(10).decode("UTF-8")
                    except BaseException:
                        break
                    if buff == "data_start":
                        header_offset = f.tell()
                        break
                    else:
                        f.seek(-9, 1)

                if not header_offset:
                    print("Error: data_start marker not found!")
                else:
                    f.seek(header_offset, 0)
                    byte_buffer = np.fromfile(f, dtype="uint8")
                    len_bytebuffer = len(byte_buffer)
                    end_offset = len("\r\ndata_end\r")
                    num_samples = int(len((byte_buffer) - end_offset) / 20)
                    big_spotx = np.zeros([self.total_samples, 1])
                    big_spoty = np.zeros([self.total_samples, 1])
                    little_spotx = np.zeros([self.total_samples, 1])
                    little_spoty = np.zeros([self.total_samples, 1])
                    self.time = np.zeros([self.total_samples, 1])
                    # pos format: t,x1,y1,x2,y2,numpix1,numpix2 => 20 bytes
                    for i, k in enumerate(
                        np.arange(0, self.total_samples * 20, 20)
                    ):  # Extract bytes from 20 bytes words
                        byte_offset = k
                        big_spotx[i] = int(
                            256 * byte_buffer[k + 4] + byte_buffer[k + 5]
                        )  # 4,5 bytes for big LED x
                        big_spoty[i] = int(
                            256 * byte_buffer[k + 6] + byte_buffer[k + 7]
                        )  # 6,7 bytes for big LED x
                        little_spotx[i] = int(
                            256 * byte_buffer[k + 4] + byte_buffer[k + 5]
                        )
                        little_spoty[i] = int(
                            256 * byte_buffer[k + 6] + byte_buffer[k + 7]
                        )


                    self.raw_position = {
                        "big_spotx": big_spotx,
                        "big_spoty": big_spoty,
                        "little_spotx": little_spotx,
                        "little_spoty": little_spoty,
                    }

        else:
            print(f"No pos file found for file {self.pos_file}")

    def get_cam_view(self):
        self.cam_view = {
            "min_x": self.min_x,
            "max_x": self.max_x,
            "min_y": self.min_y,
            "max_y": self.max_y,
        }
        return self.cam_view

    def get_tmaze_start(self):
        x, y = self.get_position()
        a = x[100:300]
        b = y[100:300]
        a = pd.Series([n if n != 1023 else np.nan for n in a])
        b = pd.Series([n if n != 1023 else np.nan for n in b])
        a.clip(0, 500, inplace=True)
        b.clip(0, 500, inplace=True)
        a.fillna(method="backfill", inplace=True)
        b.fillna(method="backfill", inplace=True)
        if a.median() < 200 and b.mean() > 300:
            start = "top left"
        elif a.median() > 400 and b.mean() > 300:
            start = "top right"
        elif a.median() < 200 and b.mean() < 200:
            start = "down left"
        elif a.median() > 300 and b.mean() < 200:
            start = "down right"
        else:
            start = "impossible to find"
        return start

    def get_window_view(self):
        try:
            self.windows_view = {
                "window_min_x": self.window_min_x,
                "window_max_x": self.window_max_x,
                "window_min_y": self.window_min_y,
                "window_max_y": self.window_max_y,
            }
            return self.windows_view
        except:
            print("No window view")

    def get_pixel_per_metre(self):
        return self.pixels_per_metre

    def get_raw_pos(self):
        bigx = [value[0] for value in self.raw_position["big_spotx"]]
        bigy = [value[0] for value in self.raw_position["big_spoty"]]

        return bigx, bigy

    def filter_max_speed(self, x, y, max_speed = .6):  # max speed
        tmp_x = x.copy()
        tmp_y = y.copy()
        for i in range(1, len(tmp_x) - 1):
            if (math.sqrt((x[i] - x[i - 1]) ** 2 + (y[i] - y[i - 1]) ** 2)) > (
                max_speed * self.pixels_per_metre
            ):
                tmp_x[i] = np.nan
                tmp_y[i] = np.nan
        return tmp_x, tmp_y

    def get_position(self):
        try:
            count_missing = 0
            bxx, sxx = [], []
            byy, syy = [], []            
            bigx = [value[0] for value in self.raw_position["big_spotx"]]
            bigy = [value[0] for value in self.raw_position["big_spoty"]]
            smallx = [value[0] for value in self.raw_position["little_spotx"]]
            smally = [value[0] for value in self.raw_position["little_spoty"]]
            for bx, sx in zip(bigx, smallx):  # Try to clean single blocked LED x
                if bx == 1023 and sx != 1023:
                    bx = sx
                elif bx != 1023 and sx == 1023:
                    sx = bx
                elif bx == 1023 and sx == 1023:
                    count_missing += 1
                    bx = np.nan
                    sx = np.nan

                bxx.append(bx)
                sxx.append(sx)

            for by, sy in zip(bigy, smally):  # Try to clean single blocked LED y
                if by == 1023 and sy != 1023:
                    by = sy
                elif by != 1023 and sy == 1023:
                    sy = by
                elif by == 1023 and sy == 1023:
                    by = np.nan
                    sy = np.nan
                byy.append(by)
                syy.append(sy)

            ### Remove coordinates with max_speed 
            bxx, byy = self.filter_max_speed(bxx, byy)
            sxx, syy = self.filter_max_speed(sxx, syy)

            ### Interpolate missing values
            bxx = (pd.Series(bxx).astype(float)).interpolate("cubic")
            sxx = (pd.Series(sxx).astype(float)).interpolate("cubic")
            byy = (pd.Series(byy).astype(float)).interpolate("cubic")
            syy = (pd.Series(syy).astype(float)).interpolate("cubic")
            
            ### Average both LEDs
            x = list((bxx + sxx) / 2)
            y = list((byy + syy) / 2)
            
            
            ## Boxcar filter 400 ms (axona tint default)
            # sample rate = 20 ms
            b = int(400 / 20)
            kernel = np.ones(b) / b

            def pad_and_convolve(xx, kernel):
                npad = len(kernel)
                xx = np.pad(xx, (npad, npad), "edge")
                yy = np.convolve(xx, kernel, mode="same")
                return yy[npad:-npad]

            x = pad_and_convolve(x, kernel)
            y = pad_and_convolve(y, kernel)
            
            # Remove the last np.nans
            x = [vx for vx in x if vx != np.nan]
            y = [vy for vy in y if vy != np.nan]
            
            return x, y
        except:
            print(f"No position information found in {self.pos_file}")

    def get_speed(self):
        print("Not implemented")
        pass

    def get_angular_pos(self):
        print("Not implemented")
        pass

#### **Function that reads LFP from the entire file and convert to MNE format**

In [None]:
# Function from NeuroChat - read LFP
def load_lfp_Axona(file_name):

    file_directory, file_basename = os.path.split(file_name)
    file_tag, file_extension = os.path.splitext(file_basename)
    file_extension = file_extension[1:]
    set_file = os.path.join(file_directory, file_tag + ".set")
    if os.path.isfile(file_name):
        with open(file_name, "rb") as f:
            while True:
                line = f.readline()
                try:
                    line = line.decode("latin-1")
                except BaseException:
                    break

                if line == "":
                    break
                if line.startswith("trial_date"):
                    # Blank eeg file
                    if line.strip() == "trial_date":
                        total_samples = 0
                        return
                    date = " ".join(line.replace(",", " ").split()[1:])
                if line.startswith("trial_time"):
                    time = line.split()[1]
                if line.startswith("experimenter"):
                    experimenter = " ".join(line.split()[1:])
                if line.startswith("comments"):
                    comments = " ".join(line.split()[1:])
                if line.startswith("duration"):
                    duration = float("".join(line.split()[1:]))
                if line.startswith("sw_version"):
                    file_version = line.split()[1]
                if line.startswith("num_chans"):
                    total_channel = int("".join(line.split()[1:]))
                if line.startswith("sample_rate"):
                    sampling_rate = float("".join(re.findall(r"\d+.\d+|\d+", line)))
                if line.startswith("bytes_per_sample"):
                    bytes_per_sample = int("".join(line.split()[1:]))
                if line.startswith("num_" + file_extension[:3].upper() + "_samples"):
                    total_samples = int("".join(line.split()[1:]))
                if line.startswith("data_start"):
                    break

            num_samples = total_samples
            f.seek(0, 0)
            header_offset = []
            while True:
                try:
                    buff = f.read(10).decode("UTF-8")
                except BaseException:
                    break
                if buff == "data_start":
                    header_offset = f.tell()
                    break
                else:
                    f.seek(-9, 1)

            eeg_ID = re.findall(r"\d+", file_extension)
            file_tag = 1 if not eeg_ID else int(eeg_ID[0])
            max_ADC_count = 2 ** (8 * bytes_per_sample - 1) - 1
            max_byte_value = 2 ** (8 * bytes_per_sample)

            with open(set_file, "r", encoding="latin-1") as f_set:
                lines = f_set.readlines()
                channel_lines = dict(
                    [
                        tuple(map(int, re.findall(r"\d+.\d+|\d+", line)[0].split()))
                        for line in lines
                        if line.startswith("EEG_ch_")
                    ]
                )
                channel_id = channel_lines[file_tag]
                channel_id = channel_id

                gain_lines = dict(
                    [
                        tuple(map(int, re.findall(r"\d+.\d+|\d+", line)[0].split()))
                        for line in lines
                        if "gain_ch_" in line
                    ]
                )
                gain = gain_lines[channel_id - 1]

                for line in lines:
                    if line.startswith("ADC_fullscale_mv"):
                        fullscale_mv = int(re.findall(r"\d+.\d+|d+", line)[0])
                        break
                AD_bit_uvolt = (
                    2 * fullscale_mv / (gain * np.power(2, 8 * bytes_per_sample))
                )

            record_size = bytes_per_sample
            sample_le = 256 ** (np.arange(0, bytes_per_sample, 1))

            if not header_offset:
                print("Error: data_start marker not found!")
            else:
                f.seek(header_offset, 0)
                byte_buffer = np.fromfile(f, dtype="uint8")
                len_bytebuffer = len(byte_buffer)
                end_offset = len("\r\ndata_end\r")
                lfp_wave = np.zeros(
                    [
                        num_samples,
                    ],
                    dtype=np.float64,
                )
                for k in np.arange(0, bytes_per_sample, 1):
                    byte_offset = k
                    sample_value = (
                        sample_le[k]
                        * byte_buffer[
                            byte_offset : byte_offset
                            + len_bytebuffer
                            - end_offset
                            - record_size : record_size
                        ]
                    )
                    if sample_value.size < num_samples:
                        sample_value = np.append(
                            sample_value,
                            np.zeros(
                                [
                                    num_samples - sample_value.size,
                                ]
                            ),
                        )
                    sample_value = sample_value.astype(
                        np.float64, casting="unsafe", copy=False
                    )
                    np.add(lfp_wave, sample_value, out=lfp_wave)
                np.putmask(
                    lfp_wave, lfp_wave > max_ADC_count, lfp_wave - max_byte_value
                )

                samples = lfp_wave * AD_bit_uvolt
                # timestamp = (
                #     np.arange(0, num_samples, 1) / sampling_rate)
                return samples

    else:
        print("No lfp file found for file {}".format(file_name))


def mne_lfp_Axona(file_name):
    """
    Create a mne object from a Axona recording.
    ------
    Load all channels from an Axona recording into a mne object


    Parameters:
    ------
    file_name (str): Axona .set file in the same folder as the EEG recordings referents to the set file

    Returns:
    ------
    MNE object with N channels named as ch_0 - ch_N

    """
    file_directory, file_basename = os.path.split(file_name)
    file_tag, file_extension = os.path.splitext(file_basename)
    set_file = os.path.join(file_directory, file_tag + ".set")

    # Open Set files configurations
    with open(file_name, "r", encoding="latin-1") as f_set:
        lines = f_set.readlines()
        for line in lines:
            if line.startswith("ADC_fullscale_mv"):
                fullscale_mv = int(re.findall(r"\d+.\d+|d+", line)[0])
        channel_map = dict(  # map internal channels from Axona set
            [
                tuple(map(int, re.findall(r"\d+.\d+|\d+", line)[0].split()))
                for line in lines
                if line.startswith("EEG_ch_")
            ]
        )
        recorded_channels = dict(  # map or recorded channels from Axona set
            [
                tuple(map(int, re.findall(r"\d+.\d+|\d+", line)[0].split()))
                for line in lines
                if line.startswith("saveEEG_ch_")
            ]
        )
        channel_ids = [
            ch for ch in recorded_channels.keys() if recorded_channels[ch]
        ]  # All recorded EEG channels
        gains = [
            int((re.findall(r"\d+.\d+|\d+", line)[0].split()[1]))
            for line in lines
            if "gain_ch_" in line
        ]  # List of gains

    data = []
    labels = []
    ch_types = []
    for ch in channel_ids:  # Loop for all channels
        if ch == 1:
            eeg_file = (
                file_directory + "/" + file_tag + ".eeg"
            )  # if it is the first eeg channel
        else:
            eeg_file = file_directory + "/" + file_tag + ".eeg" + str(ch)

        if os.path.isfile(eeg_file):  # open eeg file
            with open(eeg_file, "rb") as f:
                while True:
                    line = f.readline()
                    try:
                        line = line.decode("latin-1")
                    except:
                        try:
                            line = line.decode("UTF-8")
                        except BaseException:
                            break
                    if line == "":
                        break
                    if line.startswith("trial_date"):
                        # Blank eeg file
                        if line.strip() == "trial_date":
                            total_samples = 0
                            break
                        date = " ".join(line.replace(",", " ").split()[1:])
                    if line.startswith("trial_time"):
                        time = line.split()[1]
                    if line.startswith("experimenter"):
                        experimenter = " ".join(line.split()[1:])
                    if line.startswith("comments"):
                        comments = " ".join(line.split()[1:])
                    if line.startswith("duration"):
                        duration = float("".join(line.split()[1:]))
                    if line.startswith("sw_version"):
                        file_version = line.split()[1]
                    if line.startswith("num_chans"):
                        total_channel = int("".join(line.split()[1:]))
                    if line.startswith("sample_rate"):
                        sampling_rate = float("".join(re.findall(r"\d+.\d+|\d+", line)))
                    if line.startswith("bytes_per_sample"):
                        bytes_per_sample = int("".join(line.split()[1:]))
                    if line.startswith("num_EEG_samples"):
                        total_samples = int("".join(line.split()[1:]))
                    if line.startswith("data_start"):
                        break

                f.seek(0, 0)
                header_offset = []
                while True:
                    try:
                        buff = f.read(10).decode("UTF-8")
                    except BaseException:
                        break
                    if buff == "data_start":
                        header_offset = f.tell()
                        break
                    else:
                        f.seek(-9, 1)
                if not header_offset:
                    print("Error: data_start marker not found!")
                else:
                    AD_bit_uvolt = (
                        2
                        * fullscale_mv
                        / (
                            gains[channel_map[ch] - 1]
                            * np.power(2, 8 * bytes_per_sample)
                        )
                    )
                    num_samples = total_samples
                    max_ADC_count = 2 ** (8 * bytes_per_sample - 1) - 1
                    max_byte_value = 2 ** (8 * bytes_per_sample)
                    record_size = bytes_per_sample
                    sample_le = 256 ** (np.arange(0, bytes_per_sample, 1))
                    f.seek(header_offset, 0)
                    byte_buffer = np.fromfile(f, dtype="uint8")
                    len_bytebuffer = len(byte_buffer)
                    end_offset = len("\r\ndata_end\r")
                    lfp_wave = np.zeros(
                        [
                            num_samples,
                        ],
                        dtype=np.float64,
                    )
                    for k in np.arange(0, bytes_per_sample, 1):
                        byte_offset = k
                        sample_value = (
                            sample_le[k]
                            * byte_buffer[
                                byte_offset : byte_offset
                                + len_bytebuffer
                                - end_offset
                                - record_size : record_size
                            ]
                        )
                        if sample_value.size < num_samples:
                            sample_value = np.append(
                                sample_value,
                                np.zeros(
                                    [
                                        num_samples - sample_value.size,
                                    ]
                                ),
                            )
                        sample_value = sample_value.astype(
                            np.float64, casting="unsafe", copy=False
                        )
                        np.add(lfp_wave, sample_value, out=lfp_wave)
                    np.putmask(
                        lfp_wave, lfp_wave > max_ADC_count, lfp_wave - max_byte_value
                    )
                    samples = lfp_wave * AD_bit_uvolt

                    timestamp = np.arange(0, num_samples, 1) / sampling_rate
                    lfp_data = lfp_wave * AD_bit_uvolt
                    if max(lfp_data) > 0:
                        data.append((lfp_wave * AD_bit_uvolt) / 1000)
                        labels.append(f"ch_{ch}")
                        ch_types.append("eeg")
    
    info = mne.create_info(ch_names=labels, sfreq=sampling_rate, ch_types=ch_types)
    return mne.io.RawArray(np.array(data), info)

### **Plot helper functions**

In [None]:
def plot_small_sq(x,y, wview):
    ax = plt.figure(figsize=(3,3))
    ax = plt.axis('off')
    ax = plt.scatter(x, y , c="black",marker='.')
    ax = plt.xlim(40, 300)#xmax=int(wview['window_max_x']))
    ax = plt.ylim(40, 300)#ymax=int(wview['window_max_y']))
#     ax = plt.plot(x,y, c= 'g')
    ax = plt.xlabel('X pixels')
    ax = plt.ylabel('Y pixels')
#     ax = plt.title('T maze postition plot')
    plt.tight_layout()
    return plt.show()

def plot_tmaze(x,y, wview, dot = True):
    ax = plt.figure(figsize=(6,6))
#     ax = plt.axis('off')
    if dot:
        ax = plt.scatter(x, y , c="black",marker='.')
    else:
        ax = plt.plot(x,y, c= 'g', linewidth=3)
    ax = plt.xlim(0, xmax=int(wview['window_max_x']))
    ax = plt.ylim(0, ymax=int(wview['window_max_y']))
    ax = plt.xlabel('X pixels')
    ax = plt.ylabel('Y pixels')
   
    ax = plt.title('T maze postition plot')
    plt.tight_layout()
    return plt.show()

def plot_mne(raw_array, base_name):
    raw_array.load_data()
    raw_array.plot(
        n_channels=2,
        block=True,
        duration=30,
        show=True,
        clipping="transparent",
        title="Raw LFP Data from {}".format(base_name),
        remove_dc=False,
        scalings=dict(eeg=250e-5),
    )

### **Read data table and select Tmaze recordings**


In [None]:
df = pd.read_csv('data_scheme.csv', parse_dates=['date_time'] )
df = df.loc[df.maze != 'screening']
df = df.loc[df.habituation == 0]
tmaze_files = df.loc[df.maze == 'tmaze', ['folder', 'filename']].agg('/'.join, axis=1).values
df.head()

### **Open LFP data in MNE**

In [None]:
i = random.randint(0,len(tmaze_files))
file = tmaze_files[i]
lfps = mne_lfp_Axona(file)
pos = RecPos(file)
plot_mne(lfps, file)

### From the start, for each position sample (50hz), 5 EEGs samples are collected. 
**Cut all nan values from the x,y and keep the smaller continuous (x or y) trunked by nans**

In [None]:
# for file in tmaze_files:
#     try:
#         pos = RecPos(file)
#         x,y = pos.get_position()
#         rx,ry = pos.get_raw_pos()
#         nanX =  sum([a for a in x if a == np.nan])
#         nanY = sum([a for a in y if a == np.nan])
#         if nanX >0:
#             print(f' x: {nanX}\t y: {nanY}')
#     except:
#          continue

### **Finding decision region**

**TODO:**

1. get the order of the points

2. find where the animal turned to left or right



**Find the turn:**
   1. define a vector lenght (ex. 100)
   2. walk with this vector from the middle of the start arm to the end
   3. divide this vector into 2 with a gap in the middle (ex. a: 40 gap:20 b:40)
   4. Normalize both small vectors
   5. calculate the direction with dot product
    

#### **Helper functions**

In [None]:
coord = {
    'down left': [(225,110),(290,175),(180,235),(100, 140)],
    'top left':  [(100,250),(205,390),(300,340),(180,200)],
    'top right':[(370,235),(260,310),(360, 400),(460,330)],
    'down right':  [(365,100),(465,160),(375,230),(290,160)],
    'impossible to find': [(0,1),(2,1),(2, 2),(2,2)]
     }

def is_inside(x, y, start):
    points_inside = []
    coord = {
        'down left': [(225,110),(290,175),(180,235),(100, 140)],
        'top left':  [(100,250),(205,390),(300,340),(180,200)],
        'top right':[(370,235),(260,310),(360, 400),(460,330)],
        'down right':  [(365,100),(465,160),(375,230),(290,160)],
        'impossible to find': [(0,1),(2,1),(2, 2),(2,2)]
     }
    region = Polygon(coord[start])
    for vx,vy in zip(x,y):
        p = Point(vx,vy)
        if p.within(region):
            points_inside.append((vx,vy))
    if len(points_inside) > 0:
        return points_inside
    else: 
        return False

def move_window(x, y, idx, size, space):
    '''break the vextor into 2 divided by a space in the middle'''
    bx = x[idx : idx + size]
    by = y[idx : idx + size]
    cx = x[idx + size + space : idx + 2*size + space]
    cy = y[idx + size + space : idx + 2*size + space]
    return bx, by, cx, xy

def calculate_regression(x, y):
    '''Calculate linear regression and return the new line'''
    regr = linear_model.LinearRegression()
    x = np.array(x).reshape(-1,1)    
    y = np.array(y)
    regr.fit(x, y)
    y_pred = regr.predict(x)
    return x, y_pred
 
def calculate_angle(a, b, c):
    ang = math.degrees(math.atan2(c[1]-b[1], c[0]-b[0]) - math.atan2(a[1]-b[1], a[0]-b[0]))
    return ang + 360 if ang < 0 else ang

def calculate_lenght(a, b):
    return math.sqrt(((a[-1] - a[0])**2) + ((b[-1] - b[0])**2))


def is_out(x_med, y_med, px, py, d):
    ''' Cheks if a point is far from the (x_mean, y_mean) point
    Given a point (px,py) and the mean of start, the function
    return True of False if the point is outside a defined 
    distance
    
    input: (float): x_mean, y_mean, px, py
    output: (bool):
    '''
    return  d > math.sqrt((px-x_med)**2 + (py-y_med)**2)
        
    
def get_cord_inside(x,y,start):
    '''check if points are inside the determined 
    region based on start'''
    
    insiders = is_inside(x,y,start)
    area = coord[start]
    aera_x = [b[0] for b in area]
    area_y = [b[1] for b in area]
    if pos.get_tmaze_start() in coord.keys():
        c = coord[pos.get_tmaze_start()]
    px = [b[0] for b in insiders]
    py = [b[1] for b in insiders]
    return px,py


# def filter_high_density_point(x,y):
    
#     read = {}
#     points = [(xn,yn) for xn,yn in zip(x,y)]
#     count_dict = {i:points.count(i) for i in points}
#     for k in count_dict.keys():
#         if count_dict[k]
    
#     return x,y

In [None]:
file = '/mnt/d/Beths/CSR6/+ maze/27032018_t3/S8/27032018_CSR6_+maze_t3_.set'
x,y = pos.get_position()

start = pos.get_tmaze_start()

#### **Test Sean idea of getting values outside a range from the beggining**

In [None]:
def find_start_points(file):
    pos = RecPos(file)
    x,y = pos.get_position()
    start = pos.get_tmaze_start()
    wview = pos.get_window_view()
    #remove nans
    x = np.asarray(x)[~np.isnan(x)]
    y = np.asarray(y)[~np.isnan(y)]
    ax = plt.figure(figsize=(6,6))
    ax = plt.plot(x,y, c='black', linewidth=.4)
    x_med = np.median(x[300:500])
    y_med = np.median(y[300:500])
    outs = 0
    st = 200
    idx = st
    outx = []
    outy= []
    d = 90
    for px, py in list(zip(x,y))[st:]:
        idx+=1
        if math.sqrt((px-x_med)**2 + (py-y_med)**2) >= d:
            outs+=1
            outx.append(px)
            outy.append(py)
            if outs > 30:
                break
        else:
            outs = 0
            outx = []
            outy= [] 
    return idx

In [None]:
file = '/mnt/d/Beths/CSR6/+ maze/27032018_t3/S8/27032018_CSR6_+maze_t3_.set'
pos = RecPos(file)
x,y = pos.get_position()

In [None]:
for i in range(0, 30):
    i = random.randint(0,len(tmaze_files))
    file = tmaze_files[i]
    pos = RecPos(file)
    x,y = pos.get_position()
    start = pos.get_tmaze_start()
    wview = pos.get_window_view()
    #remove nans
    x = np.asarray(x)[~np.isnan(x)]
    y = np.asarray(y)[~np.isnan(y)]
    ax = plt.figure(figsize=(6,6))
    ax = plt.plot(x,y, c='black', linewidth=.4)
    x_med = np.median(x[300:500])
    y_med = np.median(y[300:500])
    st = 200
    idx = st
    outx = []
    outy= []
    d = 90
    for px, py in list(zip(x,y))[st:]:
        idx+=1
        if math.sqrt((px-x_med)**2 + (py-y_med)**2) >= d:
            outx.append(px)
            outy.append(py)
            if len(outx) > 90:
                break
        else:
            outs = 0
            outx = []
            outy= [] 
            
    x_med = np.median(x[100:300])
    y_med = np.median(y[100:300])
    x = x[st:idx+100]
    y = y[st:idx+100]
    ax = plt.plot(x,y, c='black', linewidth=1)
    ax = plt.plot(outx,outy, c='green', linewidth=2)
    ax = plt.xlim(0, xmax=int(wview['window_max_x']))
    ax = plt.ylim(0, ymax=int(wview['window_max_y']))
    ax = plt.scatter(x_med, y_med , c="red",s = d**2, alpha =.1)
    ax = plt.xlabel('X pixels')
    ax = plt.ylabel('Y pixels')
    ax = plt.title(f'{start}')
    print(idx)
    plt.tight_layout()
    plt.show()
#     break

In [None]:
def get_first_dots(x,y):
    x_med = np.median(x[300:500])
    y_med = np.median(y[300:500])
    outs = 0
    st = 200
    outx = []
    outy= []
    d = 30
    for px, py in list(zip(x,y))[st:]:
        if math.sqrt((px-x_med)**2 + (py-y_med)**2) >= d:
            outs+=1
            outx.append(px)
            outy.append(py)
            if outs > 30:
                return outx, outy
        else:
            outs = 0
            outx = []
            outy= [] 

    return None

#### **Find before and after decision using regression**

In [None]:
file = '/mnt/d/Beths/CSR6/+ maze/27032018_t3/S8/27032018_CSR6_+maze_t3_.set'
pos = RecPos(file)
x,y = pos.get_position()
# get_first_dots(x,y)

In [None]:
for n in range(0,10):
    try:
        i = random.randint(0,len(tmaze_files))
        file = tmaze_files[i]
        pos = RecPos(file)
        x,y = pos.get_position()
        x = np.asarray(x)[~np.isnan(x)]
        y = np.asarray(y)[~np.isnan(y)]
        x = x[:len(x)//2] # half of the recording
        y = y[:len(y)//2] # half of the recording
        start = pos.get_tmaze_start()
        wview = pos.get_window_view()
        space = 25 # space between regression lines
        lenght = 90 # kength of regression line
        idx = find_start_points(file)
        print(idx)
        for idx1 in range(idx, len(x)-2*lenght-space, 2):
            idx2 = idx1+space+lenght
            x1, y1 = calculate_regression(x[idx1:lenght+idx1], y[idx1:lenght+idx1])
            lx1 = len(x1)//2
            ly1 = len(y1)//2
            x2,y2 = calculate_regression(x[idx2:lenght+idx2], y[idx2:lenght+idx2])
            angle = calculate_angle((x1[0], y1[0]), (x1[-1], y1[-1]), (x2[-1],y2[-1]))
            length1 = calculate_lenght(x[idx1:lenght+idx1],y[idx1:lenght+idx1])
            length2 = calculate_lenght(x[idx2:lenght+idx2], y[idx2:lenght+idx2])
            if angle > 65 and angle < 115  and length1 > 110 and length2 > 110 and is_inside(x1[lx1:lx1+1], y1[ly1:ly1+1], start):
                print(angle)
                ax = plt.figure(figsize=(5,5))
                ax = plt.plot(x, y,  color='black', linewidth = .4) # plot maze
                ax = plt.plot(x1, y1, color='red')
                ax = plt.scatter(x[idx1:lenght+idx1], y[idx1:lenght+idx1], color='green')
                ax = plt.plot(x2, y2, color='red')
                ax = plt.scatter(x[idx2:lenght+idx2], y[idx2:lenght+idx2], color='blue')
                ax = plt.axis('off')
                ax = plt.xlim(0, xmax=int(wview['window_max_x']))
                ax = plt.ylim(0, ymax=int(wview['window_max_y']))
                ax = plt.xlabel('X pixels')
                ax = plt.ylabel('Y pixels')
                ax = plt.title('T maze postition plot')
                plt.tight_layout()
    #             plt.savefig('reg_tmaze6.png')
                plt.show()
                break
    except:
        print(file)
        print(start)

#### Plot a random T maze 

In [None]:
i = random.randint(0,len(tmaze_files))
# file = '/mnt/d/Beths/LSubRet1/recording/+maze/06092017_3rd/S7/06092017_LSubRet1_+maze_trial_3_7.set'
ax = plt.figure(figsize=(6,6))
# for i in (range(0,1)):
#     try:
file = tmaze_files[i]
pos = RecPos(file)
x,y = pos.get_position()
x = x[:len(x)//2]
y = y[:len(y)//2]
start = pos.get_tmaze_start()
insiders = is_inside(x,y,start)
area = coord[start]
aera_x = [b[0] for b in area]
area_y = [b[1] for b in area]
if pos.get_tmaze_start() in coord.keys():
    c = coord[pos.get_tmaze_start()]
px = [b[0] for b in insiders]
py = [b[1] for b in insiders]
wview = pos.get_window_view()
ax = plt.scatter(x, y , c="black",marker='.',  alpha=0.5)
ax = plt.plot(aera_x, area_y , c="red",marker='.')
ax = plt.axis('on')
ax = plt.scatter(px, py, c="green",marker='_')
ax = plt.xlim(0, xmax=int(wview['window_max_x']))
ax = plt.ylim(0, ymax=int(wview['window_max_y']))
ax = plt.xlabel(file)
ax = plt.ylabel('Y pixels')
ax = plt.title(f'Started { pos.get_tmaze_start()}')

#     except:
#         continue
plt.tight_layout()
# plt.savefig(f'decision_point_{i}.png')
# # plt.close()

In [None]:
# def calculate_bbox():
    

# def get_tmaze_start(x,y):

#     a = x[100:250]
#     b = y[100:250]
#     a = pd.Series([n if n != 1023 else np.nan for n in a])
#     b = pd.Series([n if n != 1023 else np.nan for n in b])
#     a.clip(0, 500, inplace=True)
#     b.clip(0, 500, inplace=True)
#     a.fillna(method="backfill", inplace=True)
#     b.fillna(method="backfill", inplace=True)
#     if a.mean() < 200 and b.mean() > 300:
#         start = "top left"
#     elif a.mean() > 400 and b.mean() > 300:
#         start = "top right"
#     elif a.mean() < 200 and b.mean() < 200:
#         start = "down left"
#     elif a.mean() > 300 and b.mean() < 200:
#         start = "down right"
#     else:
#         start = "impossible to find"
#     return start

### **Separate into forced and choice**

In [None]:
fig = plt.figure("T_maze LFP", figsize=(8,8))

# Plot the T maze position
pos = RecPos(file)
print(pos.pos_file)
x,y = pos.get_position()
pos.get_tmaze_start()
wview = pos.get_window_view()
ax0 = fig.add_subplot(2, 2, 1)
ax0 = plt.plot(x,y, c= 'black', linewidth=2, alpha=.3)
ax0 = plt.scatter(x[0:len(x)//2], y[0:len(x)//2] , c="black",marker='.')
ax0 = plt.xlim(0, xmax=int(wview['window_max_x']))
ax0 = plt.ylim(0, ymax=int(wview['window_max_y']))
ax0 = plt.xlabel('X pixels')
ax0 = plt.ylabel('Y pixels')
ax0 = plt.title(f'Forced path started on {pos.get_tmaze_start()}')
plt.tight_layout()

ax0 = fig.add_subplot(2, 2, 2)
ax0 = plt.scatter(x[len(x)//2:], y[len(x)//2:], c="black",marker='.')
ax0 = plt.plot(x,y, c= 'black', linewidth=2, alpha=.3)
ax0 = plt.xlim(0, xmax=int(wview['window_max_x']))
ax0 = plt.ylim(0, ymax=int(wview['window_max_y']))
ax0 = plt.xlabel('X pixels')
ax0 = plt.ylabel('Y pixels')
ax0 = plt.title(f'Choosen path started on {pos.get_tmaze_start()}')
plt.tight_layout()

# Load the EEG data
ch0_f =  load_lfp_Axona(file[:-3] + 'eeg')
half = int(len(ch0_f)/2)
ch1_f = load_lfp_Axona(file[:-3] + 'eeg2')
ch0_c = load_lfp_Axona(file[:-3] + 'eeg3')
ch1_c = load_lfp_Axona(file[:-3] + 'eeg4')
eeg = [ch0_f[0:half], ch1_f[0:half], ch0_c[half:], ch1_c[half:] ]

data = np. array(eeg)
data = data.T
n_samples = len(load_lfp_Axona(file[:-3]+'eeg')[0:half])
n_rows = len(eeg)
t = 10 * np.arange(n_samples) / n_samples

# Plot the EEG
ticklocs = []
ax2 = fig.add_subplot(2, 1, 2)
ax2.set_xlim(0, 10)
ax2.set_xticks(np.arange(10))
dmin = data.min()
dmax = data.max()
dr = (dmax - dmin) * 2  # Crowd them a bit.
y0 = dmin
y1 = (n_rows - 1) * dr + dmax
ax2.set_ylim(y0, y1)
segs = []
for i in range(n_rows):
    segs.append(np.column_stack((t, data[:, i])))
    ticklocs.append(i * dr)
offsets = np.zeros((n_rows, 2), dtype=float)
offsets[:, 1] = ticklocs
lines = LineCollection(segs, offsets=offsets, transOffset=None)
ax2.add_collection(lines)

# Set the yticks to use axes coordinates on the y axis
ax2.set_yticks(ticklocs)
ax2.set_yticklabels(['ch0 forced', 'ch1 forced', 'ch0 choice', 'ch1 choice'])
ax2.set_xlabel('Time (s)')
plt.tight_layout()
plt.savefig('Tmaze3.png')
plt.show()

Try plotting a bounding box

In [None]:
def points_in_circle_np(radius, x0=0, y0=0, ):
    x_ = np.arange(x0 - radius - 1, x0 + radius + 1, dtype=int)
    y_ = np.arange(y0 - radius - 1, y0 + radius + 1, dtype=int)
    x, y = np.where((x_[:,np.newaxis] - x0)**2 + (y_ - y0)**2 <= radius**2)
#     x, y = np.where((np.hypot((x_-x0)[:,np.newaxis], y_-y0)<= radius)) # alternative implementation
    for x, y in zip(x_[x], y_[y]):
        return x, y

In [None]:
for n in range(0,2):
    i = random.randint(0,len(tmaze_files))
    file = tmaze_files[i]
    pos = RecPos(file)
    x,y = pos.get_position()
    #remove nans
    x = np.asarray(x)[~np.isnan(x)]
    y = np.asarray(y)[~np.isnan(y)]
    start = pos.get_tmaze_start()
    ax = plt.figure(figsize=(6,6))
    circle = plt.Circle((0,0), radius= 5)
    ax = plt.plot(x,y, c= 'black', linewidth=1, alpha = .7)
    ax = plt.xlim(0, xmax=int(wview['window_max_x']))
    ax = plt.ylim(0, ymax=int(wview['window_max_y']))
    x_avg = np.median(x)
    y_avg = np.median(y)
    ax = plt.scatter(x_avg, y_avg , c="red",s=3000)
    ax = plt.xlabel('X pixels')
    ax = plt.ylabel('Y pixels')
    ax = plt.title(f'{start}')
    plt.tight_layout()
    plt.show()