In [None]:
import numpy as np
import matplotlib.pyplot as plt
import adaptive_latents as al
import tensortools as tt
from tensortools.cpwarp import fit_shifted_cp
from sklearn.cluster import SpectralClustering
import scipy

from naumann_utility_functions import make_responses, find_decompositions, reshape_by_group, compare_rows, plot_comparison, plot_per_neuron

rng = np.random.default_rng()


In [None]:
d = al.datasets.Naumann24uDataset(sub_dataset_identifier=2)
responses = make_responses(d)
non_nan_responses = responses[...,:d.n_neurons_in_optical]

models = find_decompositions(non_nan_responses, n_restarts=200)
model = models[0]

In [None]:
%matplotlib inline
template = [[1.]]
convolved = scipy.signal.convolve2d(template, d.neural_data[:,:d.n_neurons_in_optical], mode='valid')
plt.matshow(convolved.T, aspect='auto')
# 83


In [None]:
%matplotlib inline
template = model.factors[2].T
convolved = scipy.signal.convolve2d(template, d.neural_data[:,:d.n_neurons_in_optical], mode='valid')
plt.matshow(convolved.T, aspect='auto')

In [None]:
%matplotlib inline
signal = 100 * np.array(d.neural_data[50:,75])
divisor = 100 * model.factors[2].flatten()

divisor = divisor[np.cumsum(np.abs(divisor)) > 1]
quotient, remainder = scipy.signal.deconvolve(signal, divisor)

fig, axs = plt.subplots(2,2)
axs[0,0].plot(signal)
axs[0,1].plot(divisor)

axs[1,0].plot(quotient)
axs[1,0].set_ylim([-10,10])
axs[1,1].plot(remainder)



In [None]:
plt.plot(model.factors[2].flatten())

In [None]:
%matplotlib inline
template = model.factors[2].T @ model.factors[0]
convolved = scipy.signal.convolve2d(template, d.neural_data[:,:d.n_neurons_in_optical], mode='valid')
plt.plot(convolved)
for s in d.opto_stimulations['sample']:
    plt.axvline(s, color='g', alpha=.5)

for s in d.visual_stimuli['sample']:
    plt.axvline(s, color='r', alpha=.5)


In [None]:

responses = make_responses(d)
responses_by_group = reshape_by_group(responses, d)


## Initial Plot
My initial question was whether the stimulations influenced neural activity, so I inspected the correspondence between stimulation times and neural activity (the C matrix) using a heatmap.

In [None]:
fig, ax = plt.subplots(figsize=(10,10))
ax.matshow(d.neural_data.T)

red_line = ax.axvline(d.last_visual_sample, color='r')
for x in d.opto_stimulations['sample']:
    white_line = ax.axvline(x, color='w', alpha = .5)

ax.legend([red_line, white_line], ['last visual stimulus', 'optogenetic stimuli'])
ax.set_xlabel("sample")
ax.set_ylabel("neuron")

ax.set_xlim([1000, 1500])
ax.set_ylim([300, 0]);


### Result

Checking the image of the neural traces, it seems reasonable that some of the neural traces (especially around 75) are responding to the stimuli.
I was next interested in the consistency of the responses to multiple stimulations of the same neuron.

## Extract responses
To check the consistency of the neural responses to a single group of stimuli, I first extracted the neural activity after each stimulus.
For each stimulus, I defined the response period to be as long as possible while staying between stimuli, so 30 samples for 'output_012824_ds3'.
'Responses' were then neural activity during the response period normalized by the activity during the first sample. 
Psudeocode with the `C` matrix would look something like this: `responses[...] = C[:, stim_start:stim_start + 30] - C[:, stim_start]`.


## Some neural responses are consistent
To compare the responses of a group of stimulations to the same neuron, I plotted the 'response' flouresence traces.


In [None]:


