In [1]:
# To start mysql docker: sudo docker-compose up -d

import datajoint as dj
import os
import sys
import json
import re
import warnings
import traceback
import glob
from scipy.io import loadmat

In [2]:
class CorruptData(Exception):
    pass

class InconsistentData(Warning):
    pass

In [3]:
root_path = os.path.abspath(os.path.join(globals()['_dh'][0], '../..'))

In [4]:
local_cred_filename = os.path.join(root_path, 'datajoint/local_cred.ini')
with open(local_cred_filename) as f:
    local_cred = json.load(f)

In [5]:
cred_filename = "~/"

dj.config['database.host'] = local_cred['host']
dj.config['database.user'] = local_cred['user']
dj.config['database.password'] = local_cred['password']

In [6]:
dj.conn().cancel_transaction()

In [9]:
schema = dj.schema('franklab_nspike', locals())

In [56]:
schema = dj.schema('franklab_nspike', locals())
schema.drop()
schema = dj.schema('franklab_nspike', locals())

In [59]:
anim.drop()

In [30]:
RippleDetectionConfig().drop()

In [58]:

class WarningTracker:
    def __init__(self, suppress_print=False):
        self.suppress_print = suppress_print
        self.missing_entries = []
        self.duplicate_entries = []
        self.duplicate_nonmatch_entries = []
        
    def missing_field(self, key, field):
        if(not self.suppress_print):
             warnings.warn('Missing {:s}: {:s}'.format(str(key), field), InconsistentData)
        self.missing_entries.append((key, field))
    
    def duplicate(self, key, field, using_first=True):
        using_first = ''
        if using_first:
            first_str = ' using first'
        if(not self.suppress_print):
            warnings.warn('Duplicate entry for {:s}: {:s}{:s}.'.
                          format(str(key), field, first_str), InconsistentData)
        self.duplicate_nonmatch_entries.append((key, field, value1, value2))        
    
    def duplicate_nonmatch(self, key, field, value1, value2):
        if(not self.suppress_print):
            warnings.warn('Duplicate entry for {:s}: {:s}, values are({:s}, {:s})'.
                          format(str(key), field, value1, value2), InconsistentData)
        self.duplicate_nonmatch_entries.append((key, field, value1, value2))
    
    def __str__(self):
        return ('missing_entries:\n' + '\n'.join(map(str, self.missing_entries)) + '\n' +
                'duplicate_entries:\n' + '\n'.join(map(str, self.duplicate_entries)) + '\n' +
                'duplicate_nonmatch_entries:\n' + '\n'.join(map(str, self.duplicate_nonmatch_entries)))
    

@schema
class Animal(dj.Manual):
    definition = """
    anim_name: varchar(20)  #Name of animal
    ---
    anim_name_short: varchar(10)
    anim_path_raw: varchar(200)
    anim_path_mat: varchar(200)
    """
    
    def __init__(self, suppress_print=False, arg=None):
        self.warn = WarningTracker(suppress_print=suppress_print)
        super().__init__(arg)


@schema
class Day(dj.Imported):
    definition = """
    -> Animal
    day: int
    ---
    day_path_raw: varchar(200)
    day_path_mat: varchar(200)
    day_start_time_sec: float
    day_end_time_sec: float
    day_start_time_nspike: int
    day_end_time_nspike: int
    """
    def __init__(self, suppress_print=False, arg=None):
        self.warn = WarningTracker(suppress_print=suppress_print)
        super().__init__(arg)
        
    def make(self, key):
        
        anim_name, anim_path_raw, anim_path_mat = (Animal() & key).fetch1('anim_name', 'anim_path_raw', 'anim_path_mat')
        dir_names = os.listdir(anim_path_raw)
        for dir_name in dir_names:
            m = re.search('^{:s}(\d*)$'.format(anim_name.lower()), dir_name)
            if m:
                day = int(m.groups()[0])
                day_path_raw = os.path.join(anim_path_raw, dir_name)
                times_path = os.path.join(day_path_raw, 'times.mat')
                if os.path.isfile(times_path):
                    times_mat = loadmat(times_path)
                    time_ranges = times_mat['ranges']
                    day_start_time_nspike = time_ranges[0][0]
                    day_end_time_nspike = time_ranges[0][1]
                    day_start_time_sec = day_start_time_nspike/10000
                    day_end_time_sec = day_end_time_nspike/10000
                    self.insert1({'anim_name': anim_name,
                                  'day': day,
                                  'day_path_raw': day_path_raw,
                                  'day_path_mat': anim_path_mat,
                                  'day_start_time_sec': day_start_time_nspike,
                                  'day_end_time_sec': day_end_time_nspike,
                                  'day_start_time_nspike': day_start_time_sec,
                                  'day_end_time_nspike': day_end_time_sec})
                else:
                    # Missing times.mat means data folder was not processed for spike sorting (matclust)
                    self.warn.missing_field(key, times_path)
                    pass

                
