# Generate results for panels c and e

This performs the forward passes on all test data (for panel c) and on all data (for panel e).

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax import config

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "gpu")

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8"

In [3]:
import time

import jax.numpy as jnp
from jax import jit, vmap, value_and_grad
import numpy as np
import matplotlib.pyplot as plt
import pickle
import matplotlib as mpl
import h5py
from functools import partial
import jax
import jaxlib

import jaxley as jx

from nex.rgc.utils.data_utils import (
    read_data,
    build_avg_recordings,
    build_training_data,
)
from nex.rgc.utils.utils import (
    build_cell,
    build_kernel,
)
from nex.rgc.utils.rf_utils import compute_all_trained_rfs
from nex.rgc.simulate import (
    predict,
    simulate,
)

In [4]:
import pandas as pd
from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

In [5]:
print(f"jax {jax.__version__}")
print(f"jaxlib {jaxlib.__version__}")
print(f"pandas {pd.__version__}")
print(f"numpy {np.__version__}")

jax 0.4.30
jaxlib 0.4.30
pandas 2.2.0
numpy 1.26.4


In [6]:
path_prefix = "../../../nex/rgc"
results_prefix = "results/train_runs/2024_08_01__20_57_21/0"  # "results/train_runs/2024_05_30__10_55_30/0"

In [7]:
start_n_scan = 0
num_datapoints_per_scanfield = 128 * 8
nseg = 4
cell_id = "20161028_1"
rec_ids = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]

In [8]:
stimuli, recordings, setup, noise_full = read_data(
    start_n_scan,
    num_datapoints_per_scanfield,
    cell_id,
    rec_ids,
    "noise",
    ".."
)

# avg_recordings = build_avg_recordings(
#     recordings, rec_ids, nseg, num_datapoints_per_scanfield
# )
# with open("../results/intermediate/avg_recordings.pkl", "wb") as handle:
#     pickle.dump(avg_recordings, handle)
with open(f"{path_prefix}/results/intermediate/avg_recordings.pkl", "rb") as handle:
    avg_recordings = pickle.load(handle)

number_of_recordings_each_scanfield = list(avg_recordings.groupby("rec_id").size())

In [9]:
warmup = 5.0
i_amp = 0.1

currents, labels, loss_weights = build_training_data(
    i_amp,
    stimuli,
    avg_recordings,
    rec_ids, 
    num_datapoints_per_scanfield,
    number_of_recordings_each_scanfield,
)

stim_branch_inds = stimuli["branch_ind"].to_numpy()
stim_comps = stimuli["comp"].to_numpy()

In [10]:
dt = 0.025
t_max = 200.0
time_vec = np.arange(0, t_max+2*dt, dt)

In [11]:
with open(f"{path_prefix}/{results_prefix}/cell.pkl", "rb") as handle:
    cell = pickle.load(handle)

basal_inds = list(np.unique(cell.group_nodes["basal"]["branch_index"].to_numpy()))
somatic_inds = list(np.unique(cell.group_nodes["soma"]["branch_index"].to_numpy()))

cell.delete_recordings()
cell.delete_stimuli()

for i, rec in avg_recordings.iterrows():
    cell.branch(rec["branch_ind"]).loc(rec["comp"]).record("Cai", verbose=False)

for i, rec in avg_recordings.iterrows():
    cell.branch(rec["branch_ind"]).loc(rec["comp"]).record("v", verbose=False)

print(f"Inserted {len(cell.recordings)} recordings")
print(f"number_of_recordings_each_scanfield {number_of_recordings_each_scanfield}")
number_of_recordings = np.sum(number_of_recordings_each_scanfield)

Inserted 294 recordings
number_of_recordings_each_scanfield [12, 6, 15, 21, 13, 10, 9, 10, 10, 6, 4, 11, 8, 4, 8]


In [12]:
_, init_states = jx.integrate(cell, t_max=warmup, return_states=True)

