diff --git a/DataProcessingTools/objects.py b/DataProcessingTools/objects.py index eea8b93..456257c 100644 --- a/DataProcessingTools/objects.py +++ b/DataProcessingTools/objects.py @@ -1,11 +1,45 @@ import numpy as np from . import levels +import h5py +import os class DPObject(): - def __init__(self, dirs=[], setidx=[], *args, **kwargs): - self.dirs = dirs - self.setidx = setidx + argsList = [] + filename = "" + + def __init__(self, *args, **kwargs): + self.dirs = [os.getcwd()] + self.setidx = [] + self.args = {} + # process positional arguments + # TODO: We need to somehow consume these, ie. remove the processed ones + pargs = [p for p in filter(lambda t: not isinstance(t, tuple), type(self).argsList)] + qargs = pargs.copy() + for (k, v) in zip(pargs, args): + self.args[k] = v + qargs.remove(k) + # run the remaining throgh kwargs + for k in qargs: + if k in kwargs.keys(): + self.args[k] = kwargs[k] + + # process keyword arguments + kargs = filter(lambda t: isinstance(t, tuple), type(self).argsList) + for (k, v) in kargs: + self.args[k] = kwargs.get(k, v) + + redoLevel = kwargs.get("redoLevel", 0) + fname = self.get_filename() + if redoLevel == 0 and os.path.isfile(fname): + self.load(fname) + else: + # create object + self.create(*args, **kwargs) + + def create(self, *args, **kwargs): + pass + def plot(self, i, fig): pass @@ -61,6 +95,25 @@ def append(self, obj): for d in obj.dirs: self.dirs.append(d) + def get_filename(self): + """ + Return the base filename with an argument hash + appended + """ + h = self.hash() + fname = self.filename.replace(".mat", "_{0}.mat".format(h)) + return fname + + def hash(self): + pass + + def load(self, fname=None): + if fname is None: + fname = self.get_filename() + + with h5py.File(fname) as ff: + self.dirs = [s.decode() for s in ff["dirs"][:]] + self.setidx = ff["setidx"][:].tolist() class DPObjects(): def __init__(self, objects): diff --git a/DataProcessingTools/psth.py b/DataProcessingTools/psth.py index 77cba4c..418ab65 100644 --- a/DataProcessingTools/psth.py +++ b/DataProcessingTools/psth.py @@ -4,20 +4,34 @@ from . import levels from .raster import Raster import os +import glob +import h5py +import hashlib class PSTH(DPObject): - def __init__(self, bins, windowsize=1, spiketimes=None, trialidx=None, triallabels=None, - alignto=None, trial_event=None, dirs=None): - DPObject.__init__(self) - tmin = bins[0] - tmax = bins[-1] - if spiketimes is None: - # attempt to load from the current directory - raster = Raster(tmin, tmax, alignto=alignto, trial_event=trial_event) - spiketimes = raster.spiketimes - trialidx = raster.trialidx - triallabels = raster.trial_labels + """ + PSTH(bins,windowSize=1, dirs=None, redoLevel=0, saveLevel=1) + """ + filename = "psth.mat" + argsList = ["bins", ("windowSize", 1)] + + def __init__(self, *args, **kwargs): + """ + Return a PSTH object using the specified bins + """ + DPObject.__init__(self, *args, **kwargs) + + def create(self, *args, **kwargs): + saveLevel = kwargs.get("saveLevel", 1) + bins = self.args["bins"] + + # attempt to load from the current directory + raster = Raster(bins[0], bins[-1], **kwargs) + spiketimes = raster.spiketimes + trialidx = raster.trialidx + self.trialLabels = raster.trialLabels + ntrials = trialidx.max()+1 counts = np.zeros((ntrials, np.size(bins)), dtype=np.int) for i in range(np.size(spiketimes)): @@ -25,38 +39,63 @@ def __init__(self, bins, windowsize=1, spiketimes=None, trialidx=None, triallabe if 0 <= jj < np.size(bins): counts[trialidx[i], jj] += 1 - self.windowsize = windowsize - if windowsize > 1: - scounts = np.zeros((ntrials, len(bins)-windowsize)) + windowSize = self.args["windowSize"] + if windowSize > 1: + scounts = np.zeros((ntrials, len(bins)-windowSize+1)) for i in range(ntrials): - for j in range(len(bins)-windowsize): - scounts[i, j] = counts[i, j:j+windowsize].sum() + for j in range(len(bins)-windowSize+1): + scounts[i, j] = counts[i, j:j+windowSize].sum() self.data = scounts - self.bins = bins[:-windowsize] + self.bins = np.array(bins[:-windowSize]) else: self.data = counts - self.bins = bins + self.bins = np.array(bins) self.ntrials = ntrials - if triallabels is None: - self.trial_labels = np.ones((ntrials,)) - elif np.size(triallabels) == ntrials: - self.trial_labels = triallabels - elif np.size(triallabels) == np.size(spiketimes): - dd = {} - for t, l in zip(trialidx, triallabels): - dd[t] = l - self.trial_labels = np.array([dd[t] for t in range(ntrials)]) - else: - self.trial_labels = triallabels # index to keep track of sets, e.g. trials self.setidx = [0 for i in range(self.ntrials)] - if dirs is not None: - self.dirs = dirs - else: - self.dirs = [os.getcwd()] + + if saveLevel > 0: + self.save() + + def load(self, fname=None): + DPObject.load(self) + if fname is None: + fname = self.filename + with h5py.File(fname) as ff: + args = {} + for (k, v) in ff["args"].items(): + self.args[k] = v.value + self.data = ff["counts"][:] + self.ntrials = self.data.shape[0] + self.bins = self.args["bins"][:self.data.shape[-1]] + self.trialLabels = ff["trialLabels"][:] + + def hash(self): + """ + Returns a hash representation of this object's arguments. + """ + #TODO: This is not replicable across sessions + h = hashlib.sha1(b"psth") + for (k, v) in self.args.items(): + x = np.atleast_1d(v) + h.update(x.tobytes()) + return h.hexdigest() + + def save(self, fname=None): + if fname is None: + fname = self.get_filename() + + with h5py.File(fname, "w") as ff: + args = ff.create_group("args") + args["bins"] = self.args["bins"] + args["windowSize"] = self.args["windowSize"] + ff["counts"] = self.data + ff["trialLabels"] = self.trialLabels + ff["dirs"] = np.array(self.dirs, dtype='S256') + ff["setidx"] = self.setidx def append(self, psth): if not (self.bins == psth.bins).all(): @@ -64,7 +103,7 @@ def append(self, psth): DPObject.append(self, psth) self.data = np.concatenate((self.data, psth.data), axis=0) - self.trial_labels = np.concatenate((self.trial_labels, psth.trial_labels), + self.trialLabels = np.concatenate((self.trialLabels, psth.trialLabels), axis=0) self.ntrials = self.ntrials + psth.ntrials @@ -73,13 +112,13 @@ def plot(self, i=None, ax=None, overlay=False): ax = gca() if not overlay: ax.clear() - trial_labels = self.trial_labels[i] + trialLabels = self.trialLabels[i] data = self.data[i, :] - labels = np.unique(trial_labels) + labels = np.unique(trialLabels) for li in range(len(labels)): label = labels[li] - idx = trial_labels == label + idx = trialLabels == label mu = data[idx, :].mean(0) sigma = data[idx, :].std(0) ax.plot(self.bins, mu) diff --git a/DataProcessingTools/raster.py b/DataProcessingTools/raster.py index 4262a0f..68d947c 100644 --- a/DataProcessingTools/raster.py +++ b/DataProcessingTools/raster.py @@ -7,43 +7,69 @@ class Raster(DPObject): - def __init__(self, tmin, tmax, alignto=None, trial_event=None, - spiketimes=None, - trial_labels=None, dirs=None): - DPObject.__init__(self) - if spiketimes is None: - spiketrain = Spiketrain() - spiketimes = spiketrain.timestamps.flatten() - if alignto is None: - trials = get_trials() - # convert from seconds to ms - alignto = 1000*trials.get_timestamps(trial_event) - if trial_labels is None: - trial_labels = np.arange(len(alignto)) + """ + Raster(tmin, tmax, TrialEvent,trialType, sortBy) + """ + filename = "raster.mat" + argsList = ["tmin", "tmax", "trialEvent", "trialType", "sortBy"] + def __init__(self, *args, **kwargs): + DPObject.__init__(self, *args, **kwargs) + + def create(self, *args, **kwargs): + trials = get_trials() + #TODO: This only works with correct trials for now + rewardOnset, cidx, stimIdx = trials.get_timestamps("reward_on") + trialEvent = self.args["trialEvent"] + if "stimulus" in trialEvent: + if trialEvent == "stimulus1": + stimnr = 0 + elif trialEvent == "stimulus2": + stimnr = 1 + else: + stimnr = -1 + ValueError("Unkonwn trial sorting {0}".format(trialEvent)) + + alignto, stimidx, trialLabel = trials.get_stim(stimnr, cidx) + alignto = 1000*np.array(alignto) + else: + alignto, sidx, stimIdx = trials.get_timestamps(trialEvent) + qidx = np.isin(sidx, cidx) + trialIdx = sidx[qidx] + alignto = 1000*alignto[qidx] + sortBy = self.args["sortBy"] + if sortBy == "stimulus1": + stimnr = 0 + elif sortBy == "stimulus2": + stimnr = 1 + else: + ValueError("Unkonwn trial sorting {0}".format(sortBy)) + + ts, identity, trialLabel = trials.get_stim(stimnr, trialIdx) + #TODO: Never reload spike trains + spiketrain = Spiketrain(*args, **kwargs) + spiketimes = spiketrain.timestamps.flatten() + tmin = self.args["tmin"] + tmax = self.args["tmax"] bidx = np.digitize(spiketimes, alignto+tmin) idx = (bidx > 0) & (bidx <= np.size(alignto)) raster = spiketimes[idx] - alignto[bidx[idx]-1] ridx = (raster > tmin) & (raster < tmax) self.spiketimes = raster[ridx] self.trialidx = bidx[idx][ridx]-1 - self.trial_labels = trial_labels + self.trialLabels = trialLabel self.setidx = [0 for i in range(len(self.trialidx))] - if dirs is None: - self.dirs = [os.getcwd()] - else: - self.dirs = dirs def append(self, raster): DPObject.append(self, raster) n_old = len(self.spiketimes) n_new = n_old + len(raster.spiketimes) - self.spiketimes.resize(n_new) + self.spiketimes = np.resize(self.spiketimes, n_new) self.spiketimes[n_old:n_new] = raster.spiketimes - self.trialidx.resize(n_new) + self.trialidx = np.resize(self.trialidx, n_new) self.trialidx[n_old:n_new] = raster.trialidx - self.trial_labels = np.concatenate((self.trial_labels, raster.trial_labels)) + self.trialLabels = np.concatenate((self.trialLabels, raster.trialLabels)) def plot(self, idx=None, ax=None, overlay=False): if ax is None: diff --git a/DataProcessingTools/spiketrain.py b/DataProcessingTools/spiketrain.py index 170d0dd..e2d2137 100644 --- a/DataProcessingTools/spiketrain.py +++ b/DataProcessingTools/spiketrain.py @@ -5,14 +5,22 @@ class Spiketrain(DPObject): - def __init__(self): - DPObject.__init__(self) - self.level = "cell" - self.filename = "unit.mat" + """ + Spiketrain() + """ + level = "cell" + filename = "unit.mat" + + def __init__(self, *args, **kwargs): + DPObject.__init__(self, *args, **kwargs) + # always load since we do not create spike trains here. if os.path.isfile(self.filename): self.load() - def load(self): + def load(self, fname=None): q = sio.loadmat(self.filename) self.timestamps = q["timestamps"] self.spikeshape = q["spikeForm"] + + def get_filename(self): + return self.filename diff --git a/DataProcessingTools/trialstructures.py b/DataProcessingTools/trialstructures.py index b86f642..61fc9ec 100644 --- a/DataProcessingTools/trialstructures.py +++ b/DataProcessingTools/trialstructures.py @@ -8,10 +8,10 @@ class TrialStructure(DPObject): - def __init__(self): - DPObject.__init__(self) + def __init__(self, **kwargs): self.events = [] self.timestamps = [] + DPObject.__init__(self, **kwargs) def get_timestamps(self, event_label): """ @@ -25,34 +25,37 @@ def get_timestamps(self, event_label): class WorkingMemoryTrials(TrialStructure): filename = "event_markers.csv" level = "day" - def __init__(self): - TrialStructure.__init__(self) - self.trialevents = {"session_start": "11000000", - "trial_start": "00000000", - "fix_start": "00000001", - "stimBlankStart": "00000011", - "delay_start": "00000100", - "response_on": "00000101", - "reward_on": "00000110", - "failure": "00000111", - "trial_end": "00100000", - "manual_reward_on": "00001000", - "stim_start": "00001111", - "reward_off ": "00000100", - "trial_start": "00000010", - "target_on": "10100000", - "target_off": "10000000", - "left_fixation": "00011101"} - self.reverse_map = dict((v,k) for k,v in self.trialevents.items()) + trialevents = {"session_start": "11000000", + "fix_start": "00000001", + "stimBlankStart": "00000011", + "delay_start": "00000100", + "response_on": "00000101", + "reward_on": "00000110", + "failure": "00000111", + "trial_end": "00100000", + "manual_reward_on": "00001000", + "stim_start": "00001111", + "reward_off": "00000100", + "trial_start": "00000010", + "target_on": "10100000", + "target_off": "10000000", + "left_fixation": "00011101"} + + def __init__(self, **kwargs): + self.reverse_map = dict((v, k) for k, v in self.trialevents.items()) + TrialStructure.__init__(self, **kwargs) + # always load self.load() - def load(self): + def load(self, fname=None): sessiondir = get_level_name("session") leveldir = resolve_level(self.level) tidx = -1 stidx = -1 self.trialidx = [] self.stimidx = [] + self.events = [] + self.timestamps = [] with open(os.path.join(leveldir, self.filename), "r") as csvfile: data = csv.DictReader(csvfile) for row in data: @@ -113,14 +116,19 @@ def get_timestamps(self, event_label): trials.get_timestamps("stimulus_on_1_*") """ - idx = np.zeros((len(self.events), ), dtype=np.bool) + events = self.events + trialidx = self.trialidx + timestamps = self.timestamps + stimidx = self.stimidx + + idx = np.zeros((len(events), ), dtype=np.bool) p = re.compile(event_label) - for (i,ee) in enumerate(self.events): + for (i ,ee) in enumerate(events): m = p.match(ee) - if m is not None: + if m is not None: idx[i] = True - - return self.timestamps[idx], self.trialidx[idx], self.stimidx[idx] + + return timestamps[idx], trialidx[idx], stimidx[idx] def get_stim(self, stimidx=0, trialidx=None): """ @@ -151,7 +159,6 @@ def get_trials(): """ for Trials in TrialStructure.__subclasses__(): leveldir = resolve_level(Trials.level) - with CWD(leveldir): - if os.path.isfile(Trials.filename): - trials = Trials() - return trials + if os.path.isfile(os.path.join(leveldir, Trials.filename)): + trials = Trials() + return trials diff --git a/README.md b/README.md index a249c79..64531a3 100644 --- a/README.md +++ b/README.md @@ -27,47 +27,26 @@ celldirs = DPT.levels.get_level_dirs("cell", cwd) ### Objects ```python import DataProcessingTools as DPT -spiketimes = np.cumsum(np.random.exponential(0.3, 100000)) -trialidx = np.random.random_integers(0, 100, (100000, )) -trial_labels = np.random.random_integers(1, 9, (101, )) -bins = np.arange(0, 100.0, 2.0) -psth1 = DPT.psth.PSTH(spiketimes, trialidx, bins, trial_labels) -psth1.dirs = ["Pancake/20130923/session01/array01/channel001/cell01"] -spiketimes = np.cumsum(np.random.exponential(0.3, 100000)) -trialidx = np.random.random_integers(0, 100, (100000, )) -psth2 = DPT.psth.PSTH(spiketimes, trialidx, bins, trial_labels) -psth2.dirs = ["Pancake/20130923/session01/array01/channel002/cell01"] - -# Concatenate the two PSTH objects into a list for plotting -ppsth = DPT.objects.DPObjects([psth1, psth2]) - -# Plot the first PSTH -fig = ppsth.plot(0) - -# Plot the second PSTH without overlay -ppsth.plot(1, overlay=False) - -# Append psth2 to psth1, creating an object spanning mulitple sets -psth1.append(psth2) - -# To access the data for the first cell in this compound object -cell_idx = psth1.getindex("cell") - -# The above returns a function that gives the index into the object's data -# corresponding to the cell level. - -# To access the data related to the first cell, we can do this -cell1_data = psth1.data[cell_idx(0),:] - -# If, on the other hand, we want to group the data by session, we could do - -session_idx = psth1.getindex("session") -session1_data = psth1.data[session_idx(0),:] - -# For this example, since we only have a single session, we would be looking -# at all the data. +class DirFiles(DPT.DPObject): + """ + DirFiles(redoLevels=0, SaveLevels=0, objectLevel='Session') + """ + def __init__(self, *args, **kwargs): + # initialize fields in parent + DPT.DPObject.__init__(self, *args, **kwargs) + # check for files or directories in current directory + dir_listing = os.listdir() + # check number of items + dnum = len(dir_listing) + # create object if there are some items in this directory + if dnum > 0: + # update fields in parent + self.dirs = os.getcwd() + self.dir_list = dir_listing + self.setidx = [0 for i in range(dnum)] ``` + ### A complete example Here is an example of how to compute raster and psth for a list for a list of cells and step through plots of both using PanGUI @@ -81,25 +60,6 @@ import PanGUI # change this to wherever you keep the data hierarchy datadir = os.path.expanduser("~/Documents/workingMemory") -# get the trials for one session -with DPT.misc.CWD(os.path.join(datadir, "Whiskey/20200106/session02")): - trials = DPT.trialstructures.WorkingMemoryTrials() - -# get response trials -response_cue, ridx, _ = trials.get_timestamps("response_on") - -# get error trials -failure, eidx, _ = trials.get_timestamps("failure") - -#error trials with response cue -reidx = np.intersect1d(ridx, eidx) - -# get stimulus onset of the error trials -stim_onset, identity, location = trials.get_stim(0, reidx) - -#convert from seconds to ms -stim_onset = 1000*np.array(stim_onset) - # get all the cells for one session with DPT.misc.CWD(os.path.join(datadir, "Whiskey/20200106/session02")): cells = DPT.levels.get_level_dirs("cell") @@ -107,13 +67,18 @@ with DPT.misc.CWD(os.path.join(datadir, "Whiskey/20200106/session02")): # gather rasters and PSTH for these cells, for the error trials bins = np.arange(-300, 1000.0, 10) with DPT.misc.CWD(cells[0]): - raster = DPT.raster.Raster(-300.0, 1000.0, stim_onset) - psth = DPT.psth.PSTH(bins, 10, raster.spiketimes, raster.trialidx, location) + raster = DPT.raster.Raster(-300.0, 1000.0, "stimulus1", "reward_on", "stimulus1", + redoLevel=1, saveLevel=0) + psth = DPT.psth.PSTH(bins, 10, trialEvent="stimulus1", sortBy="stimulus1", trialType="reward_on", + redoLevel=1, saveLevel=0) for cell in cells[1:]: with DPT.misc.CWD(cell): - praster = DPT.raster.Raster(-300.0, 1000.0, stim_onset) + praster = DPT.raster.Raster(-300.0, 1000.0, "stimulus1", "reward_on", "stimulus1", + redoLevel=1, saveLevel=0) raster.append(praster) - psth.append(DPT.psth.PSTH(bins, 10, praster.spiketimes, praster.trialidx, location)) + ppsth = DPT.psth.PSTH(bins, 10, trialEvent="stimulus1", sortBy="stimulus1", trialType="reward_on", + redoLevel=1, saveLevel=0) + psth.append(ppsth) app = PanGUI.create_window([raster, psth], cols=1, indexer="cell") ``` diff --git a/requirements.txt b/requirements.txt index f23634b..830b4c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy scipy matplotlib +h5py diff --git a/setup.py b/setup.py index d24cda2..6d166a8 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup setup(name="DataProcessingTools", - version="0.6.0", + version="0.7.0", description="""Tools for processing data with hierarchical organization""", url="https://github.com/grero/DataProcessingTools.git", author="Roger Herikstad", diff --git a/tests/test_chain.py b/tests/test_chain.py index cf7a9df..8b726f8 100644 --- a/tests/test_chain.py +++ b/tests/test_chain.py @@ -1,32 +1,118 @@ import DataProcessingTools as DPT +from DataProcessingTools.trialstructures import WorkingMemoryTrials import tempfile import os import numpy as np import scipy.io as sio +import csv +import matplotlib.pylab as plt def test_load(): - spiketimes = np.array([0.1, 0.15, 0.3, 0.4, 0.5, 0.6, - 1.1, 1.2, 1.3, 1.4, 2.5, 2.6]) - trial_events = np.array([0.1, 1.1]) + spiketimes = 1000*np.array([0.1, 0.15, 0.3, 0.4, 0.5, 0.6, + 1.1, 1.2, 1.3, 1.4, 2.5, 2.6]) + trialEvents = [("11000001", 0.0), + (WorkingMemoryTrials.trialevents["trial_start"], 0.05), + ("10100001", 0.1), + ("10000001", 0.2), + ("01100001", 0.1), + ("01000001", 0.2), + (WorkingMemoryTrials.trialevents["response_on"], 0.3), + (WorkingMemoryTrials.trialevents["reward_on"], 0.4), + (WorkingMemoryTrials.trialevents["reward_off"], 0.5), + (WorkingMemoryTrials.trialevents["trial_end"], 0.6), + (WorkingMemoryTrials.trialevents["trial_start"], 1.0), + ("10100001", 1.1), + ("10000001", 1.2), + ("01100001", 1.1), + ("01000001", 1.2), + (WorkingMemoryTrials.trialevents["response_on"], 1.3), + (WorkingMemoryTrials.trialevents["reward_on"], 1.4), + (WorkingMemoryTrials.trialevents["reward_off"], 1.5), + (WorkingMemoryTrials.trialevents["trial_end"], 1.6)] tempdir = tempfile.gettempdir() with DPT.misc.CWD(tempdir): - pth = "Pancake/20130923/session01/array01/channel001/cell01" - if not os.path.isdir(pth): - os.makedirs(pth) - sio.savemat(os.path.join(pth, "unit.mat"), {"timestamps": spiketimes, - "spikeForm": [0]}) - with DPT.misc.CWD(pth): - raster = DPT.raster.Raster(-0.1, 0.5, trial_events) + pth = "Pancake/20130923/session01/array01/channel001" + for cell in ["cell01", "cell02"]: + if not os.path.isdir(os.path.join(pth, cell)): + os.makedirs(os.path.join(pth, cell)) + sio.savemat(os.path.join(pth, cell, "unit.mat"), {"timestamps": spiketimes, + "spikeForm": [0.0]}) + + with DPT.misc.CWD(os.path.join(pth, "cell01")): + with DPT.misc.CWD(DPT.levels.resolve_level("day")): + # save trial info + with open("event_markers.csv", "w") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(["words", "timestamps"]) + for ee in trialEvents: + writer.writerow(ee) + trials = DPT.trialstructures.get_trials() + assert len(trials.events) == len(trialEvents) + + trials = DPT.trialstructures.get_trials() + assert len(trials.events) == len(trialEvents) + stim_onset, tidx, sidx = trials.get_timestamps("stimulus_on_1_*") + assert (stim_onset == [0.1, 1.1]).all() + spiketrain = DPT.spiketrain.Spiketrain() + assert np.allclose(spiketrain.timestamps, spiketimes) + raster = DPT.raster.Raster(-100.0, 500.0, "stimulus1", "reward_on", "stimulus1", + redoLevel=1, saveLevel=0) assert (raster.trialidx == [0, 0, 0, 0, 0, 1, 1, 1, 1]).all() - assert np.isclose(raster.spiketimes, [0.0, 0.05, 0.2, 0.3, 0.4, 0.0, 0.1, 0.2, 0.3]).all() - psth = DPT.psth.PSTH([-0.1, 0.2, 0.4, 0.6], alignto=trial_events) + assert np.isclose(raster.spiketimes, [0., 50., 200., 300., 400., 0., 100., 200., 300.]).all() + raster2 = DPT.raster.Raster(-100.0, 500.0, "stimulus2", "reward_on", "stimulus2", + redoLevel=1, saveLevel=0) + assert (raster2.trialidx == [0, 0, 0, 0, 0, 1, 1, 1, 1]).all() + assert np.isclose(raster2.spiketimes, [0., 50., 200., 300., 400., 0., 100., 200., 300.]).all() + raster3 = DPT.raster.Raster(-100.0, 500.0, "reward_on", "reward_on", "stimulus2", + redoLevel=1, saveLevel=0) + assert (raster3.trialidx == [0, 0, 0, 1]).all() + psth = DPT.psth.PSTH([-100., 200., 400., 600.], 1, trialEvent="stimulus1", + sortBy="stimulus1", + trialType="reward_on", + redoLevel=1, saveLevel=0) assert psth.data.shape == (2, 4) assert (psth.data[0, :] == [0, 3, 2, 1]).all() assert (psth.data[1, :] == [0, 3, 1, 0]).all() + psths = DPT.psth.PSTH([-100., 200., 400., 600.], windowSize=2, trialEvent="stimulus1", + sortBy="stimulus1", + trialType="reward_on", + redoLevel=1, saveLevel=0) + assert psths.args["windowSize"] == 2 + assert psths.data.shape == (2, 3) + + with DPT.misc.CWD(os.path.join(pth, "cell02")): + raster2 = DPT.raster.Raster(-100.0, 500.0, "stimulus1", "reward_on", "stimulus1", + redoLevel=1, saveLevel=0) + psth2 = DPT.psth.PSTH([-100., 200., 400., 600.], 1, trialEvent="stimulus1", + sortBy="stimulus1", + trialType="reward_on", + redoLevel=1, saveLevel=1) + + psth3 = DPT.psth.PSTH([-100., 200., 400., 600.], 1, trialEvent="stimulus1", + sortBy="stimulus1", + trialType="reward_on", + redoLevel=0, saveLevel=0) + os.remove(psth3.get_filename()) + assert (psth2.data == psth3.data).all() + assert (psth2.args["windowSize"] == psth3.args["windowSize"]) + assert (psth2.args["bins"] == psth3.args["bins"]).all() - os.remove(os.path.join(pth, "unit.mat")) + raster.append(raster2) + cellidx = raster.getindex("cell") + assert len(cellidx(1)) == len(raster2.setidx) + raster.plot(cellidx(0)) + xy = plt.gca().lines[0].get_data() + assert np.allclose(xy[1], raster.trialidx[cellidx(0)]) + psth.append(psth2) + cellidx = psth.getindex("cell") + assert len(cellidx(1)) == len(psth2.setidx) + psth.plot(cellidx(0)) + assert len(plt.gca().lines) == 1 + os.remove(os.path.join(pth, "cell01", "unit.mat")) + os.remove(os.path.join(pth, "cell02", "unit.mat")) os.rmdir("Pancake/20130923/session01/array01/channel001/cell01") + os.rmdir("Pancake/20130923/session01/array01/channel001/cell02") os.rmdir("Pancake/20130923/session01/array01/channel001") os.rmdir("Pancake/20130923/session01/array01") diff --git a/tests/test_objects.py b/tests/test_objects.py index 69e7fff..ba8bb16 100644 --- a/tests/test_objects.py +++ b/tests/test_objects.py @@ -1,10 +1,21 @@ import DataProcessingTools as DPT +class MyObj(DPT.objects.DPObject): + def __init__(self, dirs, *args, **kwargs): + DPT.objects.DPObject.__init__(self, *args, **kwargs) + self.dirs = dirs + for i in range(len(dirs)): + self.setidx.extend([i for j in range(3)]) + + def test_level_idx(): - obj = DPT.DPObject(dirs=["session01/array01/channel01/cell01", - "session01/array01/channel01/cell02"], - setidx=[0, 0, 0, 1, 1, 1]) + obj = MyObj(dirs=["session01/array01/channel01/cell01", + "session01/array01/channel01/cell02"]) + + # test trial + idx = obj.getindex("trial") + assert idx(0) == [0] # test cell level idx = obj.getindex("cell") @@ -17,13 +28,12 @@ def test_level_idx(): def test_append(): - obj1 = DPT.objects.DPObject(dirs=["session01/array01/channel001/cell01", - "session01/array01/channel001/cell02"], - setidx=[0, 0, 0, 1, 1, 1]) - obj2 = DPT.objects.DPObject(dirs=["session01/array02/channel033/cell01", - "session01/array02/channel034/cell01"], - setidx=[0, 0, 0, 1, 1, 1]) + obj1 = MyObj(dirs=["session01/array01/channel001/cell01", + "session01/array01/channel001/cell02"]) + + obj2 = MyObj(dirs=["session01/array02/channel033/cell01", + "session01/array02/channel034/cell01"]) obj1.append(obj2) @@ -45,3 +55,16 @@ def test_append(): idx = obj1.getindex(None) assert idx(0) is None + + +def test_object(): + + class MyObj2(DPT.objects.DPObject): + argsList = ["tmin", "tmax"] + filename = "test.mat" + + obj = MyObj2(-0.1, 1.0) + + assert obj.args["tmin"] == -0.1 + assert obj.args["tmax"] == 1.0 + diff --git a/tests/test_psth.py b/tests/test_psth.py deleted file mode 100644 index 4db38c8..0000000 --- a/tests/test_psth.py +++ /dev/null @@ -1,68 +0,0 @@ -import numpy as np -import DataProcessingTools as DPT -import matplotlib.pylab as plt - - -def test_psth(): - spiketimes = np.cumsum(np.random.exponential(0.3, 100000)) - trialidx = np.random.random_integers(0, 100, (100000, )) - trial_labels = np.random.random_integers(1, 9, (101, )) - bins = np.arange(0, 100.0, 2.0) - psth = DPT.psth.PSTH(bins, 1, spiketimes, trialidx, trial_labels, - dirs=["Pancake/20130923/session01/array01/channel001/cell01"]) - - assert psth.data.shape[0] == 101 - - idx = psth.update_idx(1) - assert idx == 1 - - idx = psth.update_idx(201) - assert idx == 100 - - psth.plot() - fig = plt.gcf() - assert len(fig.axes[0].lines) == 9 - - spiketimes = np.cumsum(np.random.exponential(0.3, 100000)) - trialidx = np.random.random_integers(0, 100, (100000, )) - - psth2 = DPT.psth.PSTH(bins, 1, spiketimes, trialidx, trial_labels, - dirs= ["Pancake/20130923/session01/array01/channel002/cell01"]) - ppsth = DPT.objects.DPObjects([psth]) - ppsth.append(psth2) - assert ppsth[0] == psth - assert ppsth[1] == psth2 - - ppsth.plot(0) - fig = plt.gcf() - assert len(fig.axes[0].lines) == 9 - ppsth.plot(1, ax=fig.axes[0]) - assert len(fig.axes[0].lines) == 9 - - ppsth.plot(0, ax=fig.axes[0], overlay=True) - assert len(fig.axes[0].lines) == 18 - - # test appending objects - psth.append(psth2) - - assert psth.data.shape[0] == 202 - assert psth.trial_labels.shape[0] == 202 - assert len(psth.setidx) == 202 - - idx = psth.getindex("cell") - data = psth.data[idx(1), :] - assert (data == psth2.data).all() - - idx = psth.getindex("session") - data = psth.data[idx(0), :] - assert (data == psth.data).all() - -def test_sliding_psth(): - spiketimes = np.cumsum(np.random.exponential(0.3, 100000)) - trialidx = np.random.random_integers(0, 100, (100000, )) - trial_labels = np.random.random_integers(1, 9, (101, )) - bins = np.arange(0, 100.0, 2.0) - psth = DPT.psth.PSTH(bins, 10, spiketimes, trialidx, trial_labels, - dirs=["Pancake/20130923/session01/array01/channel001/cell01"]) - assert psth.data.shape == (101, len(bins) - 10) - assert psth.bins.shape == (len(bins) - 10, ) diff --git a/tests/test_raster.py b/tests/test_raster.py deleted file mode 100644 index 989f071..0000000 --- a/tests/test_raster.py +++ /dev/null @@ -1,33 +0,0 @@ -import DataProcessingTools as DPT -import numpy as np -import tempfile -import os -import scipy.io as sio -import matplotlib.pylab as plt - - -def test_basic(): - - spiketimes = np.array([0.1, 0.15, 0.3, 0.4, 0.5, 0.6, - 1.1, 1.2, 1.3, 1.4, 2.5, 2.6]) - trial_events = np.array([0.1, 1.1]) - - raster = DPT.raster.Raster(-0.1, 0.5, alignto=trial_events, spiketimes=spiketimes, - dirs=["session01/array01/channel001/cell01"]) - assert (raster.trialidx == [0, 0, 0, 0, 0, 1, 1, 1, 1]).all() - assert np.isclose(raster.spiketimes, [0.0, 0.05, 0.2, 0.3, 0.4, 0.0, 0.1, 0.2, 0.3]).all() - - # test plotting - fig = plt.figure() - ax = fig.add_subplot(111) - raster.plot(ax=ax) - x, y = ax.lines[0].get_data() - assert np.isclose(raster.spiketimes, x).all() - assert np.isclose(raster.trialidx, y).all() - - raster.append(DPT.raster.Raster(-0.1, 0.5, alignto=trial_events, spiketimes=spiketimes, - dirs=["session01/array01/channel002/cell01"])) - cell_idx = raster.getindex("cell") - for i in range(2): - assert np.allclose(raster.spiketimes[cell_idx(i)], [0.0, 0.05, 0.2, 0.3, 0.4, 0.0, 0.1, 0.2, 0.3]) - assert (raster.trialidx[cell_idx(i)] == [0, 0, 0, 0, 0, 1, 1, 1, 1]).all() diff --git a/tests/test_trialstructures.py b/tests/test_trialstructures.py deleted file mode 100644 index bdfa1e8..0000000 --- a/tests/test_trialstructures.py +++ /dev/null @@ -1,71 +0,0 @@ -import DataProcessingTools as DPT -import numpy as np -import os - -def test_basic(): - events = [("trial_start", 0.1), - ("target_on", 0.2), - ("target_off", 0.5), - ("reward_on", 0.6), - ("reward_off", 0.8), - ("trial_end", 0.85), - ("trial_start", 1.1), - ("target_on", 1.2), - ("target_off", 1.5), - ("failure_on", 1.6), - ("failure_off", 1.8), - ("trial_end", 1.85)] - - trials = DPT.trialstructures.TrialStructure() - trials.events = np.array([event[0] for event in events]) - trials.timestamps = np.array([event[1] for event in events]) - - trial_starts = trials.get_timestamps("trial_start") - assert (trial_starts == [0.1, 1.1]).all() - -def test_working_memory(): - testdir = "animal/20130923" - with DPT.misc.CWD(os.path.join(os.path.dirname(__file__),testdir)): - trials = DPT.trialstructures.WorkingMemoryTrials() - assert len(trials.events) == 12746 - trial_starts, trialidx, stimidx = trials.get_timestamps("trial_start") - assert len(trial_starts) == 1733 - assert np.allclose(trial_starts[:2], [47.7139, 50.11616667]) - - stim1_onset, trialidx, stimidx = trials.get_timestamps("stimulus_on_1_*") - assert len(stim1_onset) == 1915 - assert np.allclose(stim1_onset[-5:], - [9107.5137, 9117.90566667, 9164.79406667, 9254.83466667, 9259.95556667]) - - stim1_offset, trialidx, stimidx = trials.get_timestamps("stimulus_off_1_*") - assert len(stim1_offset) == len(stim1_onset) - reward_on, trialidx, stimidx = trials.get_timestamps("reward_on") - assert len(reward_on) == 395 - assert trials.trialidx[-1] == 1732 - - stim1_onset, identity, location = trials.get_stim(0) - assert len(stim1_onset) == 1434 - assert identity.count(1) == 1434 - assert location.count(0) == 417 - assert location.count(1) == 274 - assert location.count(2) == 431 - assert location.count(3) == 312 - - # only correct trials - reward_onset, ridx, _ = trials.get_timestamps("reward_on") - stim1_onset, identity, location = trials.get_stim(0, ridx) - assert len(stim1_onset) == 395 - - testdir = "animal/20130923/session01" - with DPT.misc.CWD(os.path.join(os.path.dirname(__file__),testdir)): - trials = DPT.trialstructures.WorkingMemoryTrials() - trial_starts,trialidx, stimidx = trials.get_timestamps("trial_start") - assert len(trial_starts) == 50 - assert trial_starts[0] == 0.003666666666667595 - - -def test_auto_discovery(): - testdir = "animal/20130923/session01/array01/channel001/cell01" - with DPT.misc.CWD(os.path.join(os.path.dirname(__file__), testdir)): - trials = DPT.trialstructures.get_trials() - assert len(trials.events) == 12746 \ No newline at end of file