In [13]:
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import os
from matplotlib.backends.backend_pdf import PdfPages

from tqdm import tqdm
import pynwb

import bluepyopt as bpop
import bluepyopt.ephys as ephys
from neuron import h
import collections

In [14]:
io = pynwb.NWBHDF5IO("/media/ubuntu/sda/Patch-seq/data/Patch/601790945_icephys.nwb", 'r')
data = io.read()

In [15]:
current_time = {
    'step_1s': [301000, 201000, 401000],
    'step_fast': [101150],
    'slop': [1701000]
}

In [40]:
acquisition_data_dict = {}
stimulus_data_dict = {}

for key in current_time.keys():
    acquisition_data_dict[key] = pd.DataFrame()
    stimulus_data_dict[key] = pd.DataFrame()

for i in data.acquisition.keys():
    id = i.split("_")[1]
    acquisition = data.get_acquisition(f'data_{id}_AD0')
    if acquisition.data_type == 'CurrentClampSeries':
        stimulus = data.get_stimulus(f'data_{id}_DA0')
        for key in current_time.keys():
            if len(acquisition.data) in current_time[key]:
                acquisition_data = pd.DataFrame(np.array(acquisition.data))
                stimulus_data = pd.DataFrame(np.array(stimulus.data))

                acquisition_data_dict[key] = pd.concat((acquisition_data_dict[key], acquisition_data), axis=1)
                stimulus_data_dict[key] = pd.concat((stimulus_data_dict[key], stimulus_data), axis=1)

In [44]:
stage0_protocol = {}

for amplitude in stimulus_data_dict['step_1s'].max(axis = 0).unique():
    stage0_protocol[f'Amplitude_{int(amplitude)}'] = {}
    stage0_protocol[f'Amplitude_{int(amplitude)}']['stimuli'] = [
        {
            'amp': amplitude / 1000,
            'amp_end': amplitude /1000,
            'delay': 200,
            'duration': 1000,
            'stim_end': 1500,
            'totduration': 3000,
            'type': 'SquarePulse'
        }
    ]

In [45]:
stage0_protocol

