In [143]:
# To start mysql docker: sudo docker-compose up -d

import datajoint as dj
import os
import sys
import json
import re
import warnings
import traceback
from scipy.io import loadmat

In [93]:
class CorruptData(Exception):
    pass

class InconsistentData(Warning):
    pass

In [94]:
root_path = os.path.abspath(os.path.join(globals()['_dh'][0], '../..'))

In [95]:
local_cred_filename = os.path.join(root_path, 'datajoint/local_cred.ini')
with open(local_cred_filename) as f:
    local_cred = json.load(f)

In [96]:
cred_filename = "~/"

dj.config['database.host'] = local_cred['host']
dj.config['database.user'] = local_cred['user']
dj.config['database.password'] = local_cred['password']

In [165]:
dj.conn().cancel_transaction()

In [98]:
schema = dj.schema('franklab_nspike', locals())

In [134]:
schema.drop()
schema = dj.schema('franklab_nspike', locals())

In [185]:
anim.drop()

In [186]:

@schema
class Animal(dj.Manual):
    definition = """
    anim_name: varchar(20)  #Name of animal
    ---
    anim_path_raw: varchar(200)
    anim_path_mat: varchar(200)
    """

@schema
class Day(dj.Imported):
    definition = """
    -> Animal
    day: int
    ---
    day_path_raw: varchar(200)
    day_start_time_sec: float
    day_end_time_sec: float
    day_start_time_nspike: int
    day_end_time_nspike: int
    """
    
    def make(self, key):
        
        anim_name, anim_path_raw, anim_path_mat = (Animal() & key).fetch1('anim_name', 'anim_path_raw', 'anim_path_mat')
        dir_names = os.listdir(anim_path_raw)
        for dir_name in dir_names:
            m = re.search('^{:s}(\d*)$'.format(anim_name.lower()), dir_name)
            if m:
                day = int(m.groups()[0])
                day_path_raw = os.path.join(anim_path_raw, dir_name)
                times_path = os.path.join(day_path_raw, 'times.mat')
                if os.path.isfile(times_path):
                    times_mat = loadmat(times_path)
                    time_ranges = times_mat['ranges']
                    day_start_time_nspike = time_ranges[0][0]
                    day_end_time_nspike = time_ranges[0][1]
                    day_start_time_sec = day_start_time_nspike/10000
                    day_end_time_sec = day_end_time_nspike/10000
                    self.insert1({'anim_name': anim_name,
                                  'day': day,
                                  'day_path_raw': day_path_raw,
                                  'day_start_time_sec': day_start_time_nspike,
                                  'day_end_time_sec': day_end_time_nspike,
                                  'day_start_time_nspike': day_start_time_sec,
                                  'day_end_time_nspike': day_end_time_sec})
                else:
                    # Missing times.mat means data folder was not processed for spike sorting (matclust)
                    pass

@schema
class Epoch(dj.Imported):
    definition = """
    -> Day
    epoch_id: tinyint
    ---
    epoch_name: varchar(50)
    epoch_time_str: varchar(50)
    epoch_start_time_sec: float
    epoch_end_time_sec: float
    epoch_start_time_nspike: int
    epoch_end_time_nspike: int
    """
    
    def make(self, key):
        anim_name, day, day_path_raw = (Animal() * (Day() & key)).fetch1('anim_name', 'day', 'day_path_raw')
        try:
            times_mat = loadmat(os.path.join(day_path_raw, 'times.mat'))
            time_ranges = times_mat['ranges']
            names = times_mat['names']
            for epoch_id, epoch_time_range in enumerate(time_ranges[1:]):
                epoch_start_time_nspike = epoch_time_range[0]
                epoch_end_time_nspike = epoch_time_range[1]
                epoch_start_time_sec = epoch_start_time_nspike/10000
                epoch_end_time_sec = epoch_end_time_nspike/10000
                name_entry = names[epoch_id + 1][0][0]
                name_re = re.search('\d*\s*(\w*)\s*([0-9:\-_]*)$', name_entry)
                if name_re:
                    epoch_name = name_re.groups()[0]
                    epoch_time_str = name_re.groups()[1]
                    self.insert1({'anim_name': anim_name,
                                  'day': day, 
                                  'epoch_id': epoch_id, 
                                  'epoch_name': epoch_name,
                                  'epoch_time_str': epoch_time_str,
                                  'epoch_start_time_sec': epoch_start_time_sec,
                                  'epoch_end_time_sec':epoch_start_time_sec,
                                  'epoch_start_time_nspike': epoch_start_time_nspike,
                                  'epoch_end_time_nspike': epoch_end_time_nspike
                                 })

        except FileNotFoundError:
            raise CorruptData('Missing {:s}.'.format((os.path.join(day_path_raw, 'times.mat'))))
    
