Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
af9d6e5
Adds args dictionary to psth
grero Jun 27, 2020
324bf4e
Adds code loading and saving PSTH objects
grero Jun 27, 2020
ab55da8
Updates psth tests
grero Jun 27, 2020
65df1f1
Updates chain tests
grero Jun 27, 2020
724ef6c
Adds h5py to requirements
grero Jun 27, 2020
790ef35
Adds some general load code to base object
grero Jun 28, 2020
7feb722
Adds code for processing arguments to DPObject
grero Jun 29, 2020
c592e50
Updates Spiketrain class with new argument handling
grero Jun 29, 2020
fc48ecd
Updates Raster object for new argument handling
grero Jun 29, 2020
7c49f0b
Fixes variable name
grero Jun 29, 2020
3d1e7d4
Updates PSTH class with new argument handler
grero Jun 29, 2020
3704c1f
Adds test for simple object
grero Jun 29, 2020
506ff27
Changes to hashing code
grero Jun 29, 2020
d80bd5f
Updates WorkingMemoryTrials
grero Jun 29, 2020
0dc90de
Fixes test_chain
grero Jun 29, 2020
a01a71e
Bumps version number
grero Jun 29, 2020
134a042
Merge branch 'master' into argcheck
grero Jun 29, 2020
61819ae
Fixes bug in test_object
grero Jun 29, 2020
8c60245
Always attempt to load trialstructures
grero Jun 30, 2020
ae68563
Fixes a bug where session was not recognized
grero Jun 30, 2020
6cb9e60
Fixes raster alignment
grero Jun 30, 2020
1d7caa4
Fixed trialLabel spelling
grero Jun 30, 2020
9115bd5
Fixes a bug in tests
grero Jun 30, 2020
c07d159
Fixes a bug in resizing for appending rasters
grero Jun 30, 2020
242f466
Adds code to always convert to np.array
grero Jun 30, 2020
ace8e49
Adds tests for appending
grero Jun 30, 2020
d27079d
Adds tests for aligning to second stimulus
grero Jun 30, 2020
f02cce8
Adds tests for plots
grero Jun 30, 2020
af4dbb8
Adds test for non-stimulus alignment
grero Jun 30, 2020
f60bfd2
Fixes a bug in sliding psth counts
grero Jun 30, 2020
b164d1a
Adsd check for sliding psth
grero Jun 30, 2020
234a635
Adds test for trial level indexing
grero Jun 30, 2020
3cb97ee
Updates README
grero Jun 30, 2020
00f9ccf
Adds simple object example
grero Jun 30, 2020
d594b4e
Adds tests for saving and loading PSTH
grero Jun 30, 2020
db8c06e
Fixes an issue with reading from hdf5
grero Jun 30, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 56 additions & 3 deletions DataProcessingTools/objects.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,45 @@
import numpy as np
from . import levels
import h5py
import os


class DPObject():
def __init__(self, dirs=[], setidx=[], *args, **kwargs):
self.dirs = dirs
self.setidx = setidx
argsList = []
filename = ""

def __init__(self, *args, **kwargs):
self.dirs = [os.getcwd()]
self.setidx = []
self.args = {}
# process positional arguments
# TODO: We need to somehow consume these, ie. remove the processed ones
pargs = [p for p in filter(lambda t: not isinstance(t, tuple), type(self).argsList)]
qargs = pargs.copy()
for (k, v) in zip(pargs, args):
self.args[k] = v
qargs.remove(k)
# run the remaining throgh kwargs
for k in qargs:
if k in kwargs.keys():
self.args[k] = kwargs[k]

# process keyword arguments
kargs = filter(lambda t: isinstance(t, tuple), type(self).argsList)
for (k, v) in kargs:
self.args[k] = kwargs.get(k, v)

redoLevel = kwargs.get("redoLevel", 0)
fname = self.get_filename()
if redoLevel == 0 and os.path.isfile(fname):
self.load(fname)
else:
# create object
self.create(*args, **kwargs)

def create(self, *args, **kwargs):
pass


def plot(self, i, fig):
pass
Expand Down Expand Up @@ -61,6 +95,25 @@ def append(self, obj):
for d in obj.dirs:
self.dirs.append(d)

def get_filename(self):
"""
Return the base filename with an argument hash
appended
"""
h = self.hash()
fname = self.filename.replace(".mat", "_{0}.mat".format(h))
return fname

def hash(self):
pass

def load(self, fname=None):
if fname is None:
fname = self.get_filename()

with h5py.File(fname) as ff:
self.dirs = [s.decode() for s in ff["dirs"][:]]
self.setidx = ff["setidx"][:].tolist()

