In [3]:
import matplotlib.pyplot as plt
import numpy as np
import h5py

import matplotlib as mpl
import pickle
import pandas as pd
import jaxley as jx
from jaxley.channels import HH
from jaxley_mech.channels.fm97 import Na, K, KA, KCa, Ca, Leak
import jax.numpy as jnp
from jax import jit, vmap
from tensorflow.data import Dataset

import pickle
from itertools import chain

In [5]:
cell_id = "20161028_1"

In [7]:
stimuli_meta = pd.read_pickle(f"../results/data/stimuli_meta_{cell_id}.pkl")
bc_output = pd.read_pickle(f"../results/data/off_bc_output_{cell_id}.pkl")

In [8]:
bc_output = bc_output[bc_output["cell_id"] == "20161028_1"]
bc_output = bc_output[bc_output["rec_id"] == 1]

In [9]:
train_or_test = "train"

In [14]:
start_n_scan = 0 if train_or_test == "train" else 128 * 4
num_datapoints_per_scanfield = 128 * 4 if train_or_test == "train" else 128 * 1
cell_id = "20161028_1"  # "20170610_1", "20161028_1"
rec_ids = [1,2,9,14,15]  # 1,2,9,14,15 vs 1,4,5,6,8
# somatic: 1,3,4,7,10,12
# medium: 2,5,7,10
# far: 9, 11, 13, 14, 15

# Only loaded for visualization.
file = h5py.File("../data/noise.h5", 'r+')
noise_stimulus = file["k"][()]
noise_stimulus = noise_stimulus[:, :, start_n_scan:start_n_scan+num_datapoints_per_scanfield]
noise_full = np.concatenate([noise_stimulus for _ in range(len(rec_ids))], axis=2)

setup = pd.read_pickle("../results/data/setup.pkl")
recording_meta = pd.read_pickle("../results/data/recording_meta.pkl")
stimuli_meta = pd.read_pickle(f"../results/data/stimuli_meta_{cell_id}.pkl")
labels_df = pd.read_pickle(f"../results/data/labels_lowpass_{cell_id}.pkl")

# TODO Change to file that contains all outputs.
bc_output = pd.read_pickle(f"../results/data/off_bc_output_{cell_id}.pkl")

setup = setup[setup["cell_id"] == cell_id]
setup = setup[setup["rec_id"].isin(rec_ids)]

stimuli_meta = stimuli_meta[stimuli_meta["cell_id"] == cell_id]

bc_output = bc_output[bc_output["cell_id"] == cell_id]
bc_output = bc_output[bc_output["rec_id"].isin(rec_ids)]

recording_meta = recording_meta[recording_meta["cell_id"] == cell_id]
recording_meta = recording_meta[recording_meta["rec_id"].isin(rec_ids)]

labels_df = labels_df[labels_df["cell_id"] == cell_id]
labels_df = labels_df[labels_df["rec_id"].isin(rec_ids)]

# Contrain the number of labels.
constrained_ca_activities = np.stack(labels_df["ca"].to_numpy())[:, start_n_scan:start_n_scan+num_datapoints_per_scanfield].tolist()
labels_df["ca"] = constrained_ca_activities

constrained_activities = np.stack(bc_output["activity"].to_numpy())[:, start_n_scan:start_n_scan+num_datapoints_per_scanfield].tolist()
bc_output["activity"] = constrained_activities

# Contrain the number of stimulus images.
bc_output_concatenated = bc_output.groupby("bc_id", sort=False)["activity"].apply(lambda x: list(chain(*list(x))))

# Constrain to a single rec_id because, apart from the activity (which is dealt with above) the bc_outputs have the same info for every scanfield.
bc_output = bc_output[bc_output["rec_id"] == rec_ids[0]]
bc_output["activity"] = list(bc_output_concatenated.to_numpy())

# Join stimulus dfs.
stimuli = stimuli_meta.join(bc_output.set_index("bc_id"), on="bc_id", how="left", rsuffix="_bc")
stimuli = stimuli.drop(columns="cell_id_bc")

# Join recording dfs.
labels_df["unique_id"] = labels_df["rec_id"] * 100 + labels_df["roi_id"]
recording_meta["unique_id"] = recording_meta["rec_id"] * 100 + recording_meta["roi_id"]
recordings = recording_meta.join(labels_df.set_index("unique_id"), on="unique_id", how="left", rsuffix="_ca")
recordings = recordings.drop(columns=["cell_id_ca", "rec_id_ca"])

# Illustration of the task

In [42]:
with open("../results/01_illustration/bc_output.pkl", "wb") as handle:
    pickle.dump(bc_output, handle)

with open("../results/01_illustration/recordings.pkl", "wb") as handle:
    pickle.dump(recordings, handle)

In [17]:
# OPL hyperparameters.
gaussian_kernel_std = 20.0
kernel_size = 50

In [18]:
def gaussian_filter(spatial_axis, std=50):
    amp1 = 1 / std
    gaussian1 = amp1 * np.exp(-spatial_axis**2 / std**2)
    return gaussian1

def build_opl_kernel(filter: str, std, filter_size):
    res_filter = 100
    center = [0., 0.]
    
    pos_x = np.linspace(-filter_size, filter_size, res_filter)
    pos_y = np.linspace(-filter_size, filter_size, res_filter)
    X, Y = np.meshgrid(pos_x, pos_y)
    
    dist_x = center[0] - X
    dist_y = center[1] - Y
    
    dists = np.sqrt(dist_x**2 + dist_y**2)

    if filter == "Gaussian":
        kernel = gaussian_filter(dists, std) / 100.0
    elif filter == "center_surround":
        raise NotImplementedError
        kernel = center_surround_filter(dists) / 100.0
    else:
        raise ValueError

    return kernel, X, Y

In [19]:
kernel, X, Y = build_opl_kernel("Gaussian", gaussian_kernel_std, kernel_size)

In [40]:
with open("../results/01_illustration/kernel_x.pkl", "wb") as handle:
    pickle.dump(X, handle)

with open("../results/01_illustration/kernel_y.pkl", "wb") as handle:
    pickle.dump(Y, handle)

with open("../results/01_illustration/kernel_z.pkl", "wb") as handle:
    pickle.dump(kernel, handle)

In [22]:
class BipolarCell:
    """Given input current, return output of bipolar cells."""

    def __init__(self, max_input):
        self.x_vals = [-100, -50, -25, -12.5, -6.75, -3, 3, 6.75, 12.5, 25.0, 50.0, 100.0]
        self.response = [-0.05, -0.12, -0.15, -0.1, -0.08, -0.03, 0.1, 0.18, 0.37, 0.64, 0.85, 1.0]
        self.intensity = (1.0 + 1.0 / 100 * np.asarray(self.x_vals)) / 2.0

        # To scale the input-output curve, we have to know the maximal input current.
        self.max_input = max_input

    def __call__(self, input):
        standardized_bc_input = input / self.max_input
        bc_output = np.interp(standardized_bc_input, self.intensity, self.response)
        return bc_output

In [23]:
bc = BipolarCell(1.0)
inputs = np.linspace(-0.1, 1.1, 100)
vals = bc(inputs)

In [39]:
with open("../results/01_illustration/nonlinearity_inputs.pkl", "wb") as handle:
    pickle.dump(inputs, handle)

with open("../results/01_illustration/nonlinearity_vals.pkl", "wb") as handle:
    pickle.dump(vals, handle)