2024-08-06 08:02:29.369396: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [13]:
cell.delete_trainables()
cell.basal.branch("all").make_trainable("axial_resistivity")
cell.basal.branch("all").make_trainable("radius")

Number of newly added trainable parameters: 154. Total number of trainable parameters: 154
Number of newly added trainable parameters: 154. Total number of trainable parameters: 308


In [14]:
kernel = build_kernel(time_vec, dt)
output_scale = jnp.asarray(60.0)
output_offset = jnp.asarray(-1.3)

In [15]:
with open(f"{path_prefix}/{results_prefix}/opt_params/params_16.pkl", "rb") as handle:
    all_opt_params = pickle.load(handle)

with open(f"{path_prefix}/{results_prefix}/transforms/transform_params.pkl", "rb") as handle:
    transform_params = pickle.load(handle)

with open(f"{path_prefix}/{results_prefix}/transforms/transform_basal.pkl", "rb") as handle:
    transform_basal = pickle.load(handle)

with open(f"{path_prefix}/{results_prefix}/transforms/transform_somatic.pkl", "rb") as handle:
    transform_somatic = pickle.load(handle)

opt_params, opt_basal_params, opt_somatic_params = all_opt_params

parameters = transform_params.forward(opt_params)
basal_neuron_params = transform_basal.forward(opt_basal_params)
somatic_neuron_params = transform_somatic.forward(opt_somatic_params)

In [16]:
all_ = 0
for a in all_opt_params:
    for b in a:
        key = list(b.keys())[0]
        val = b[key]
        all_ += np.prod(val.shape)

print(f"Number of synaptic parameters:   {len(all_opt_params[0][0]['w_bc_to_rgc'])}")
print(f"total number of parameters:      {int(all_)}")
print(f"total number of branches:        {len(all_opt_params[0][1]['axial_resistivity'])+1}")  # Plus 1 for soma.


Number of synaptic parameters:   287
total number of parameters:      607
total number of branches:        155


In [17]:
num_truncations = 4

static = {
    "cell": cell,
    "dt": dt,
    "t_max": t_max,
    "time_vec": time_vec,
    "num_truncations": num_truncations,
    "output_scale": output_scale,
    "output_offset": output_offset,
    "kernel": kernel,
    "somatic_inds": somatic_inds,
    "basal_inds": basal_inds,
    "stim_branch_inds": stim_branch_inds,
    "stim_comps": stim_comps,
}
vmapped_predict = jit(vmap(partial(predict, static=static), in_axes=(None, None, None, 0, None)))

In [18]:
def evaluate(split_images, split_labels, split_currents, split_masks):
    num_batches = 15
    batch_size_eval = 128 * 8
    num = 128 * 8
    num_batches = int(np.ceil(len(split_masks) / batch_size_eval))
    print("num batches", num_batches)
    batches = range(num_batches)
    
    all_ca_predictions = []
    all_ca_predictions_untrained = []
    all_ca_recordings = []
    all_images = []
    all_masks = []
    
    for k in batches:
        print("k", k)
    
        test_images = split_images[:, :, k*batch_size_eval:k*batch_size_eval+num]
        test_currents = split_currents[k*batch_size_eval:k*batch_size_eval+num]
    
        all_images.append(test_images)
        all_ca_recordings.append(split_labels[k*batch_size_eval:k*batch_size_eval+num])
        all_masks.append(split_masks[k*batch_size_eval:k*batch_size_eval+num])
    
        # Trained.
        ca_predictions = vmapped_predict(
            parameters,
            basal_neuron_params,
            somatic_neuron_params,
            test_currents,
            init_states,
        )
        all_ca_predictions.append(ca_predictions)
    
    
    all_images = np.concatenate(all_images, axis=2)
    all_ca_recordings = np.concatenate(all_ca_recordings, axis=0)
    all_ca_predictions = np.concatenate(all_ca_predictions, axis=0)
    all_masks = np.concatenate(all_masks, axis=0)

    return all_images, all_ca_recordings, all_ca_predictions, all_masks