@schema
class Epoch(dj.Imported):
    definition = """
    -> Day
    epoch_id: tinyint
    ---
    epoch_name: varchar(50)
    epoch_time_str: varchar(50)
    epoch_start_time_sec: float
    epoch_end_time_sec: float
    epoch_start_time_nspike: int
    epoch_end_time_nspike: int
    """
    
    def __init__(self, suppress_print=False, arg=None):
        self.warn = WarningTracker(suppress_print=suppress_print)
        super().__init__(arg)
        
    def make(self, key):
        anim_name, day, day_path_raw = (Animal() * (Day() & key)).fetch1('anim_name', 'day', 'day_path_raw')
        try:
            times_mat = loadmat(os.path.join(day_path_raw, 'times.mat'))
            time_ranges = times_mat['ranges']
            names = times_mat['names']
            for epoch_id, epoch_time_range in enumerate(time_ranges[1:]):
                epoch_start_time_nspike = epoch_time_range[0]
                epoch_end_time_nspike = epoch_time_range[1]
                epoch_start_time_sec = epoch_start_time_nspike/10000
                epoch_end_time_sec = epoch_end_time_nspike/10000
                name_entry = names[epoch_id + 1][0][0]
                name_re = re.search('^\d*\s*(\w*)\s*([0-9:\-_]*)$', name_entry)
                if name_re:
                    epoch_name = name_re.groups()[0]
                    epoch_time_str = name_re.groups()[1]
                    self.insert1({'anim_name': anim_name,
                                  'day': day, 
                                  'epoch_id': epoch_id, 
                                  'epoch_name': epoch_name,
                                  'epoch_time_str': epoch_time_str,
                                  'epoch_start_time_sec': epoch_start_time_sec,
                                  'epoch_end_time_sec':epoch_start_time_sec,
                                  'epoch_start_time_nspike': epoch_start_time_nspike,
                                  'epoch_end_time_nspike': epoch_end_time_nspike
                                 })

        except FileNotFoundError:
            self.warn.missing_field(key, times_path)
    
@schema
class Tetrode(dj.Imported):
    definition = """
    -> Animal
    tet_id: tinyint
    ---
    tet_hemisphere: varchar(50)
    """
    
    def __init__(self, suppress_print=False, arg=None):
        self.warn = WarningTracker(suppress_print=suppress_print)
        super().__init__(arg)
        
    def make(self, key):
        anim_name, anim_path_mat = (Animal() & key).fetch1('anim_name', 'anim_path_mat')
        mat = loadmat(os.path.join(anim_path_mat, 'bontetinfo.mat'))
        tet = {}
        for tet_days in mat['tetinfo'][0]:
            if len(tet_days[0]) > 0:
                for tet_epochs in tet_days[0]:
                    for tet_id, tet_epoch in enumerate(tet_epochs[0]):
                        if len(tet_epoch[0]) > 0:
                            tet_entry = tet.setdefault(tet_id, {})
                            tet_entry_hemi_list = tet_entry.setdefault('hemisphere', [])
                            try:
                                tet_hemi = tet_epoch[0][0]['hemisphere'][0]
                            except ValueError:
                                tet_hemi = None
                            tet_entry_hemi_list.append(tet_hemi)

        for tet_id, tet_entries in tet.items():
            tet_hemi = tet_entries['hemisphere']
            tet_hemi_set = set(tet_hemi)
            if len(tet_hemi_set) == 1:
                tet_hemisphere = list(tet_hemi_set)[0]
                if tet_hemisphere is None:
                    tet_hemisphere = ''
                self.insert1({'anim_name': anim_name,
                              'tet_id': tet_id,
                              'tet_hemisphere': tet_hemisphere
                             })
            else:
                warn.duplicate_nonmatch(key, 'hemisphere', tet, list(tet_hemi_set)[0], list(tet_hemi_set)[1:])


