# Embedding of evoked responses and TCA
Inspired by 'ripples' paper where time are used as features to investigate diversity of responses across age and cells

In [None]:
import numpy as np
import os

session_paths = [
    'data_proc/jm/jm064/2025-11-13_s',
    'data_proc/jm/jm064/2025-11-14_s',
    'data_proc/jm/jm064/2025-11-15_s',
    'data_proc/jm/jm064/2025-11-16_s',
    'data_proc/jm/jm064/2025-11-17_s',
    'data_proc/jm/jm064/2025-11-18_s'
]

t2p_indices_path = 'data_proc/jm/jm064/track2p/plane0_suite2p_indices.npy'


all_resp_evoked = []
for session_path in session_paths:
    all_resp_evoked.append(np.load(os.path.join(session_path, 'resp_evoked', 'resp_evoked.npy'), allow_pickle=True).item())

n_features = all_resp_evoked[0]['resp_mean'].shape[1]
n_days = len(all_resp_evoked)
n_neurons = all_resp_evoked[0]['resp_mean'].shape[0]

In [None]:
n_trials = 60
all_resp_mean = np.zeros((n_days, n_neurons, n_features))
all_resp = np.zeros((n_days, n_trials, n_neurons, n_features))


for i, resp_evoked in enumerate(all_resp_evoked):
    all_resp_mean[i] = resp_evoked['resp_mean']
    print(resp_evoked['resp'].shape)
    all_resp[i] = resp_evoked['resp']

# TODO: THE MATCHING DOESN'T MAKE SENSE...

# now plot some example neurons across days
import matplotlib.pyplot as plt
neuron_idxs = np.int64(np.linspace(0, n_neurons-1, 50))

for neuron_idx in neuron_idxs:
    fig, axs = plt.subplots(1, n_days, figsize=(15, 1))
    for day in range(n_days):
        axs[day].plot(all_resp_mean[day, neuron_idx])
        axs[day].set_title(f'Day {day+1}')
    plt.suptitle(f'Neuron {neuron_idx}')
    plt.show()

In [None]:
print(all_resp.shape)
# now flatten the data for dimensionality reduction (neurons*trials*days x features)
all_resp_flat = all_resp.reshape(-1, n_features)
print(all_resp_flat.shape)
# now label rows by day


In [None]:
# now import and run umap
import umap
from sklearn.decomposition import PCA

In [None]:
def zscore_rows(X):
    return (X - X.mean(axis=1, keepdims=True)) / X.std(axis=1, keepdims=True)

In [None]:
# flatten all_resp_mean along the days dimension and add labels (to color code the points by day)
data_mn = all_resp_mean.reshape(-1, n_features)
data_st = all_resp_flat
# # Compute the PSD: Instead of feeding raw time series into UMAP, feed it the Power Spectral Density (PSD) or the magnitude of the Fast Fourier Transform (FFT).
# data_psd = np.abs(np.fft.fft(data, axis=1))

# data = zscore_rows(data_psd)

labels_mn = np.repeat(np.arange(n_days), n_neurons)
labels_st = np.repeat(np.arange(n_days), n_neurons*n_trials)


In [None]:
# now fit umap and visualise
reducer_mn = umap.UMAP(
    random_state=42,
    n_neighbors=15
)
emb_umap_mn = reducer_mn.fit_transform(data_mn)

In [None]:
reducer_st = umap.UMAP(
    random_state=42,
    n_neighbors=15
)

emb_umap_st = reducer_st.fit_transform(data_st)

In [None]:
nrn_idx = 306
nrn_idx_days = [nrn_idx + i*n_neurons for i in range(n_days)]

In [None]:
plt.figure(figsize=(14, 10), dpi=300)
plt.scatter(emb_umap_mn[:, 0], emb_umap_mn[:, 1], c=labels_mn+8, s=10, cmap='plasma', alpha=0.5)
plt.axis('off')
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=18)
cbar.set_label('Postnatal day', fontsize=24)

In [None]:
plt.figure(figsize=(14, 10), dpi=300)
plt.scatter(emb_umap_st[:, 0], emb_umap_st[:, 1], c=labels_st+8, s=0.1, cmap='viridis', alpha=0.5)
plt.axis('off')
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=18)
cbar.set_label('Postnatal day', fontsize=24)