def plot_single_group_responses(group_n, responses_by_group):
    target_neurons = np.squeeze([np.unique(x) for x in reshape_by_group(d.opto_stimulations.target_neuron, d)])
    assert abs(len(responses_by_group) - len(target_neurons)) <= 1
    
    single_group_responses = responses_by_group[group_n]
    target_neuron = target_neurons[group_n]

    fig, axs = plt.subplots(ncols=max(single_group_responses.shape[0],2), nrows=2,  figsize=(10, 5), sharey='row')
    gs = axs[1,0].get_gridspec()
    for ax in axs[1, :4]:
        ax.remove()
    axs[1,0] = fig.add_subplot(gs[1,:2])

    stimuli_in_previous_groups = sum([r.shape[0] for r in responses_by_group[:group_n]])
    for i, ax in enumerate(axs[0,:]):
        stimulus_number = stimuli_in_previous_groups + i
        ax.plot(single_group_responses[i,:,:])
        ax.set_xlabel('samples from stim')
        ax.set_title(f'stim {stimulus_number}, neuron = {d.opto_stimulations.loc[stimulus_number,"target_neuron"]}', fontsize='small')

    mean_responses = single_group_responses.mean(axis=0)
    axs[1,0].plot(mean_responses)

    axs[0,0].set_ylabel('response magnitude (a.u.)')

    axs[1,0].set_ylabel('response magnitude (a.u.)')
    axs[1,0].set_title(f'average response for group {group_n}')
    axs[1,0].set_xlabel('samples from stim')
    
    sizes = np.mean(mean_responses, axis=0)
    sizes = np.abs(sizes / 7.3) * 15
    sizes[sizes < 5]  = np.nan
    plot_per_neuron(axs[1,4], sizes, d)
    axs[1,4].scatter(d.neuron_df.loc[target_neuron, 'x'], d.neuron_df.loc[target_neuron, 'y'], s=10, color='blue')
    fig.tight_layout()
    
    print(f"Neuron with the highest average peak: {np.unravel_index(np.nanargmax(mean_responses), mean_responses.shape)[1]}")

def make_stim_subplot(ax, group_n, responses_by_group):
    target_neurons = np.squeeze([np.unique(x) for x in reshape_by_group(d.opto_stimulations.target_neuron, d)])
    assert abs(len(responses_by_group) - len(target_neurons)) <= 1

    single_group_responses = responses_by_group[group_n]
    target_neuron = target_neurons[group_n]

    mean_responses = single_group_responses.mean(axis=0)

    sizes = np.mean(mean_responses, axis=0)
    sizes = np.abs(sizes / 7.3) * 15
    sizes[sizes < 5]  = np.nan
    plot_per_neuron(ax, sizes, d)
    ax.scatter(d.neuron_df.loc[target_neuron, 'x'], d.neuron_df.loc[target_neuron, 'y'], s=10, color='blue')


In [None]:
groups_to_inspect = {
    'output_012824_ds6_fish3': [5,7,13],
    'output_012824_ds3': [4,8,11],
    'output_020424_ds1': [9,14,19],
}.get(d.sub_dataset)


In [None]:
%matplotlib inline
plot_single_group_responses(group_n=groups_to_inspect[0], responses_by_group=responses_by_group)

Some neurons showed consistent responses; group 5, stimulating neuron 107, is a good example of a consistent response.
This group of responses displays variability, but there is clearly a pattern for the average to pick up on; particularly the red trace (neuron 83).
Note, however, that the most consistently active neuron (83) is not the neuron that was targeted for stimulation (107.)
This plot works for identifying consistent large-magnitude responses, but a limitation I haven't addressed is that it will not clearly show low-magnitude (but still consistent) responses.


In [None]:
plot_single_group_responses(group_n=groups_to_inspect[1], responses_by_group=responses_by_group)
plot_single_group_responses(group_n=groups_to_inspect[2], responses_by_group=responses_by_group)

In [None]:
fig, axs = plt.subplots(5,5, layout='tight')
for i, ax in enumerate(axs.flatten()):
    if i < len(responses_by_group):
        make_stim_subplot(ax, group_n=i, responses_by_group=responses_by_group)
    else:
        ax.axis('off')

## Comparison matrices
To visualize more completely how the responses within groups compare, I next constructed comparison matrices.
I was mostly looking to see if the responses within groups  were more similar than responses between groups; this would 
indicate that targeting specific neurons had consistent, distinguishable effects.


In [None]:
timepoint = 14
responses_sub_mat = np.vstack([responses_by_group[i] for i in groups_to_inspect])[:, timepoint, :d.n_neurons_in_optical]

comparison_methods = [ 'distances', 'angles', 'norm_distances']

fig, axs = plt.subplots(ncols=3, figsize=(10,5))

for ax, method in zip(axs, comparison_methods):
    plot_comparison(ax, compare_rows(responses_sub_mat, method=method), group_sizes=[responses_by_group[i].shape[0] for i in groups_to_inspect], group_names=[chr(65+i) for i in groups_to_inspect])
    ax.set_title(f'Comparison using {method}')
    
fig.tight_layout()