class DPObjects():
def __init__(self, objects):
Expand Down
113 changes: 76 additions & 37 deletions DataProcessingTools/psth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,67 +4,106 @@
from . import levels
from .raster import Raster
import os
import glob
import h5py
import hashlib


class PSTH(DPObject):
def __init__(self, bins, windowsize=1, spiketimes=None, trialidx=None, triallabels=None,
alignto=None, trial_event=None, dirs=None):
DPObject.__init__(self)
tmin = bins[0]
tmax = bins[-1]
if spiketimes is None:
# attempt to load from the current directory
raster = Raster(tmin, tmax, alignto=alignto, trial_event=trial_event)
spiketimes = raster.spiketimes
trialidx = raster.trialidx
triallabels = raster.trial_labels
"""
PSTH(bins,windowSize=1, dirs=None, redoLevel=0, saveLevel=1)
"""
filename = "psth.mat"
argsList = ["bins", ("windowSize", 1)]

def __init__(self, *args, **kwargs):
"""
Return a PSTH object using the specified bins
"""
DPObject.__init__(self, *args, **kwargs)

def create(self, *args, **kwargs):
saveLevel = kwargs.get("saveLevel", 1)
bins = self.args["bins"]

# attempt to load from the current directory
raster = Raster(bins[0], bins[-1], **kwargs)
spiketimes = raster.spiketimes
trialidx = raster.trialidx
self.trialLabels = raster.trialLabels

ntrials = trialidx.max()+1
counts = np.zeros((ntrials, np.size(bins)), dtype=np.int)
for i in range(np.size(spiketimes)):
jj = np.searchsorted(bins, spiketimes[i])
if 0 <= jj < np.size(bins):
counts[trialidx[i], jj] += 1

self.windowsize = windowsize
if windowsize > 1:
scounts = np.zeros((ntrials, len(bins)-windowsize))
windowSize = self.args["windowSize"]
if windowSize > 1:
scounts = np.zeros((ntrials, len(bins)-windowSize+1))
for i in range(ntrials):
for j in range(len(bins)-windowsize):
scounts[i, j] = counts[i, j:j+windowsize].sum()
for j in range(len(bins)-windowSize+1):
scounts[i, j] = counts[i, j:j+windowSize].sum()

self.data = scounts
self.bins = bins[:-windowsize]
self.bins = np.array(bins[:-windowSize])
else:
self.data = counts
self.bins = bins
self.bins = np.array(bins)

self.ntrials = ntrials
if triallabels is None:
self.trial_labels = np.ones((ntrials,))
elif np.size(triallabels) == ntrials:
self.trial_labels = triallabels
elif np.size(triallabels) == np.size(spiketimes):
dd = {}
for t, l in zip(trialidx, triallabels):
dd[t] = l
self.trial_labels = np.array([dd[t] for t in range(ntrials)])
else:
self.trial_labels = triallabels

# index to keep track of sets, e.g. trials
self.setidx = [0 for i in range(self.ntrials)]
if dirs is not None:
self.dirs = dirs
else:
self.dirs = [os.getcwd()]

if saveLevel > 0:
self.save()

def load(self, fname=None):
DPObject.load(self)
if fname is None:
fname = self.filename
with h5py.File(fname) as ff:
args = {}
for (k, v) in ff["args"].items():
self.args[k] = v.value
self.data = ff["counts"][:]
self.ntrials = self.data.shape[0]
self.bins = self.args["bins"][:self.data.shape[-1]]
self.trialLabels = ff["trialLabels"][:]

def hash(self):
"""
Returns a hash representation of this object's arguments.
"""
#TODO: This is not replicable across sessions
h = hashlib.sha1(b"psth")
for (k, v) in self.args.items():
x = np.atleast_1d(v)
h.update(x.tobytes())
return h.hexdigest()

def save(self, fname=None):
if fname is None:
fname = self.get_filename()

with h5py.File(fname, "w") as ff:
args = ff.create_group("args")
args["bins"] = self.args["bins"]
args["windowSize"] = self.args["windowSize"]
ff["counts"] = self.data
ff["trialLabels"] = self.trialLabels
ff["dirs"] = np.array(self.dirs, dtype='S256')
ff["setidx"] = self.setidx

def append(self, psth):
if not (self.bins == psth.bins).all():
ValueError("Incompatible bins")