### Panel c: Calcium vs model correlation plot; based on test data

In [19]:
with open(f"{path_prefix}/{results_prefix}/data/train_inds.pkl", "rb") as handle:
    train_inds = pickle.load(handle)

with open(f"{path_prefix}/{results_prefix}/data/val_inds.pkl", "rb") as handle:
    val_inds = pickle.load(handle)

with open(f"{path_prefix}/{results_prefix}/data/test_inds.pkl", "rb") as handle:
    test_inds = pickle.load(handle)


In [20]:
inds = test_inds
split_images = noise_full[:, :, inds]
split_labels = labels[inds]
split_currents = currents[inds]
split_masks = loss_weights[inds]

In [21]:
test_images, test_ca_recordings, test_ca_predictions, test_masks = evaluate(
    split_images, split_labels, split_currents, split_masks
)

num batches 3
k 0
k 1
k 2


In [24]:
rhos = []
for roi_id in range(147):
    roi_was_measured = test_masks[:, roi_id].astype(bool)

    rho_trained = np.corrcoef(
        test_ca_recordings[roi_was_measured, roi_id], 
        test_ca_predictions[roi_was_measured, roi_id]
    )[0, 1]
    rhos.append(rho_trained)
print("mean", np.mean(rhos))
print("max", np.max(rhos))
print("std", np.std(rhos))
print("larger > 0", np.sum(np.asarray(rhos) > 0.0), "out of", len(rhos))

mean 0.24699644990499506
max 0.5119800747722698
std 0.12544279124008825
larger > 0 146 out of 147


In [25]:
roi_ids = np.cumsum([0] + number_of_recordings_each_scanfield)[:-1]

with open("../results/03_results/test_images.pkl", "wb") as handle:
    pickle.dump(test_images, handle)

with open("../results/03_results/test_ca_recordings.pkl", "wb") as handle:
    pickle.dump(test_ca_recordings, handle)

with open("../results/03_results/test_ca_predictions.pkl", "wb") as handle:
    pickle.dump(test_ca_predictions, handle)

with open("../results/03_results/test_masks.pkl", "wb") as handle:
    pickle.dump(test_masks, handle)

with open("../results/03_results/roi_ids.pkl", "wb") as handle:
    pickle.dump(roi_ids, handle)


# Panel e: forward sims for receptive fields, based on merged train/val/test data

In [26]:
split_images = noise_full
split_labels = labels
split_currents = currents
split_masks = loss_weights

In [27]:
all_images, all_ca_recordings, all_ca_predictions, all_masks = evaluate(
    split_images, split_labels, split_currents, split_masks
)

num batches 15
k 0
k 1
k 2
k 3
k 4
k 5
k 6
k 7
k 8
k 9
k 10
k 11
k 12
k 13
k 14


In [31]:
rhos = []
for roi_id in range(147):
    roi_was_measured = all_masks[:, roi_id].astype(bool)

    rho_trained = np.corrcoef(
        all_ca_recordings[roi_was_measured, roi_id], 
        all_ca_predictions[roi_was_measured, roi_id]
    )[0, 1]
    rhos.append(rho_trained)
print("mean", np.mean(rhos))
print("max", np.max(rhos))
print("std", np.std(rhos))
print("larger > 0:", np.sum(np.asarray(rhos) > 0.0), "out of", len(rhos))

mean 0.2520464123394455
max 0.46443021518596933
std 0.11246235871375339
larger > 0: 147 out of 147


In [32]:
with open("../results/03_results/all_images.pkl", "wb") as handle:
    pickle.dump(all_images, handle)

with open("../results/03_results/all_ca_recordings.pkl", "wb") as handle:
    pickle.dump(all_ca_recordings, handle)

with open("../results/03_results/all_ca_predictions.pkl", "wb") as handle:
    pickle.dump(all_ca_predictions, handle)

with open("../results/03_results/all_masks.pkl", "wb") as handle:
    pickle.dump(all_masks, handle)