Next, I wanted to see if the responses within groups were more similar than the responses between groups.
From the traces in the previous section, it appears that group 5 has more internal consistency than groups 7 and 13, which I want to quantify here.
The distance metric that ended up being most useful for this comparison is the angle (`method='angles'` above) between responses at a given timepoint (here, timepoint 14).
I tried using distances at first, but I think that metric was too sensitive to bulk changes in excitation, like in the 3rd trial of group 13.

I interpret the comparisons above to mean that stimuli in group 5 had a consistent response, whereas groups 7 and 13 had varying responses that did not align.
(I cherry-picked these from the next graph.)

In [None]:
timepoint = 13 
non_nan_responses = responses[...,:d.n_neurons_in_optical]

fig, axs = plt.subplots(ncols=2, gridspec_kw={'width_ratios': [20, 1]}, figsize=(7,7))

comparison_matrix = compare_rows(non_nan_responses[:,timepoint,:], method='angles')
plot_comparison(axs[0], comparison_matrix, group_sizes=[g.shape[0] for g in responses_by_group], ax_colorbar=axs[1])

fig.tight_layout()


In [None]:

trials_indexes_to_mark

In [None]:
timepoint = 13 
non_nan_responses = responses[...,:d.n_neurons_in_optical]

fig, axs = plt.subplots(ncols=2, gridspec_kw={'width_ratios': [20, 1]}, figsize=(7,7))

comparison_matrix = compare_rows(non_nan_responses[:,timepoint,:], method='angles')
plot_comparison(axs[0], comparison_matrix, group_sizes=[g.shape[0] for g in responses_by_group], ax_colorbar=axs[1])


trials_to_mark = "N2 C0 L2 D2 I3 G0 N3 C1 G1".split(' ')
trials_indexes_to_mark = [(ord(g) - 65) * 5 + int(i) for g,i in trials_to_mark]

axs[0].set_xticks(trials_indexes_to_mark)
axs[0].set_xticklabels(trials_to_mark, rotation=90, size='small')
axs[0].set_yticks(trials_indexes_to_mark)

axs[0].set_xlabel("")
axs[0].set_ylabel("")

fig.tight_layout()


This is a similar graph, but now considering all pairs of stimuli.
Note that much of the variation centers around 90°, as we might expect.
In the next plot, I will center the color axis range on 90° to give the colorscale more dynamic range.

Note that for many groups, the within-group similarity is similar to across-group similarities.
For example, the reactions to stimuli in group 4 are about as similar to each other as they are to reactions in group 5.
If this were a covariance matrix, I would immediately suspect a single low-rank principal component that most groups are aligned to.
But, since I'm looking at angles instead of covariance, I need to check more explicitly in the next section.
Having one dominating response component makes me wonder what's going on biologically, though.

Also note that groups 7 and 13 don't seem to align to the others; this makes me suspect that those stimulations were different somehow.
Since their angles are all about 90°, I wonder if those stimuli failed somehow.

Next, I wanted to check how these relationships vary as a function of time since the stimulus delivery, which led to the animation below.
The animation creates comparison matrices very similar to the one above, but for varying timepoints.


In [None]:
%matplotlib inline
fig, axs = plt.subplots(ncols=2, gridspec_kw={'width_ratios': [20, 1]}, figsize=(10,10))
with al.plotting_functions.AnimationManager(fig=fig, make_axs=False, fps=10, filetype='gif') as am:
    for step in range(responses.shape[1]):
        axs[0].cla()
        
        comparison_matrix = compare_rows(responses[:,step,:d.n_neurons_in_optical])
        plot_comparison(axs[0], comparison_matrix, vmin=60, vmax=120, ax_colorbar=axs[1], group_sizes=[g.shape[0] for g in responses_by_group])
        axs[0].set_title(f"{step = }")
        
        am.grab_frame()

My biggest conclusion from the animation is that there is a temporal organization to the response angle convergence. 
Later stimuli seem to reach peak convergence before the earlier stimuli do; this is what the wave-like pattern in the animation shows.

Another way to show this is to extract the dominant component of the neural activity at a time we know there is a response (timestep 14), and
project all of the responses over time onto this neural component.
Assuming the neurons active in the response are stable over time (which I checked), this projection shows the ramp in response for each trial. 
In the graph below, we can see later stimulations had faster responses, just like was suggested in the animations.

In [None]:
_,_, vh = np.linalg.svd(non_nan_responses[:,14,:])
svd_neural_component = vh[0]
activation_of_neural_component = (non_nan_responses @ svd_neural_component).T