@schema
class Tetrode(dj.Imported):
    definition = """
    -> Animal
    tet_id: tinyint
    ---
    tet_hemisphere: varchar(50)
    """
    def make(self, key):
        anim_name, anim_path_mat = (Animal() & key).fetch1('anim_name', 'anim_path_mat')
        mat = loadmat(os.path.join(anim_path_mat, 'bontetinfo.mat'))
        tet = {}
        for tet_days in mat['tetinfo'][0]:
            if len(tet_days[0]) > 0:
                for tet_epochs in tet_days[0]:
                    for tet_id, tet_epoch in enumerate(tet_epochs[0]):
                        if len(tet_epoch[0]) > 0:
                            tet_entry = tet.setdefault(tet_id+1, {})
                            tet_entry_hemi_list = tet_entry.setdefault('hemisphere', [])
                            try:
                                tet_hemi = tet_epoch[0][0]['hemisphere'][0]
                            except ValueError:
                                tet_hemi = None
                            tet_entry_hemi_list.append(tet_hemi)

        for tet_id, tet_entries in tet.items():
            tet_hemi = tet_entries['hemisphere']
            tet_hemi_set = set(tet_hemi)
            if len(tet_hemi_set) == 1:
                tet_hemisphere = list(tet_hemi_set)[0]
                if tet_hemisphere is None:
                    tet_hemisphere = ''
                self.insert1({'anim_name': anim_name,
                              'tet_id': tet_id,
                              'tet_hemisphere': tet_hemisphere
                             })
            else:
                warnings.warn("Tetrode {:d} doesn't have exactly 1 type for hemisphere entry: {:s}".format(tet_id, str(tet_hemi_set)))