In [None]:
# now get the centroid of the points corresponding to all of the trials of a given neuron on a given day
nrn_idx = 33
all_centroid = np.zeros((n_days, 2))
for d in range(n_days):
    # get indices of the points corresponding to the trials of neuron nrn_idx on day d
    trial_indices = np.where(labels_st == d)[0]
    neuron_trial_indices = trial_indices[trial_indices % n_neurons == nrn_idx]
    neuron_trial_points = emb_umap_st[neuron_trial_indices]
    all_centroid[d] = neuron_trial_points.mean(axis=0)

    print(f'Centroid for neuron {nrn_idx} on day {d+1}: {all_centroid[d]}')

In [None]:
ex_nrn_idxs = [33, 94, 284, 60, 47, 127, 282, 329]

In [None]:
for nrn_idx in ex_nrn_idxs:
    
    all_centroid = np.zeros((n_days, 2))
    for d in range(n_days):
        # get indices of the points corresponding to the trials of neuron nrn_idx on day d
        trial_indices = np.where(labels_st == d)[0]
        neuron_trial_indices = trial_indices[trial_indices % n_neurons == nrn_idx]
        neuron_trial_points = emb_umap_st[neuron_trial_indices]
        all_centroid[d] = neuron_trial_points.mean(axis=0)

        print(f'Centroid for neuron {nrn_idx} on day {d+1}: {all_centroid[d]}')

    plt.figure(figsize=(14, 10), dpi=300)
    plt.scatter(emb_umap_st[:, 0], emb_umap_st[:, 1], c=labels_st+8, s=0.1, cmap='viridis', alpha=0.5)
    # now plot the centroids
    plt.scatter(all_centroid[:, 0], all_centroid[:, 1], c=np.arange(n_days)+8, s=50, cmap='viridis', label=f'Neuron {nrn_idx} centroids', zorder=3)
    plt.plot(all_centroid[:, 0], all_centroid[:, 1], c='grey', linewidth=3, zorder=2, label='Trajectory')
    plt.legend()
    plt.axis('off')
    cbar = plt.colorbar()
    cbar.ax.tick_params(labelsize=18)
    cbar.set_label('Postnatal day', fontsize=24)
    plt.show()

In [None]:
for nrn_idx in ex_nrn_idxs:
    nrn_idx_days = [nrn_idx + i*n_neurons for i in range(n_days)]
    
    fig, axs = plt.subplot_mosaic(mosaic='AAAAAA\nAAAAAA\nAAAAAA\nAAAAAA\nAAAAAA\nAAAAAA\nBCDEFG', figsize=(10, 10))
    # make the BCDEFG share y axis
    for ax in ['B', 'C', 'D', 'E', 'F', 'G']:
        axs[ax].sharey(axs['B'])
    axs['A'].scatter(emb_umap_mn[:, 0], emb_umap_mn[:, 1], c=labels_mn, s=5, cmap='plasma',alpha=0.7, zorder=0)
    axs['A'].scatter(emb_umap_mn[nrn_idx_days, 0], emb_umap_mn[nrn_idx_days, 1], c=np.arange(n_days), s=50, cmap='plasma')
    axs['A'].plot(emb_umap_mn[nrn_idx_days, 0], emb_umap_mn[nrn_idx_days, 1], c='gray', alpha=0.5, label=f'Trajectory of neuron {nrn_idx}')
    axs['A'].set_title('UMAP embedding of evoked responses')
    axs['A'].set_xlabel('UMAP 1')
    axs['A'].set_ylabel('UMAP 2')
    # remove axis
    axs['A'].axis('off')
    # add legend to top left corner of the plot
    axs['A'].legend(loc='upper left', frameon=False)
    # add colormap labelled with the day numbers
    other_days = 'BCDEFG'
    for day in range(n_days):
        # get color based on 'plasma' colormap and the day index
        color = plt.cm.plasma(day/ (n_days-1))
        axs[other_days[day]].plot(all_resp_mean[day, nrn_idx], color=color)
        axs[other_days[day]].set_xticks([])
        axs[other_days[day]].set_yticks([]) 
        axs[other_days[day]].set_axis_off()
        # add f'P{8+day}' to top left corner of the subplot
        axs[other_days[day]].text(0.05, 0.95, f'P{8+day}', transform=axs[other_days[day]].transAxes, fontsize=12, verticalalignment='top', color=color)


In [None]:
import tensortools as tt

