In [None]:
import spikeextractors as se
import spiketoolkit as st
import spikewidgets as sw
import edlio
import matplotlib.pyplot as plt
import numpy as np
%matplotlib notebook

In [None]:
#ced_file = "/Users/abuccino/Documents/Data/catalyst/heidelberg/ced/Short_example/M113_C4.smrx"
ced_file = "D:/CED_example_data/Other example/m365_pt1_590-1190secs-001.smrx"

## Load TTL signals and extract triggers

In [None]:
channel_info = se.CEDRecordingExtractor.get_all_channels_info(ced_file)

In [None]:
smrx_channels = []
for ch, info in channel_info.items():
    if 'TTL' in info["title"]:
        print("Loading", info["title"])
        smrx_channels.append(ch)

rec = se.CEDRecordingExtractor(ced_file, smrx_channels)

In [None]:
traces = rec.get_traces()

In [None]:
plt.figure()
plt.plot(traces.T)

### Find rising and falling edges

In [None]:
ttls = []
states = []
for tr in traces:
    threshold = np.ptp(tr) / 2 + np.min(tr)
    crossings = np.array(tr > threshold).astype('int8')
    
    rising = np.nonzero(np.diff(crossings, 1) > 0)[0]
    falling = np.nonzero(np.diff(crossings, 1) < 0)[0]
    
    ttl = np.concatenate((rising, falling))
    sort_order = np.argsort(ttl)
    ttl = np.sort(ttl)
    state = [1] * len(rising) + [-1] * len(falling)
    state = np.array(state)[sort_order]

    ttls.append(ttl)
    states.append(state)

In [None]:
# plot 
for i, tr in enumerate(traces):
    fig, ax = plt.subplots()
    
    ax.plot(tr)
    state = states[i]
    ttl = ttls[i]
    
    for t in ttl[state == 1]:
        plt.axvline(t, color='r', alpha=0.5)
    for t in ttl[state == -1]:
        plt.axvline(t, color='g', alpha=0.5)

The `Conditions` in the mat files are just the TTLs with rising and falling in the same row:

In [None]:
conditions = []

for ttl, state in zip(ttls, states):
    assert len(ttl[state==1]) == len(ttl[state==-1]), "Different number of rising/falling edges"
    condition = np.zeros((len(ttl[state==1]), 2), dtype='int')
    
    condition[:, 0] = ttl[state == 1]
    condition[:, 1] = ttl[state == -1]
    
    conditions.append(condition)

In [None]:
conditions[0]

## Save other signals to NWB

Non-TTL signals can be saved normally to NWB (note that Keyboard is somehow problematic...).

Also we can only load traces with the same sampling frequency in the same extractor.

In [None]:
smrx_channels = []
smrx_channels_names = []

sampling_rate_limit = 15000

for ch, info in channel_info.items():
    if 'TTL' not in info["title"] and info["title"] != 'Keyboard':
        if info['rate'] > sampling_rate_limit:
            print("Loading", info["title"])
            smrx_channels.append(ch)
            smrx_channels_names.append(info["title"])
        else:
            print("Skipped", info["title"])
            print(info)

rec = se.CEDRecordingExtractor(ced_file, smrx_channels)

In [None]:
sw.plot_timeseries(rec, trange=[20, 30])

These can be saved directly as ElectricalSeries to NWB.