DPObject.append(self, psth)
self.data = np.concatenate((self.data, psth.data), axis=0)
self.trial_labels = np.concatenate((self.trial_labels, psth.trial_labels),
self.trialLabels = np.concatenate((self.trialLabels, psth.trialLabels),
axis=0)
self.ntrials = self.ntrials + psth.ntrials

Expand All @@ -73,13 +112,13 @@ def plot(self, i=None, ax=None, overlay=False):
ax = gca()
if not overlay:
ax.clear()
trial_labels = self.trial_labels[i]
trialLabels = self.trialLabels[i]
data = self.data[i, :]
labels = np.unique(trial_labels)
labels = np.unique(trialLabels)

for li in range(len(labels)):
label = labels[li]
idx = trial_labels == label
idx = trialLabels == label
mu = data[idx, :].mean(0)
sigma = data[idx, :].std(0)
ax.plot(self.bins, mu)
Expand Down
68 changes: 47 additions & 21 deletions DataProcessingTools/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,69 @@


class Raster(DPObject):
def __init__(self, tmin, tmax, alignto=None, trial_event=None,
spiketimes=None,
trial_labels=None, dirs=None):
DPObject.__init__(self)
if spiketimes is None:
spiketrain = Spiketrain()
spiketimes = spiketrain.timestamps.flatten()
if alignto is None:
trials = get_trials()
# convert from seconds to ms
alignto = 1000*trials.get_timestamps(trial_event)
if trial_labels is None:
trial_labels = np.arange(len(alignto))
"""
Raster(tmin, tmax, TrialEvent,trialType, sortBy)
"""
filename = "raster.mat"
argsList = ["tmin", "tmax", "trialEvent", "trialType", "sortBy"]

def __init__(self, *args, **kwargs):
DPObject.__init__(self, *args, **kwargs)

def create(self, *args, **kwargs):
trials = get_trials()
#TODO: This only works with correct trials for now
rewardOnset, cidx, stimIdx = trials.get_timestamps("reward_on")
trialEvent = self.args["trialEvent"]
if "stimulus" in trialEvent:
if trialEvent == "stimulus1":
stimnr = 0
elif trialEvent == "stimulus2":
stimnr = 1
else:
stimnr = -1
ValueError("Unkonwn trial sorting {0}".format(trialEvent))

alignto, stimidx, trialLabel = trials.get_stim(stimnr, cidx)
alignto = 1000*np.array(alignto)
else:
alignto, sidx, stimIdx = trials.get_timestamps(trialEvent)
qidx = np.isin(sidx, cidx)
trialIdx = sidx[qidx]
alignto = 1000*alignto[qidx]
sortBy = self.args["sortBy"]
if sortBy == "stimulus1":
stimnr = 0
elif sortBy == "stimulus2":
stimnr = 1
else:
ValueError("Unkonwn trial sorting {0}".format(sortBy))

ts, identity, trialLabel = trials.get_stim(stimnr, trialIdx)
#TODO: Never reload spike trains
spiketrain = Spiketrain(*args, **kwargs)
spiketimes = spiketrain.timestamps.flatten()
tmin = self.args["tmin"]
tmax = self.args["tmax"]
bidx = np.digitize(spiketimes, alignto+tmin)
idx = (bidx > 0) & (bidx <= np.size(alignto))
raster = spiketimes[idx] - alignto[bidx[idx]-1]
ridx = (raster > tmin) & (raster < tmax)
self.spiketimes = raster[ridx]
self.trialidx = bidx[idx][ridx]-1
self.trial_labels = trial_labels
self.trialLabels = trialLabel
self.setidx = [0 for i in range(len(self.trialidx))]
if dirs is None:
self.dirs = [os.getcwd()]
else:
self.dirs = dirs

def append(self, raster):
DPObject.append(self, raster)
n_old = len(self.spiketimes)
n_new = n_old + len(raster.spiketimes)
self.spiketimes.resize(n_new)
self.spiketimes = np.resize(self.spiketimes, n_new)
self.spiketimes[n_old:n_new] = raster.spiketimes
self.trialidx.resize(n_new)
self.trialidx = np.resize(self.trialidx, n_new)
self.trialidx[n_old:n_new] = raster.trialidx

self.trial_labels = np.concatenate((self.trial_labels, raster.trial_labels))
self.trialLabels = np.concatenate((self.trialLabels, raster.trialLabels))

def plot(self, idx=None, ax=None, overlay=False):
if ax is None:
Expand Down
18 changes: 13 additions & 5 deletions DataProcessingTools/spiketrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,22 @@


class Spiketrain(DPObject):
def __init__(self):
DPObject.__init__(self)
self.level = "cell"
self.filename = "unit.mat"
"""
Spiketrain()
"""
level = "cell"
filename = "unit.mat"

def __init__(self, *args, **kwargs):
DPObject.__init__(self, *args, **kwargs)
# always load since we do not create spike trains here.
if os.path.isfile(self.filename):
self.load()

def load(self):
def load(self, fname=None):
q = sio.loadmat(self.filename)
self.timestamps = q["timestamps"]
self.spikeshape = q["spikeForm"]

def get_filename(self):
return self.filename
Loading