data = all_resp_mean # ... specify a numpy array holding the tensor you wish to fit
# data = all_resp.reshape(n_days, n_trials*n_neurons, n_features) # ... specify a numpy array holding the tensor you wish to fit

# Fit an ensemble of models, 4 random replicates / optimization runs per model rank
ensemble = tt.Ensemble(fit_method="ncp_hals")
ensemble.fit(data, ranks=range(1, 10), replicates=5)

fig, axes = plt.subplots(1, 2)
tt.plot_objective(ensemble, ax=axes[0])   # plot reconstruction error as a function of num components.
tt.plot_similarity(ensemble, ax=axes[1])  # plot model similarity as a function of num components.
fig.tight_layout()

# Plot the low-d factors for an example model, e.g. rank-2, first optimization run / replicate.
num_components = 4
replicate = 0
# color lines in 'C1' 
tt.plot_factors(ensemble.factors(num_components)[replicate], line_kw=[{'color': 'C2'}, {'color': 'C0'}, {'color': 'C3'}])  # plot the low-d factors

plt.show()

In [None]:
day_components = np.zeros((num_components, n_days))
nrn_components = np.zeros((num_components, n_neurons))
feature_components = np.zeros((num_components, n_features))

for i in range(num_components):
    day_components[i] = ensemble.factors(num_components)[replicate][0][:, i]
    nrn_components[i] = ensemble.factors(num_components)[replicate][1][:, i]
    feature_components[i] = ensemble.factors(num_components)[replicate][2][:, i]

In [None]:
fig, axs = plt.subplots(num_components, 1, figsize=(2, 5))
for i in range(num_components):
    axs[i].hist(nrn_components[i,:], bins=26)
    axs[i].axis('off')
plt.show()

In [None]:
import matplotlib.colors as colors


In [None]:
# now scatter the umap embedding color coded by the TCA component embedding
for i in range(num_components):
    # outer product of neuron and day componentsa
    c = np.outer(day_components[i], nrn_components[i, :]).flatten()
    # now do the log to get a logarithmic color scale (since the values are mostly close to zero, with some large outliers)
    c = np.log(np.abs(c) + 1e-5)  # add a small value to avoid log(0)
    
    plt.figure(figsize=(10, 8))
    plt.scatter(
        emb_umap_mn[:, 0],
        emb_umap_mn[:, 1],
        c=c,
        s=5,
        cmap='viridis'
    )
    plt.axis('off')
    cbar = plt.colorbar(ticks=[])
    cbar.set_label(fr'$TC{i+1}_{{\mathrm{{day}}}}\otimes TC{i+1}_{{\mathrm{{neuron}}}}$ (log scale)', fontsize=12)
    plt.title(f'UMAP embedding colored by TCA component {i+1}')
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    plt.show()

In [None]:
# now scatter the umap embedding color coded by the TCA component embedding
for i in range(num_components):
    # outer product of neuron and day componentsa
    c = np.outer(day_components[i], nrn_components[i, :]).flatten()
    # now do the log to get a logarithmic color scale (since the values are mostly close to zero, with some large outliers)
    c = np.log(np.abs(c) + 1e-5)  # add a small value to avoid log(0)
    
    fig, axs = plt.subplot_mosaic(mosaic='AAA\nAAA\nAAA\nAAA\nAAA\nAAA\nBCD', figsize=(10, 8), dpi=300)
    axs['A'].scatter(
        emb_umap_mn[:, 0],
        emb_umap_mn[:, 1],
        c=c,
        s=5,
        cmap='viridis'
    )
    axs['A'].axis('off')
    cbar = fig.colorbar(axs['A'].collections[0], ax=axs['A'], ticks=[], shrink=0.8)
    cbar.set_label(fr'$TC{i+1}_{{\mathrm{{day}}}}\otimes TC{i+1}_{{\mathrm{{neuron}}}}$ (log scale)', fontsize=12)
    axs['A'].set_title(f'UMAP embedding colored by TCA component {i+1}')
    axs['B'].plot(day_components[i], color='C2')
    axs['B'].set_title(f'Day component {i+1}')
    axs['B'].set_axis_off()
    axs['C'].plot(nrn_components[i, :], color='C0')
    axs['C'].set_title(f'Neuron component {i+1}')
    axs['C'].set_axis_off()
    axs['D'].plot(feature_components[i, :], color='C3')
    axs['D'].set_title(f'Feature component {i+1}')
    axs['D'].set_axis_off()

    plt.show()