In [7]:
import os
import os.path as op
import tables as tb
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import logging
import sys
import tempfile
import argparse
import scipy.io
import pylab as p
import matplotlib.cm as cm
import struct

from h5py import Dataset
from six import string_types
from six.moves import zip

import phy
from phy.io import KwikModel
from phy.io import create_kwik
from phy.utils._misc import _read_python
from phy.utils.logging import debug
from phy.utils.settings import _load_default_settings, _ensure_dir_exists
from attrdict import AttrDict
from phy.scripts.phy_script import _create_session
from phy.traces import Filter, Thresholder, compute_threshold, FloodFillDetector, WaveformExtractor, PCA
from phy.detect.spikedetekt import SpikeDetekt

In [21]:
def process_matdata(matfile, channelnumber, waveformstart, waveformend) :
    #This concatenates the waveforms and makes fake continuous data to be put into spikedetekt
    matdata = scipy.io.loadmat(matfile)
    
    #Find how many spikes there were on the channel to preallocate memory for waveforms and save time
    totalspikes = 0
    for key in matdata.keys() :
        if key.find('Wspk')!=-1 :
            if key.rsplit('Wspk', 1)[1][0:-3]==channelnumber :
                totalspikes += matdata[key].shape[1]

    wave_raw = np.zeros([totalspikes, waveformend-waveformstart])
    spike_indices = np.zeros([totalspikes])
    metadata = {}
    i=0
    for key in matdata.keys() :
        if key.find('Wspk')!=-1 :
            if key.rsplit('Wspk', 1)[1][0:-3]==channelnumber :
                for j in np.arange(matdata[key].shape[1]) :
                    metadata[i] = {}
                    #metadata[i]['episode'] = key.rsplit('Wspk', 1)[0][2:]
                    #metadata[i]['channel'] = key.rsplit('Wspk', 1)[1][0:-3]
                    #keyVspk = 'Ep'+metadata[i]['episode']+'Vspk'+metadata[i]['channel']+'Vu1'
                    metadata[i]['episode'] = key.rsplit('ep', 1)[1][2:]
                    metadata[i]['channel'] = key.rsplit('_', 1)[0][4:]
                    #keyVspk = 'Ep'+metadata[i]['episode']+'Vspk'+metadata[i]['channel']+'Vu1'
                    keyVspk = 'Vspk'+metadata[i]['channel']+'_ep'+metadata[i]['episode']
                    
                    metadata[i]['spiketime'] = matdata[keyVspk][j,0]
                    spike_indices[i] = matdata[keyVspk][j,0]
                    metadata[i]['episode_spikenumber'] = j
                    wave_raw[i,:] = matdata[key][waveformstart:waveformend,j]
                    i+=1
        
    cont_data = np.reshape(wave_raw, [np.size(wave_raw)])
    cont_data = cont_data.astype(np.int32)
    
    f = open('data.dat','wb')
    data = bytearray()
    for i in np.arange(len(cont_data)) :
        data.extend(struct.pack('l', cont_data[i]))
    f.write(data)
    f.close()

    return wave_raw, spike_indices, metadata
    # returning all waves, indices and metadata and saving the file for klusta


# just to save a file of spiketimes
def write_spikestimes(filename, spike_times):
    data = spike_times
    np.save(filename, data)

# make the files necessary to run klustakwik2, don't know for what    
def KK2_prepare_textfiles(kwik_path, n_features, channelnumber):
    filename, file_extension = op.splitext(kwik_path)
    model = KwikModel(kwik_path)
    features = model.features[:, :]
    times = model.spike_times
    times = np.expand_dims(times, axis=1)
    
    # make the text files to go into klustakwik2, fet.0 file and fmasks.0 file
    numfeatures = n_features+1
    data = np.concatenate((features, times), axis=1)

    np.savetxt(filename + '.fet.' + channelnumber, data, header = str(numfeatures), fmt = '%10.5f', comments='', delimiter=' ')

    dataMASKS = np.ones(data.shape)
    np.savetxt(filename + '.fmask.' + channelnumber, dataMASKS, header = str(numfeatures), fmt = '%10.5f', comments='', delimiter=' ')

    
# managing the clusterfile    
def add_clusters(clustfile, session, group):
    
    clusters = np.loadtxt(clustfile, dtype='int64', unpack=False)
    numclusters = clusters[0]+1
    clusters = clusters[1::]
    
    sc = clusters.astype(np.int32)
    session.model.creator.add_clustering(group=group, name='main', spike_clusters=sc)
    session.emit('open')

# new kwik session    
def initialize_session(paramsfile, overwrite, kwikfile):
    
    args = AttrDict({'file': paramsfile, 'overwrite': overwrite, 'kwikpath': kwikfile})
    
    params = _read_python(args.file)

    kwik_path = create_kwik(args.file, overwrite=args.overwrite, kwik_path=args.kwikpath)
    interval = None

    args.file = args.kwikpath
    session = _create_session(args, use_store=False)

    return dict(session=session, interval=interval)

# run spikedetect with "fake" data
def run_faux_detection(session):
    
    sd_dir = op.join(session.settings.exp_settings_dir, 'spikedetekt')
    _ensure_dir_exists(sd_dir)
    interval_samples = None
    traces = session.model.traces
    params = session.model.metadata
    params['probe_channels'] = session.model.probe.channels_per_group
    params['probe_adjacency_list'] = session.model.probe.adjacency
    debug("Running SpikeDetekt with the following parameters: " "{}.".format(params))
    
    sd = SpikeDetekt(tempdir=sd_dir, **params)
    out = sd.run_serial(traces, interval_samples=interval_samples)

    return sd, out, params