@schema
class TetrodeEpoch(dj.Computed):
    definition = """
    -> Tetrode
    -> Epoch
    ---
    tet_depth = NULL: int
    tet_num_cells = NULL: int
    tet_area = NULL: varchar(50)
    tet_subarea = NULL: varchar(50)
    tet_near_ca2 = NULL: tinyint       # boolean
    """
    
    def __init__(self, suppress_print=False, arg=None):
        self.warn = WarningTracker(suppress_print=suppress_print)
        super().__init__(arg)
    
    def make(self, key):
        anim_name, anim_path_mat = (Animal() & key).fetch1('anim_name', 'anim_path_mat')
        try:
            mat = self.anim_tet_infos[anim_name]
        except AttributeError:
            self.anim_tet_infos = {}
            mat = loadmat(os.path.join(anim_path_mat, 'bontetinfo.mat'))
            self.anim_tet_infos[anim_name] = mat
        except KeyError:
            mat = loadmat(os.path.join(anim_path_mat, 'bontetinfo.mat'))
            self.anim_tet_infos[anim_name] = mat            
            
        try:
            tet_epoch_data = mat['tetinfo'][0][key['day']-1][0][key['epoch_id']][0][key['tet_id']][0]

            # if mat cell is empty, skip insert
            if tet_epoch_data.size > 0:
                try:
                    key['tet_depth'] = tet_epoch_data['depth'][0][0][0][0][0]
                except (ValueError, IndexError):
                    self.warn.missing_field(key, 'tet_depth')
                    # print(key)
                    # leave entry out
                    pass
                    # key['tet_depth'] = None
                try:
                    key['tet_num_cells'] = tet_epoch_data['numcells'][0][0][0]
                except (ValueError, IndexError):
                    self.warn.missing_field(key, 'tet_num_cells')
                    # print(key)
                    # leave entry out
                    pass
                    # key['tet_num_cells'] = None
                try:
                    key['tet_area'] = tet_epoch_data['area'][0][0]
                except (ValueError, IndexError):
                    self.warn.missing_field(key, 'tet_area')
                    # print(key)
                    # leave entry out
                    pass
                    # key['tet_area'] = None
                try:
                    key['tet_subarea'] = tet_epoch_data['subarea'][0][0]
                except (ValueError, IndexError):
                    self.warn.missing_field(key, 'tet_subarea')
                    # print(key)
                    # leave entry out
                    pass
                    # key['tet_subarea'] = None
                try:
                    key['tet_near_ca2'] = tet_epoch_data['nearCA2'][0][0][0]
                except (ValueError, IndexError):
                    self.warn.missing_field(key, 'tet_near_ca2')
                    # print(key)
                    # leave entry out
                    pass
                    # key['tet_near_ca2'] = None
                self.insert1(key)
        except (ValueError, IndexError):
            self.warn.missing_field(key, 'entire tetrode')
            # print(key)
            pass

@schema
class LFP(dj.Imported):
    definition = """
    -> TetrodeEpoch
    ---
    lfp_filepath_eeg_mat = NULL: varchar(200)
    """
    
    def __init__(self, suppress_print=False, arg=None):
        self.warn = WarningTracker(suppress_print=suppress_print)
        super().__init__(arg)

    def make(self, key):
        day_path_mat = (Day() & key).fetch1('day_path_mat')
        try:
            lfp_filepath_eeg_mats = glob.glob(os.path.join(day_path_mat, 'EEG/boneeg{:02d}-{:d}-{:02d}.mat'.
                                                           format(key['day'], key['epoch_id']+1, key['tet_id']+1)))
            
            if len(lfp_filepath_eeg_mats) > 1:
                self.warn.duplicate(key, lfp_filepath_eeg_mats, using_first=True)
            lfp_filepath_eeg_mat = lfp_filepath_eeg_mats[0]
            
            key['lfp_filepath_eeg_mat'] = lfp_filepath_eeg_mat
        except IndexError:
            self.warn.missing_field(key, 'eeg mat')

        self.insert1(key)


@schema
class LFPRaw(dj.Imported):
    definition = """
    -> TetrodeEpoch
    ---
    lfp_filepath_raw = NULL: varchar(200)
    lfp_tet_depth = NULL: int
    """
    
    def __init__(self, suppress_print=False, arg=None):
        self.warn = WarningTracker(suppress_print=suppress_print)
        super().__init__(arg)

    def make(self, key):
        day_path_raw, day_path_mat = (Day() & key).fetch1('day_path_raw', 'day_path_mat')
        tet_id = key['tet_id']
        lfp_filepath_raw = glob.glob(os.path.join(day_path_raw, '{:02d}-*.eeg').format(tet_id+1))[0]
        lfp_filename_raw = os.path.basename(lfp_filepath_raw)
        re_match = re.search('\d*-(\d*).eeg$', lfp_filename_raw)
        lfp_tet_depth_str = re_match.groups()[0]
        lfp_tet_depth = int(lfp_tet_depth_str)

        key['lfp_filepath_raw'] = lfp_filepath_raw
        key['lfp_tet_depth'] = lfp_tet_depth

        self.insert1(key)
     
    