@schema
class TetrodeEpoch(dj.Computed):
    definition = """
    -> Tetrode
    -> Epoch
    ---
    tet_depth = NULL: int
    tet_num_cells = NULL: int
    tet_area = NULL: varchar(50)
    tet_subarea = NULL: varchar(50)
    tet_near_ca2 = NULL: tinyint       # boolean
    """
    
    def __init__(self, suppress_missing_print=False, arg=None):
        self.suppress_missing_print = suppress_missing_print
        super().__init__(arg)
    
    def _process_missing_field(self, field, key, suppress_missing_print):
        if(not suppress_missing_print):
             print(self._form_missing_str(field, key))
        self._save_missing_entry(field, key)
    
    def _form_missing_str(self, field, key):
        missing_str = ('Missing animal {:s}, day {:d}, epoch_id {:d}, tet_id {:d}: {:s}'.
                       format(key['anim_name'], key['day'], key['epoch_id'], key['tet_id'], field))
        
        return missing_str
    
    def _save_missing_entry(self, field, key):
        try:
            self.missing_entries.append((key['anim_name'], key['day'], key['epoch_id'], key['tet_id'], field))
        except AttributeError:
            self.missing_entries = []
            self.missing_entries.append((key['anim_name'], key['day'], key['epoch_id'], key['tet_id'], field))
    
    def make(self, key):
        anim_name, anim_path_mat = (Animal() & key).fetch1('anim_name', 'anim_path_mat')
        try:
            mat = self.anim_tet_infos[anim_name]
        except AttributeError:
            self.anim_tet_infos = {}
            mat = loadmat(os.path.join(anim_path_mat, 'bontetinfo.mat'))
            self.anim_tet_infos[anim_name] = mat
        except KeyError:
            mat = loadmat(os.path.join(anim_path_mat, 'bontetinfo.mat'))
            self.anim_tet_infos[anim_name] = mat            
            
        try:
            tet_epoch_data = mat['tetinfo'][0][key['day']-1][0][key['epoch_id']][0][key['tet_id']][0]

            # if mat cell is empty, skip insert
            if tet_epoch_data.size > 0:
                try:
                    key['tet_depth'] = tet_epoch_data['depth'][0][0][0][0][0]
                except (ValueError, IndexError):
                    self._process_missing_field('"tet_depth"', key, self.suppress_missing_print)
                    # print(key)
                    # leave entry out
                    pass
                    # key['tet_depth'] = None
                try:
                    key['tet_num_cells'] = tet_epoch_data['numcells'][0][0][0]
                except (ValueError, IndexError):
                    self._process_missing_field('"tet_num_cells"', key, self.suppress_missing_print)
                    # print(key)
                    # leave entry out
                    pass
                    # key['tet_num_cells'] = None
                try:
                    key['tet_area'] = tet_epoch_data['area'][0][0]
                except (ValueError, IndexError):
                    self._process_missing_field('"tet_area"', key, self.suppress_missing_print)
                    # print(key)
                    # leave entry out
                    pass
                    # key['tet_area'] = None
                try:
                    key['tet_subarea'] = tet_epoch_data['subarea'][0][0]
                except (ValueError, IndexError):
                    self._process_missing_field('"tet_subarea"', key, self.suppress_missing_print)
                    # print(key)
                    # leave entry out
                    pass
                    # key['tet_subarea'] = None
                try:
                    key['tet_near_ca2'] = tet_epoch_data['nearCA2'][0][0][0]
                except (ValueError, IndexError):
                    self._process_missing_field('"tet_near_ca2"', key, self.suppress_missing_print)
                    # print(key)
                    # leave entry out
                    pass
                    # key['tet_near_ca2'] = None
                self.insert1(key)
        except (ValueError, IndexError):
            self._process_missing_field('entire tetrode', key, self.suppress_missing_print)
            # print(key)
            pass

@schema
class LFP(dj.Imported):
    definition = """
    -> TetrodeEpoch
    ---
    lfp_path_raw_part: varchar(200)
    lfp_path_raw_abs: varchar(200)
    lfp_path_eeg_part: varchar(200)
    lfp_path_eeg_abs: varchar(200)
    """
    
    def make(self, key):
        print (key)
        display((Animal() * Tetrode()) & key)
     
@schema
class RippleInterval(dj.Imported):
    definition = """
    -> Epoch
    algorithm: varchar(20)
    ---
    part_path: varchar(200)
    path_abs: varchar(200)
    """
    
    class LFPSource(dj.Part):
        definition = """
        -> LFP
        -> RippleInterval
        ---
        
        """
    
    def make(self, key):
        print(key)

@schema
class RawSpikes(dj.Imported):
    definition = """
    -> TetrodeEpoch
    ---
    raw_spike_path: varchar(200)
    """

@schema
class Position(dj.Imported):
    definition = """
    -> Epoch
    ---
    pos_path: varchar(200)
    """

@schema
class LinearPosition(dj.Imported):
    definition = """
    -> Position
    ---
    lin_pos_path: varchar(200)
    """


In [162]:
day = Day()
epoch = Epoch()
tet = Tetrode()
tet_ep = TetrodeEpoch()
lfp = LFP()
rip = RippleInterval()

In [177]:
tet_ep.anim_tet_infos

In [187]:
%%time
anim = Animal()
anim.insert1({'anim_name': 'Bond', 'anim_path_raw': '/opt/data/daliu/other/mkarlsso/bond/', 
              'anim_path_mat': '/opt/data/daliu/other/mkarlsso/Bon/'})
display(anim)
day = Day()
day.populate()
display(day)
epoch = Epoch()
epoch.populate()
display(epoch)
tet = Tetrode()
tet.populate()
display(tet)
tet_ep = TetrodeEpoch(suppress_missing_print=True)
tet_ep.populate()
display(tet_ep)
lfp = LFP()
rip = RippleInterval()

In [46]:
dj.ERD(schema)