Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DataProcessingTools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import misc, levels, objects, raster, psth
from . import misc, levels, objects, raster, psth, trialstructures
from .spiketrain import Spiketrain
14 changes: 8 additions & 6 deletions DataProcessingTools/levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def resolve_level(target_level, cwd=None):
this_idx = levels.index(this_level)
target_idx = levels.index(target_level)
pl = ["."]
for i in range(0, this_idx - target_idx+1):
for i in range(0, this_idx - target_idx):
pl.append("..")
return os.path.join(*pl)

Expand All @@ -47,14 +47,16 @@ def get_level_dirs(target_level, cwd=None):
this_level = level(cwd)
this_idx = levels.index(this_level)
target_idx = levels.index(target_level)
if target_idx <= this_idx:
if target_idx == this_idx:
dirs = [os.path.join(cwd, ".")]
elif target_idx < this_idx:
rel_path = resolve_level(target_level, cwd)
pattern = level_patterns_s[target_idx]
gpattern = os.path.join(cwd, rel_path, pattern)
dirs = glob.glob(gpattern)
gpattern = os.path.join(cwd, rel_path, "..", pattern)
dirs = sorted(glob.glob(gpattern))
else:
patterns = level_patterns_s[this_idx+1:target_idx+1]
dirs = glob.glob(os.path.join(cwd, *patterns))
dirs = sorted(glob.glob(os.path.join(cwd, *patterns)))
return dirs


Expand All @@ -63,7 +65,7 @@ def get_level_name(target_level, cwd=None):
Return the name of the requested level
"""
if cwd is None:
cwd = os.getwd()
cwd = os.getcwd()

this_level = level(cwd)
this_idx = levels.index(this_level)
Expand Down
4 changes: 2 additions & 2 deletions DataProcessingTools/psth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

class PSTH(DPObject):
def __init__(self, bins, spiketimes=None, trialidx=None, triallabels=None,
trial_events=None):
alignto=None, trial_event=None):
tmin = bins[0]
tmax = bins[-1]
if spiketimes is None:
# attempt to load from the current directory
raster = Raster(tmin, tmax, trial_event=trial_events)
raster = Raster(tmin, tmax, alignto=alignto, trial_event=trial_event)
spiketimes = raster.spiketimes
trialidx = raster.trialidx
triallabels = raster.trial_labels
Expand Down
18 changes: 10 additions & 8 deletions DataProcessingTools/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@
import numpy as np
from .objects import DPObject
from .spiketrain import Spiketrain
from .trialstructures import *
import os


class Raster(DPObject):
def __init__(self, tmin, tmax, trial_event=None,
def __init__(self, tmin, tmax, alignto=None, trial_event=None,
spiketimes=None,
trial_labels=None, dirs=None):
if spiketimes is None:
spiketrain = Spiketrain()
spiketimes = spiketrain.timestamps.flatten()
if trial_event is None:
# TODO: Load trials here
pass
if alignto is None:
trials = get_trials()
# convert from seconds to ms
alignto = 1000*trials.get_timestamps(trial_event)
if trial_labels is None:
trial_labels = np.arange(len(trial_event))
trial_labels = np.arange(len(alignto))

bidx = np.digitize(spiketimes, trial_event+tmin)
idx = (bidx > 0) & (bidx <= np.size(trial_event))
raster = spiketimes[idx] - trial_event[bidx[idx]-1]
bidx = np.digitize(spiketimes, alignto+tmin)
idx = (bidx > 0) & (bidx <= np.size(alignto))
raster = spiketimes[idx] - alignto[bidx[idx]-1]
ridx = (raster > tmin) & (raster < tmax)
self.spiketimes = raster[ridx]
self.trialidx = bidx[idx][ridx]-1
Expand Down
120 changes: 120 additions & 0 deletions DataProcessingTools/trialstructures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from .objects import DPObject
from .levels import *
from .misc import *
import numpy as np
import os
import csv
import re


class TrialStructure(DPObject):
def __init__(self):
DPObject.__init__(self)
self.events = []
self.timestamps = []

def get_timestamps(self, event_label):
"""
Return the timestamps corresponding to the
specified event.
"""
idx = np.where(self.events == event_label)[0]
return self.timestamps[idx]


class WorkingMemoryTrials(TrialStructure):
filename = "event_markers.csv"
level = "day"
def __init__(self):
TrialStructure.__init__(self)
self.trialevents = {"session_start": "11000000",
"trial_start": "00000000",
"fix_start": "00000001",
"stimBlankStart": "00000011",
"delay_start": "00000100",
"response_on": "00000101",
"reward_on": "00000110",
"failure": "00000111",
"trial_end": "00100000",
"manual_reward_on": "00001000",
"stim_start": "00001111",
"reward_off ": "00000100",
"trial_start": "00000010",
"target_on": "10100000",
"target_off": "10000000",
"left_fixation": "00011101"}
self.reverse_map = dict((v,k) for k,v in self.trialevents.items())
self.load()

def load(self):
sessiondir = get_level_name("session")
leveldir = resolve_level(self.level)
with open(os.path.join(leveldir, self.filename), "r") as csvfile:
data = csv.DictReader(csvfile)
for row in data:
word = row["words"]
if word[:2] == "11":
idx = int(word[2:], 2)
event = "".join(("session", str(idx).zfill(2)))
elif (word[:2] == "10") or (word[:2] == "01"):
if word[:2] == "10":
stimid = 1
else:
stimid = 2
if word[2] == "1":
switch = "on"
else:
switch = "off"
locidx = int(word[3:], 2)
event = "stimulus_{0}_{1}_{2}".format(switch, stimid, locidx)
else:
event = self.reverse_map.get(word, None)

if event is not None:
self.events.append(event)
self.timestamps.append(np.float(row["timestamps"]))
if sessiondir:
# filter events to only those in the current session
sidx0 = self.events.index(sessiondir)
sid = int("".join([f for f in filter(str.isdigit, sessiondir)]))
sid += 1
try:
ssid = "".join(("session", str(sid).zfill(2)))
sidx1 = self.events.index(ssid)
except:
sidx1 = len(self.events)
self.events = np.array(self.events[sidx0:sidx1])
self.timestamps = np.array(self.timestamps[sidx0:sidx1]) - self.timestamps[sidx0]
else:
self.events = np.array(self.events)
self.timestamps = np.array(self.timestamps)

def get_timestamps(self, event_label):
"""
Return the timestamps of all events matching `event_label`.
Wildcard can be used as well, so that, for example, to find all
stimulus 1 onsets, regardless of position, use

trials.get_timestamps("stimulus_on_1_*")

"""
idx = np.zeros((len(self.events), ), dtype=np.bool)
p = re.compile(event_label)
for (i,ee) in enumerate(self.events):
m = p.match(ee)
if m is not None:
idx[i] = True

return self.timestamps[idx]

def get_trials():
"""
Attempt to auto-discover the trial structure by looking for a file
corresponding to a known structure in the current working directory
"""
for Trials in TrialStructure.__subclasses__():
leveldir = resolve_level(Trials.level)
with CWD(leveldir):
if os.path.isfile(Trials.filename):
trials = Trials()
return trials
Loading