fig, axs = plt.subplots(ncols=2, figsize=(10,5))
axs[0].bar(np.arange(svd_neural_component.size), svd_neural_component)
scatter_sizes = svd_neural_component.copy()
scatter_sizes[scatter_sizes < .001] = np.nan
scatter_sizes = scatter_sizes * 40
plot_per_neuron(axs[1], scatter_sizes, d)


In [None]:

fig, ax = plt.subplots(figsize=(10,5))
ax.matshow(activation_of_neural_component)
ax.set_xlabel("stimulation")
ax.set_ylabel("timepoint")
ax.set_title("Projection of neural activity onto timestep 14's dominant activity pattern")

group_sizes = [g.shape[0] for g in responses_by_group]
group_edges = np.cumsum(group_sizes)-.5
for boundary in group_edges:
    ax.axvline(boundary, color='w', lw=.5)
ax.set_xticks([e - s/2 for s,e in zip(group_sizes, group_edges)], [chr(65+i) for i in range(len(group_edges))]);


In [None]:

peri_stim_recordings = []
d.opto_stimulations.sort_values(by='time')
borders = list(d.opto_stimulations['sample']) + [d.neural_data.shape[0]]
for i in range(len(borders)-1):
    peri_stim_recordings.append(d.neural_data[borders[i]:borders[i+1],:d.n_neurons_in_optical])

temp = np.zeros([len(peri_stim_recordings)] + list(np.max([s.shape for s in peri_stim_recordings], axis=0))) * np.nan
for i, psr in enumerate(peri_stim_recordings):
    temp[i, :psr.shape[0], :psr.shape[1]] = psr - psr[0]
peri_stim_recordings = temp


fig, ax = plt.subplots(figsize=(10,5))
ax.matshow((peri_stim_recordings @ svd_neural_component).T)
for boundary in np.cumsum([g.shape[0] for g in responses_by_group])-.5:
    ax.axvline(boundary, color='w', lw=.5)

group_sizes = [g.shape[0] for g in responses_by_group]
group_edges = np.cumsum(group_sizes)-.5
for boundary in group_edges:
    ax.axvline(boundary, color='w', lw=.5)
ax.set_xticks([e - s/2 for s,e in zip(group_sizes, group_edges)], [chr(65+i) for i in range(len(group_edges))]);



## Tensor decomposition
I suspected that what I did in the figure above, where I projected a 3d array onto 1 neural activity pattern, giving an array, 
was trying to reinvent tensor decompositions.
Therefore, I used the ideas and code from [Williams et al. 2018](https://doi.org/10.1016/j.neuron.2018.05.015) to do a proper
tensor decomposition, with the 3 modes being neurons, trials, and time.

In [None]:
def find_decompositions(non_nan_responses, n_restarts=200):
    models = []
    for _ in range(n_restarts):
        try:
            m = tt.cpwarp.fit_shifted_cp(
                non_nan_responses.transpose(2,0,1), 
                1,
                max_iter=1000,
                boundary="edge",
                max_shift_axis0=None,
                max_shift_axis1=.3,
                u_nonneg=True, # neurons
                v_nonneg=True, # trials
            )
            models.append(m)
        except ZeroDivisionError:
            pass

    models.sort(key=lambda m: m.loss_hist[-1])
    return models

models = find_decompositions(non_nan_responses, n_restarts=200)
model = models[0]
        
print(f"proportion of variance unexplained: {model.loss_hist[-1]}")

In [None]:
%matplotlib inline
_, axs = plt.subplots(ncols=3, figsize=(10,5), layout='tight')
result = model
neural_component = np.squeeze(result.factors[0])
axs[0].bar(x=np.arange(neural_component.size), height=neural_component)
axs[0].set_title("neural loadings (U)")
axs[0].set_xlabel("neuron #")

per_stim_component = result.factors[1].T
all_group_loadings = []
total = 0
for group_loadings in reshape_by_group(per_stim_component, d):
    axs[1].plot(np.arange(total, total+len(group_loadings)), group_loadings, '.-')
    total = total + len(group_loadings)
    all_group_loadings.append(group_loadings)
axs[1].set_title("trial loadings (V)")
axs[1].set_xlabel("trial #")

temporal_component = result.factors[2].T
axs[2].plot(temporal_component)
axs[2].set_title("temporal component")
axs[2].set_xlabel("time (steps)");



In [None]:

inferred_responses = model.predict().transpose([1,0,2])
background = rng.normal(size=inferred_responses.shape) * .01
inferred_responses = inferred_responses + background

timepoint = 13
fig, axs = plt.subplots(ncols=2, gridspec_kw={'width_ratios': [20, 1]}, figsize=(7,7))
comparison_matrix = compare_rows(inferred_responses[:,timepoint,:], method='angles')
plot_comparison(axs[0], comparison_matrix, group_sizes=[g.shape[0] for g in responses_by_group], ax_colorbar=axs[1])
fig.tight_layout()


The result of this decomposition isn't particularly inspiring in the neural or trial loadings (although the neuron loadings 
match well with the SVD from above). The interesting part is that all of the datasets admit a decomposition which appears to replicate the 
dynamics of the reporter.
It has a near-zero response for the first ~7 timepoints, quickly rises to a peak, and then decays.

