In [21]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from scipy.signal import lfilter

import abt

def plot_weights_one_fig(weights, pulse_width=18e-6):
    num_chan = min(weights.shape)
    num_steps = max(weights.shape)
    fig = plt.figure()

    t_vector = np.linspace(0, num_steps * pulse_width, num_steps)
    colors = matplotlib.colormaps["tab20"]
    for channel in np.arange(num_chan):
        plt.step(
            t_vector,
            channel + weights[channel].clip(0) + 1,
            c=colors(channel),
            where="post",
        )
        plt.grid()

    plt.xlim(0, num_steps * pulse_width)
    plt.xlabel("Time [s]")
    plt.ylabel("Electrode")
    return fig

name = "tone_1kHz"
pulse_train, weights_matrix, audio_signal = abt.wav_to_electrodogram(
    abt.sounds[name], current_steering=True
)

In [12]:
weights_matrix.shape

(16, 111090)

In [13]:
pulse_train

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [5]:
channels = np.arange(0, 1.1, .125)
channels

array([0.   , 0.125, 0.25 , 0.375, 0.5  , 0.625, 0.75 , 0.875, 1.   ])

In [28]:

def transform_pulse_train_to_121_virtual(pulse_train, weights_matrix):
    (num_electrodes, num_samples) = weights_matrix.shape
    pulse_times, pulse_electrodes = np.where(pulse_train.T < 0)
    pulse_train121 = np.zeros((121, num_samples))
    # # turn weights and I_given into 121 integers 
    weights_121_all_samples = np.zeros(num_samples)
    for el in np.arange(num_electrodes):
        pulse_times_electrode = pulse_times[pulse_electrodes == el]
        # print(el)
        if el == 15:
            el -= 1 # only loop over electrode, don't add to count
            el_pair = [14, 15]
        else:
            el_pair = [el, el+1]
        for pt in pulse_times_electrode:
            weights_pair = weights_matrix[el_pair, pt]
            if (weights_pair == np.array([1.0, 0.0])).all():
                weights_121_all_samples[pt] = 1 + el*8
            elif (weights_pair == np.array([0.875, 0.125]) ).all():
                weights_121_all_samples[pt] = 2 + el*8
            elif(weights_pair == np.array([0.75, 0.25]) ).all():
                weights_121_all_samples[pt] = 3 + el*8
            elif (weights_pair == np.array([0.625, 0.375]) ).all():
                weights_121_all_samples[pt] = 4 + el*8
            elif (weights_pair == np.array([0.5, 0.5]) ).all():
                weights_121_all_samples[pt] = 5 + el*8
            elif (weights_pair == np.array([0.375, 0.625]) ).all():
                weights_121_all_samples[pt] = 6 + el*8
            elif (weights_pair == np.array([0.25, 0.75]) ).all():
                weights_121_all_samples[pt] = 7 + el*8
            elif (weights_pair == np.array([0.125, 0.875]) ).all():
                weights_121_all_samples[pt] = 8 + el*8
            elif (weights_pair == np.array([0.0, 1.0]) ).all():
                weights_121_all_samples[pt] = 9 + el*8
            else:
                continue
            pulse_pair = pulse_train[el_pair, pt]
            virtual_channel_id = int(weights_121_all_samples[pt] -1)
            pulse_train121[virtual_channel_id, pt] = np.sum(pulse_pair)# apical + basal
    kernel = np.array([1, -1]) # biphasic pulses, already negative first pulse
    pulse_train121 = lfilter(kernel, 1, pulse_train121)
    return pulse_train121

transform_pulse_train_to_121_virtual(pulse_train, weights_matrix)


[0.875 0.   ] [np.int64(1), np.int64(2)]
[0.875 0.   ] [np.int64(1), np.int64(2)]
[0.875 0.   ] [np.int64(1), np.int64(2)]
[0.875 0.   ] [np.int64(1), np.int64(2)]
[0.875 0.   ] [np.int64(1), np.int64(2)]
[0.875 0.   ] [np.int64(1), np.int64(2)]
[0.875 0.   ] [np.int64(1), np.int64(2)]
[0.875 0.   ] [np.int64(1), np.int64(2)]
[0.875 0.   ] [np.int64(1), np.int64(2)]
[0.875 0.   ] [np.int64(1), np.int64(2)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.875 0.   ] [np.int64(3), np.int64(4)]
[0.75 0.  ] [np.int64(6), np.int64(7)]
[0.75 0.  ] [np.int64(6), np.int64(7)]
[0.75 0.  ] [np.int64(6), np.int64(7)]
[0.625 0.   ] [np.int6

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])