def replace_spikes(session, sd, out, wave_raw, group, n_channels, n_features):
    out.spike_samples[0] = np.linspace(0, wave_raw.shape[1]*wave_raw.shape[0], wave_raw.shape[0]+1, dtype='int')+16
    out.spike_samples[0] = out.spike_samples[0][0:-1] 
    spike_samples = out.spike_samples[0]
    n_spikes = len(spike_samples) if spike_samples is not None else 0
    spike_recordings = None
    out.masks[0] = np.ones([len(spike_samples),1], dtype='float32')
    wave_raw = np.expand_dims(wave_raw, axis=2)
    pcs = sd.waveform_pcs(wave_raw, out.masks[0])
    out.features[0] = sd.features(wave_raw, pcs)
    
    session.model.creator.add_spikes(group=group,
                                spike_samples=spike_samples,
                                spike_recordings=None,
                                masks=out.masks[0],
                                features=out.features[0],
                                n_channels=n_channels,
                                n_features=n_features,
                                )

In [22]:
Folder = '/media/matias/DATA/WORKSPACE2/EXP_2/Spike_Sort/FC7/'

matfile = Folder+'FC-151217-7_waves.mat'

channelnumber = '2'   ### or 1 depending where we start coubn t
waveformstart = 0
waveformend = 48

paramsfile= Folder + 'params.prm'
overwrite = True
kwik_path = Folder + 'data' + channelnumber +'.kwik'
clustfile = Folder + 'data.clu.'+ channelnumber

n_channels = 1

n_features = 3

group = 0

In [24]:
%%capture
# MAIN PROGRAM

# obtaining everything from the matfile
wave_raw, spike_indices, metadata = process_matdata(matfile, channelnumber, waveformstart, waveformend)

# saving spiketimes to a file
write_spikestimes('spike_times.npy', spike_indices/np.float(30000))

# changing data type of raw waves
wave_raw = wave_raw.astype(np.float32)

# new session in klusta
newsession = initialize_session(paramsfile, overwrite, kwik_path)

# running spikedetect with fake data
sd, out, params = run_faux_detection(newsession['session'])

# don't fucking now
replace_spikes(newsession['session'], sd, out, wave_raw, group, n_channels, n_features)

# just preparing stuff to run clustering
KK2_prepare_textfiles(kwik_path, n_features, channelnumber)

# running the clustering
%run C:\Miniconda3\envs\phy\Scripts\kk2_legacy-script.py data $channelnumber max_possible_clusters=6 num_starting_clusters=3 max_iterations=10000

11:46:16 [I] Saving a backup of the Kwik file in /media/matias/DATA/WORKSPACE2/EXP_2/Spike_Sort/FC7/data2.kwik.bak.


ValueError: cannot mmap an empty file

In [14]:
# adding clusters
add_clusters(clustfile, newsession['session'], group)

NameError: name 'newsession' is not defined

In [13]:
# trying stuff
wave_raw, spike_indices, metadata = process_matdata(matfile, channelnumber, waveformstart, waveformend)

In [3]:
clusters = np.loadtxt('data.clu.5', dtype='int64', unpack=False)
clusters = clusters[1::]
pcs = np.loadtxt('data.fet.'+channelnumber, dtype='float32', unpack=False, skiprows=1)

In [4]:
# checking stuff
np.unique(clusters)

array([1, 3, 4, 5, 6], dtype=int64)

In [8]:
# extracting features and waves

pcs1 = pcs[np.where(clusters==1)[0],:]
pcs3 = pcs[np.where(clusters==3)[0],:]
pcs4 = pcs[np.where(clusters==4)[0],:]
pcs5 = pcs[np.where(clusters==5)[0],:]
pcs6 = pcs[np.where(clusters==6)[0],:]

waves1 = wave_raw[np.where(clusters==1)[0],:]
waves3 = wave_raw[np.where(clusters==3)[0],:]
waves4 = wave_raw[np.where(clusters==4)[0],:]
waves5 = wave_raw[np.where(clusters==5)[0],:]
waves6 = wave_raw[np.where(clusters==6)[0],:]

In [9]:
# plot the clusters

fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(pcs1[:,0], pcs1[:,1], pcs1[:,2], color = 'b')
ax.scatter(pcs3[:,0], pcs3[:,1], pcs3[:,2], color = 'r')
ax.scatter(pcs4[:,0], pcs4[:,1], pcs4[:,2], color = 'g')
ax.scatter(pcs5[:,0], pcs5[:,1], pcs5[:,2], color = 'k')
ax.scatter(pcs6[:,0], pcs6[:,1], pcs6[:,2], color = 'y')
plt.show()

In [12]:
# Plot the waves

for i in np.arange(waves1.shape[0]) :
    plt.plot(waves1[i,:], 'b')
#for i in np.arange(waves3.shape[0]) :
    #plt.plot(waves3[i,:], 'r')
#for i in np.arange(500) :
    #plt.plot(waves4[i,:], 'g')
#for i in np.arange(500) :
    #plt.plot(waves6[i,:], 'y')
plt.show()