If these dynamics match anything we can externally verify, I think cross-referencing the dynamics with a ground-truth would give a lot of credibility to the validity of the rest of the decomposition.


Below is a debugging cell for the decomposition; the problem is non-convex, and there are often multiple solution basins.
This code lets you look at the clusters of solutions in case the lowest-loss basin isn't the one you want.
I used it for hyperparameter tuning in the code above.
The third dataset didn't need tuning, it always gave satisfactory decompositions, but the first one was more difficult.
However, I've found values that work for all 3 datasets.

In [None]:
%matplotlib inline
t_traces = np.vstack([m.factors[2] for m in models])
t_traces = t_traces / np.linalg.norm(t_traces, axis=1)[:,None]


c = SpectralClustering(n_clusters=3, affinity='precomputed')
predictions = c.fit_predict(t_traces @ t_traces.T)
clusters, counts = np.unique(predictions, return_counts=True)
for cluster, count in zip(clusters, counts):
    predictions[predictions==cluster] = count

p = np.argsort(predictions)[::-1]

fig, axs = plt.subplots(ncols=2, figsize=(10,5))
axs[0].matshow(t_traces[p] @ t_traces[p].T)
comp_mat_examples = np.convolve(np.cumsum([0] + sorted(counts)[::-1]), [.5,.5], 'valid').astype(int)
axs[0].set_xticks(comp_mat_examples)
example_trials = p[comp_mat_examples]
axs[1].plot(t_traces[example_trials].T);

axs[1].set_title("extracted temporal profile")
axs[0].set_title("temporal profile covariance (clusters)")



In [None]:
fig, axs = plt.subplots(ncols=2, sharey=True, layout='tight', figsize=(10,5))

for sub_dataset_index in [0, 0, 2, 1]:
    d = al.datasets.Naumann24uDataset(sub_dataset_index=sub_dataset_index)
    responses = make_responses(d)
    non_nan_responses = responses[...,:d.n_neurons_in_optical]
    models = decompositions = find_decompositions(non_nan_responses, n_restarts=200)
    model = models[0]
    temporal_component = np.squeeze(models[0].factors[2].T)

    t = np.arange(len(temporal_component))
    axs[0].plot(t, temporal_component / temporal_component.max(), label=d.sub_dataset)
    
    t = t - np.argmax(temporal_component)
    axs[1].plot(t, temporal_component / temporal_component.max(), label=d.sub_dataset)
    

axs[0].set_title("Cross-dataset temporal component comparison (unaligned)")
axs[0].set_xlabel("time (samples)")
axs[0].set_ylabel("temporal response magnitude (normalized)")
axs[1].set_title("(aligned)")
axs[1].set_xlabel("time from peak (samples)")

for ax in axs:
    ax.legend()
    

Above, I compare the temporal component discovered for all of the datasets; they have similar dynamics.
Note that although I do one of the datasets twice, the result isn't consistent; this decomposition is stochastic.
This makes the delays it discovers also stochastic, so phase misalignment between the temporal components between runs could be a problem.
(As an aside, the library uses paralellization in a way that makes it difficult to control the randomness with a seed; this is bad for replicability.)

## Dataset 1 delay recovery
In dataset 1, the delays discovered per trial replicate the peaks of the activation of the trial 14 dominant neural component unusually well.
This doesn't work as well in the other datasets.

In [None]:
_,_, vh = np.linalg.svd(non_nan_responses[:,14,:])
svd_neural_component = vh[0]
activation_of_neural_component = (non_nan_responses @ svd_neural_component).T

per_trial_shifts = model.v_s.T
pc1_argmaxes = np.argmax(activation_of_neural_component,axis=0)

fig, ax = plt.subplots()
ax.plot((pc1_argmaxes - pc1_argmaxes.mean()))
ax.plot(per_trial_shifts - per_trial_shifts.mean());