{'Amplitude_50': {'stimuli': [{'amp': np.float32(0.05),
    'amp_end': np.float32(0.05),
    'delay': 200,
    'duration': 1000,
    'stim_end': 1500,
    'totduration': 3000,
    'type': 'SquarePulse'}]},
 'Amplitude_90': {'stimuli': [{'amp': np.float32(0.09),
    'amp_end': np.float32(0.09),
    'delay': 200,
    'duration': 1000,
    'stim_end': 1500,
    'totduration': 3000,
    'type': 'SquarePulse'}]},
 'Amplitude_130': {'stimuli': [{'amp': np.float32(0.13),
    'amp_end': np.float32(0.13),
    'delay': 200,
    'duration': 1000,
    'stim_end': 1500,
    'totduration': 3000,
    'type': 'SquarePulse'}]},
 'Amplitude_120': {'stimuli': [{'amp': np.float32(0.12),
    'amp_end': np.float32(0.12),
    'delay': 200,
    'duration': 1000,
    'stim_end': 1500,
    'totduration': 3000,
    'type': 'SquarePulse'}]},
 'Amplitude_110': {'stimuli': [{'amp': np.float32(0.11),
    'amp_end': np.float32(0.11),
    'delay': 200,
    'duration': 1000,
    'stim_end': 1500,
    'totduration': 300

In [17]:
key = 'step_1s'
acquisition_data_dict[key] = acquisition_data_dict[key].iloc[40000:110000, :].T
stimulus_data_dict[key] = stimulus_data_dict[key].iloc[40000:110000, :].T

key = 'step_fast'
acquisition_data_dict[key] = acquisition_data_dict[key].iloc[50000:55000, :].T
stimulus_data_dict[key] = stimulus_data_dict[key].iloc[50000:55000, :].T


In [18]:
test_raw_trace = acquisition_data_dict['step_1s'].iloc[22, :]
test_stimulus = stimulus_data_dict['step_1s'].iloc[22, :]

In [19]:
averaged_trace = np.array(test_raw_trace).reshape(-1, 50).max(axis=1)
averaged_stimulus = np.array(test_stimulus).reshape(-1, 50).max(axis=1)

In [33]:
a = ephys.models.CellModel('test')

In [20]:
import collections
from neuron import h
from bluepyopt import ephys

class AdExModel(ephys.models.CellModel):
    def __init__(self, name):
        super().__init__(name)
        
        # Create morphology (single soma section)
        self.morphology = ephys.morphologies.NrnFileMorphology('none')
        
        # Create mechanisms (none needed for AdEx point process)
        self.mechanisms = []
        
        # Define parameters
        self.params = self.define_parameters()
        
        # Create the point process (will be created in instantiate)
        self.adex = None
        self.soma = None

    def instantiate(self, sim=None):
        """Create the model in NEURON"""
        if sim is None:
            sim = ephys.simulators.NrnSimulator()
        
        # Create soma section
        self.soma = h.Section(name='soma', cell=self)
        self.soma.L = 20
        self.soma.diam = 20
        self.soma.cm = 1
        
        # Create AdEx point process
        self.adex = h.AdEx(0.5, sec=self.soma)
        
        # Set initial parameter values
        for param in self.params.values():
            param.instantiate(self.adex, sim=sim)
        
        # Create a section list for the morphology
        self.all = h.SectionList()
        self.all.append(sec=self.soma)
        
        # Create a section list for the soma
        self.somatic = h.SectionList()
        self.somatic.append(sec=self.soma)
        
        # Create a section list for the axon (empty)
        self.axonal = h.SectionList()
        
        # Create a section list for the basal dendrites (empty)
        self.basal = h.SectionList()
        
        # Create a section list for the apical dendrites (empty)
        self.apical = h.SectionList()

    def destroy(self, sim=None):
        """Clean up NEURON objects"""
        if sim is None:
            sim = ephys.simulators.NrnSimulator()
        
        # Destroy sections
        self.soma = None
        
        # Destroy point process
        if self.adex:
            self.adex = None
        
        # Destroy section lists
        self.all = None
        self.somatic = None
        self.axonal = None
        self.basal = None
        self.apical = None

    def define_parameters(self):
        """Define all modifiable parameters"""
        params = collections.OrderedDict()
        
        params['gL'] = ephys.parameters.NrnPointProcessParameter(
            name='gL',
            param_name='gL',
            value=30,
            bounds=[10, 50],  # nS
            frozen=False,
            locations=None
        )
        
        params['EL'] = ephys.parameters.NrnPointProcessParameter(
            name='EL',
            param_name='EL',
            value=-70.6,
            bounds=[-80, -60],  # mV
            frozen=False,
            locations=None
        )
        
        params['VT'] = ephys.parameters.NrnPointProcessParameter(
            name='VT',
            param_name='VT',
            value=-50.4,
            bounds=[-55, -45],  # mV
            frozen=False,
            locations=None
        )
        
        params['DeltaT'] = ephys.parameters.NrnPointProcessParameter(
            name='DeltaT',
            param_name='DeltaT',
            value=2,
            bounds=[0.5, 5],  # mV
            frozen=False,
            locations=None
        )
        
        params['Vr'] = ephys.parameters.NrnPointProcessParameter(
            name='Vr',
            param_name='Vr',
            value=-70.6,
            bounds=[-75, -65],  # mV
            frozen=False,
            locations=None
        )
        
        params['tauw'] = ephys.parameters.NrnPointProcessParameter(
            name='tauw',
            param_name='tauw',
            value=144,
            bounds=[50, 300],  # ms
            frozen=False,
            locations=None
        )
        
        params['a'] = ephys.parameters.NrnPointProcessParameter(
            name='a',
            param_name='a',
            value=4,
            bounds=[0.1, 10],  # nS
            frozen=False,
            locations=None
        )
        
        params['b'] = ephys.parameters.NrnPointProcessParameter(
            name='b',
            param_name='b',
            value=0.0805,
            bounds=[0.01, 0.2],  # nA
            frozen=False,
            locations=None
        )
        
        params['C'] = ephys.parameters.NrnPointProcessParameter(
            name='C',
            param_name='C',
            value=281,
            bounds=[100, 500],  # pF
            frozen=False,
            locations=None
        )
        
        return params
    
    def get_adex_point_process(self):
        """Get the AdEx point process"""
        return self.adex
    
    def get_soma(self):
        """Get the soma section"""
        return self.soma


In [21]:
def create_protocols(step_amplitude, step_delay, step_duration, total_duration):
    step_stim = ephys.stimuli.NrnSquarePulse(
        step_amplitude=step_amplitude,  # nA 
        step_delay=step_delay,       # ms
        step_duration=step_duration,   # ms
        location=ephys.locations.NrnSeclistLocation('somatic', seclist_name='somatic'),
        total_duration=total_duration   # ms
    )
    
    protocol = ephys.protocols.SweepProtocol(
        name='StepProtocol',
        stimuli=[step_stim],
        recordings=[
            ephys.recordings.CompRecording(
                name='soma.v',
                location=ephys.locations.NrnSeclistLocation('somatic', seclist_name='somatic'),
                variable='v'
            )
        ]
    )
    
    return protocol

In [22]:
from scipy import stats

class AdExFeatureExtractor(bpop.ephys.efeatures.eFELFeature):
    def __init__(self, feature_name, recording_names={'': 'soma.v'}):
        super().__init__(feature_name, recording_names=recording_names)
        self.spike_threshold = -30 
        self.ahp_window = 20  
        self.ap_analysis_window = 10  
        
    def calculate_feature(self, responses, feature_name):
        voltage = responses['soma.v'].values
        time = responses['soma.v'].index.values
        
        spike_indices, spike_times = self.detect_spikes(voltage, time)
        
        if feature_name == 'resting_potential':
            return self.calc_resting_potential(voltage, time)
        
        elif feature_name == 'steady_state_voltage':
            return self.calc_steady_state_voltage(voltage, time)
            
        elif feature_name == 'voltage_deflection':
            return self.calc_voltage_deflection(voltage, time)
        
        elif feature_name == 'spike_frequency':
            return self.calc_spike_frequency(spike_times)
        
        elif feature_name == 'ISI_slope':
            return self.calc_ISI_slope(spike_times)
        
        elif feature_name == 'adaptation_index':
            return self.calc_adaptation_index(spike_times)
        
        elif feature_name == 'time_to_first_spike':
            return self.calc_time_to_first_spike(spike_times)
        
        elif feature_name == 'AP_amplitude':
            return self.calc_AP_amplitude(voltage, spike_indices, spike_times)
        
        elif feature_name == 'AP_width':
            return self.calc_AP_width(voltage, time, spike_indices)
        
        elif feature_name == 'AHP_depth':
            return self.calc_AHP_depth(voltage, time, spike_indices)
                
        else:
            raise ValueError(f"Unknown feature: {feature_name}")
    
    def detect_spikes(self, voltage, time):
        dvdt = np.gradient(voltage, time)
        spike_indices = np.where(dvdt > 20)[0] 
        
        spikes = []
        spike_indices_refined = []
        
        for idx in spike_indices:
            if voltage[idx] > self.spike_threshold:
                window_start = max(0, idx - 5)
                window_end = min(len(voltage), idx + 5)
                peak_idx = np.argmax(voltage[window_start:window_end]) + window_start
                
                if not spikes or time[peak_idx] - spikes[-1] > 2:
                    spikes.append(time[peak_idx])
                    spike_indices_refined.append(peak_idx)
        
        return np.array(spike_indices_refined), np.array(spikes)
    
    def calc_resting_potential(self, voltage, time):
        pre_stim = (time > 50) & (time < 100)
        return np.mean(voltage[pre_stim]) if np.any(pre_stim) else np.nan
    
    def calc_steady_state_voltage(self, voltage, time):
        post_stim = (time > 1100) & (time < 1500)
        return np.mean(voltage[post_stim]) if np.any(post_stim) else np.nan
    
    def calc_voltage_deflection(self, voltage, time):
        stim_period = (time > 150) & (time < 1050)
        resting = self.calc_resting_potential(voltage, time)
        return np.mean(voltage[stim_period]) - resting if np.any(stim_period) else np.nan
    
    def calc_spike_frequency(self, spike_times):
        if len(spike_times) < 2:
            return 0
        return 1000 / np.mean(np.diff(spike_times))
    
    def calc_ISI_slope(self, spike_times):
        if len(spike_times) < 3:
            return 0
            
        isi = np.diff(spike_times)
        x = np.arange(len(isi))
        slope, _, _, _, _ = stats.linregress(x, isi)
        return slope
    
    def calc_adaptation_index(self, spike_times):
        if len(spike_times) < 3:
            return 0
            
        isi = np.diff(spike_times)
        first_isi = isi[0]
        last_isi = isi[-1]
        
        return (first_isi - last_isi) / (first_isi + last_isi) if (first_isi + last_isi) != 0 else 0
    
    def calc_time_to_first_spike(self, spike_times):
        return spike_times[0] - 100 if len(spike_times) > 0 else None
    
    def calc_AP_amplitude(self, voltage, spike_indices, spike_times):
        if len(spike_indices) == 0:
            return 0
            
        amplitudes = []
        for idx in spike_indices:
            search_start = max(0, idx - 20)
            threshold_idx = np.argmin(voltage[search_start:idx]) + search_start
            amplitude = voltage[idx] - voltage[threshold_idx]
            amplitudes.append(amplitude)
        
        return np.mean(amplitudes)
    
    def calc_AP_width(self, voltage, time, spike_indices):
        if len(spike_indices) == 0:
            return 0
            
        widths = []
        for idx in spike_indices:
            peak_voltage = voltage[idx]

            search_start = max(0, idx - 20)
            threshold_idx = np.argmin(voltage[search_start:idx]) + search_start
            threshold_voltage = voltage[threshold_idx]
            
            half_height = threshold_voltage + (peak_voltage - threshold_voltage) / 2
            
            rising_segment = voltage[threshold_idx:idx]
            rising_times = time[threshold_idx:idx]
            rising_cross = np.where(rising_segment > half_height)[0]
            rise_time = rising_times[rising_cross[0]] if rising_cross.size > 0 else time[idx]
            
            falling_segment = voltage[idx:min(len(voltage), idx+50)]
            falling_times = time[idx:min(len(voltage), idx+50)]
            falling_cross = np.where(falling_segment < half_height)[0]
            fall_time = falling_times[falling_cross[0]] if falling_cross.size > 0 else time[idx]
            
            width = fall_time - rise_time
            widths.append(width)
        
        return np.mean(widths)
    
    def calc_AHP_depth(self, voltage, time, spike_indices):
        if len(spike_indices) < 2:
            return 0
            
        ahp_depths = []
        for i in range(len(spike_indices)-1):
            start_idx = spike_indices[i]
            end_idx = spike_indices[i+1]
            
            window_start = start_idx + 1
            window_end = min(len(voltage), start_idx + self.ahp_window)
            
            if window_end - window_start > 2:
                min_idx = np.argmin(voltage[window_start:window_end]) + window_start
                ahp_depth = voltage[start_idx] - voltage[min_idx]
                ahp_depths.append(ahp_depth)
        
        return np.mean(ahp_depths) if ahp_depths else 0
    

In [23]:
def create_objectives(features, exp_features, exp_std):
    objectives = []
    
    for feature in features:        
        objective = bpop.ephys.objectives.SingletonObjective(
            name=f"{feature}_objective",
            feature=exp_features[feature]       
        )
        objectives.append(objective)
    
    return objectives

In [24]:
def create_optimizer(cell_model, protocol, objectives):
    sim = bpop.ephys.simulators.NrnSimulator()

    evaluator = bpop.ephys.evaluators.CellEvaluator(
            cell_model=cell_model,
            param_names=[p for p in cell_model.params.keys()],
            fitness_protocols={protocol.name: protocol},
            fitness_calculator=bpop.ephys.objectivescalculators.ObjectivesCalculator(objectives),
            #sim = sim,
            isolate_protocols = False
        )
    
    optimizer = bpop.optimisations.DEAPOptimisation(
        evaluator=evaluator,
        offspring_size=50,
        map_function=map,
        seed=1
    )
    
    return optimizer, evaluator

- Stage0

In [25]:
acquisition_data_dict = {}
stimulus_data_dict = {}
feature_extractor = AdExFeatureExtractor('all_features')
feature_names_constant = [
    'resting_potential', 'steady_state_voltage']

feature_names_specific = [
    'spike_frequency', 'ISI_slope', 'adaptation_index', 'time_to_first_spike', #'voltage_deflection',
    'AP_amplitude', 'AP_width', 'AHP_depth'
]

features_constant = {}
features_specific = {}

for feature in feature_names_constant:
    features_constant[feature] = []

for i in data.acquisition.keys():
    id = i.split("_")[1]
    acquisition = data.get_acquisition(f'data_{id}_AD0')
    if acquisition.data_type == 'CurrentClampSeries':
        if len(acquisition.data) in [301000, 201000, 401000]:
            stimulus = data.get_stimulus(f'data_{id}_DA0')
            inject_current = np.array(stimulus.data).max()
            
            features_specific[str(inject_current)] = {}

            voltage = acquisition.data[45000:120000]
            timestamps = np.linspace(0, 1500, len(voltage))
            responses = {'soma.v': pd.Series(voltage, index=timestamps)}
            
            for feature in feature_names_specific:
                try:
                    features_specific[str(inject_current)][feature] = feature_extractor.calculate_feature(
                        responses=responses, feature_name=feature)
                except Exception as e:
                    print(f"Error calculating {feature} for {1}: {str(e)}")
                    features_specific[feature] = np.nan

            for feature in feature_names_constant:
                try:
                    features_constant[feature].append(feature_extractor.calculate_feature(
                        responses=responses, feature_name=feature))
                except Exception as e:
                    print(f"Error calculating {feature} for {1}: {str(e)}")
                    features_constant[feature].append(np.nan)
                    

In [26]:
features_constant_mean = {}
features_constant_std = {}
for key in features_constant.keys():
    features_constant_mean[key] = np.array(features_constant[key]).mean()
    features_constant_std[key] = np.array(features_constant[key]).std()

In [27]:
from neuron import h
h.load_file("stdlib.hoc")
h.load_file("stdrun.hoc")
h.nrn_load_dll("/media/ubuntu/sda/Patch-seq/data/AdEx_Neuron/x86_64/.libs/libnrnmech.so")  

1.0

In [28]:
stage0_model = AdExModel(name="stage0")
for name in ['gL', 'EL', 'C']:
    stage0_model.params[name].frozen = False

stage0_protocol = create_protocols(170, 100, 1000, 1500)

stage0_features = ['resting_potential', 'steady_state_voltage', #'voltage_deflection'
                   ]

stage0_objectives = create_objectives(stage0_features, exp_features=features_constant_mean, exp_std=features_constant_std)



In [29]:
stage0_optimizer, stage0_evaluator = create_optimizer(
    stage0_model, stage0_protocol, stage0_objectives)


ValueError: CellEvaluator: you have to provide a Simulator object to the 'sim' argument of the CellEvaluator constructor

In [None]:
stage0_optimizer.run()

SweepProtocolException: Failed to run Neuron Sweep Protocol

In [None]:


stage0_final_pop, stage0_hall_of_fame, stage0_log, stage0_history = stage0_optimizer.run(
        max_ngen=50)

best_stage0_params = stage0_hall_of_fame[0]
print("Stage 0 best parameters:", best_stage0_params)

[<bluepyopt.ephys.objectives.SingletonObjective at 0x7f0603de5010>,
 <bluepyopt.ephys.objectives.SingletonObjective at 0x7f0603af07c0>]

- Stage1

In [None]:
stage1_model = AdExModel("AdEx_Stage1")
for param in stage1_model.params:
    if param.name in best_stage0_params:
        param.value = best_stage0_params[param.name]

for param in stage1_model.params:
        param.frozen = False

stage1_protocols = []
for amp in [0.1, 0.15, 0.2]:  
    stim = ephys.stimuli.NrnSquarePulse(
        step_amplitude=amp,
        step_delay=100,
        step_duration=1000,
        location=ephys.locations.NrnSeclistLocation('somatic', seclist_name='somatic'),
        total_duration=1200
    )

    protocol = ephys.protocols.SweepProtocol(
        name=f'StepProtocol_{amp}',
        stimuli=[stim],
        recordings=[
            ephys.recordings.CompRecording(
                name='soma.v',
                location=ephys.locations.NrnSeclistLocation('somatic', seclist_name='somatic'),
                variable='v'
            )
        ]
    )
    stage1_protocols.append(protocol)

stage1_compound_protocol = ephys.protocols.CompoundProtocol(
    protocols=stage1_protocols)

stage1_features = [
    'resting_potential', 'steady_state_voltage', 'voltage_deflection',
    'spike_frequency', 'ISI_slope', 'adaptation_index', 'time_to_first_spike',
    'AP_amplitude', 'AP_width', 'AHP_depth'
]

stage1_objectives = []
for protocol in stage1_protocols:
    stage1_objectives.extend(create_objectives(protocol, stage1_features))

stage1_evaluator = bpop.ephys.evaluators.CellEvaluator(
    cell_model=stage1_model,
    param_names=[p.name for p in stage1_model.params],
    fitness_protocols={p.name: p for p in stage1_protocols},
    fitness_calculator=bpop.ephys.objectivescalculators.ObjectivesCalculator(stage1_objectives)
)

stage1_optimizer = bpop.optimisations.DEAPOptimisation(
    evaluator=stage1_evaluator,
    offspring_size=100,
    map_function=map,
    seed=42
)

stage1_final_pop, stage1_hall_of_fame, stage1_log, stage1_history = stage1_optimizer.run(
    max_ngen=200)

best_stage1_params = stage1_hall_of_fame[0]
print("Stage 1 best parameters:", best_stage1_params)