@schema
class LFPGnd(dj.Imported):
    definition = """
    -> TetrodeEpoch
    ---
    lfp_filepath_eeggnd_mat = NULL: varchar(200)
    """
    
    def __init__(self, suppress_print=False, arg=None):
        self.warn = WarningTracker(suppress_print=suppress_print)
        super().__init__(arg)

    def make(self, key):
        day_path_mat = (Day() & key).fetch1('day_path_mat')
        
        try:
            lfp_filepath_eeggnd_mats = glob.glob(os.path.join(day_path_mat, 'EEG/boneeggnd{:02d}-{:d}-{:02d}.mat'.
                                                              format(key['day'], key['epoch_id']+1, key['tet_id']+1)))
            if len(lfp_filepath_eeggnd_mats) > 1:
                self.warn.duplicate(key, lfp_filepath_eeggnd_mats)

            lfp_filepath_eeggnd_mat = lfp_filepath_eeggnd_mats[0]
            
            key['lfp_filepath_eeggnd_mat'] = lfp_filepath_eeggnd_mat
        except IndexError:
            self.warn.missing_field(key, 'eeggnd')
        
        self.insert1(key)


@schema
class RippleDetectionConfig(dj.Lookup):
    definition = """
    rip_alg : varchar(20)
    rip_detect_thresh : decimal(5,2)
    rip_min_thresh_dur : decimal(6,4)
    rip_tet_filter : varchar(100)
    ---

    """
    
    contents = [['cons', 2.0, 0.0300, "'(isequal($validripple, 1))'"]]


@schema
class RippleInterval(dj.Imported):
    definition = """
    -> Epoch
    -> RippleDetectionConfig
    ---
    part_path: varchar(200)
    path_abs: varchar(200)
    """

    def __init__(self, suppress_print=False, arg=None):
        self.warn = WarningTracker(suppress_print=suppress_print)
        super().__init__(arg)

    class LFPSource(dj.Part):
        definition = """
        -> LFP
        -> RippleInterval
        ---
        
        """
    
    def make(self, key):
        print(key)
        anim_name_short, day_path_mat = (Animal() * (Day() & key)).fetch1('anim_name_short', 'day_path_mat')
        anim_name = key['anim_name']
        day = key['day']
        epoch_id = key['epoch_id']
        rip_alg = key['rip_alg']
        rip_mat_fp = glob.glob(os.path.join(day_path_mat, '{:s}ripples{:s}{:02d}.mat'.
                                            format(anim_name_short, rip_alg, key['day'])))
        
        if len(rip_mat_fp) > 1:
            self.warn.duplicate_entries(key, rip_mat_fp[0], rip_mat_fp[1:], use_first=True)
        elif len(rip_mat_fp) == 0:
            self.warn.missing_entries(key, os.path.join(day_path_mat, '{:s}ripples{:s}{:02d}.mat'.
                                      format(anim_name_short, rip_alg, key['day'])))
        
        try:
            rip_mat = self.day_rip_mats[key['day']-1]
        except AttributeError:
            self.day_rip_mats = {}
            rip_mat = loadmat(rip_mat_fp[0])[str('ripples{:s}'.format(rip_alg))][0]
            self.day_rip_mats[key['day']-1] = rip_mat
        except KeyError:
            rip_mat = loadmat(rip_mat_fp[0])['ripples{:s}'.format(rip_alg)][0]
            self.day_rip_mats[key['day']-1] = rip_mat
        
        rip_epoch = rip_mat[key['day']-1][0][key['epoch_id']][0]
        f

@schema
class RawSpikes(dj.Imported):
    definition = """
    -> TetrodeEpoch
    ---
    raw_spike_path: varchar(200)
    """

@schema
class Position(dj.Imported):
    definition = """
    -> Epoch
    ---
    pos_path: varchar(200)
    """

@schema
class LinearPosition(dj.Imported):
    definition = """
    -> Position
    ---
    lin_pos_path: varchar(200)
    """


In [59]:
%%time
anim = Animal(suppress_print=True)
anim.insert1({'anim_name': 'Bond', 'anim_name_short': 'bon', 'anim_path_raw': '/opt/data/daliu/other/mkarlsso/bond/', 
              'anim_path_mat': '/opt/data/daliu/other/mkarlsso/Bon/'})
display(anim)
day = Day(suppress_print=True)
day.populate()
display(day)
epoch = Epoch(suppress_print=True)
epoch.populate()
display(epoch)
tet = Tetrode(suppress_print=True)
tet.populate()
display(tet)
tet_ep = TetrodeEpoch(suppress_print=True)
tet_ep.populate(reserve_jobs=True)
display(tet_ep)


In [None]:
lfp.drop()
lfp_raw.drop()
lfp_gnd.drop()

In [60]:
%%time

lfp = LFP(suppress_print=True)
lfp.populate()
display(lfp)
lfp_raw = LFPRaw(suppress_print=True)
lfp_raw.populate()
display(lfp_raw)
lfp_gnd = LFPGnd(suppress_print=True)
lfp_gnd.populate()
display(lfp_gnd)

In [61]:
rip = RippleInterval()
rip_config = RippleDetectionConfig()


In [None]:
rip_config.drop()

In [62]:
rip.populate()

In [47]:
rip_config

In [21]:
rip.LFPSource()

In [13]:
dj.ERD(schema)