In [1]:
import sys
sys.path.append("../")
sys.path.append("../Modules/")

import os

import numpy as np

%cd ../scripts/

/home/drfrbc/Neural-Modeling/scripts


In [2]:
sim_folder = '11-06-2024-11-56-10-STAs/Complex_Np5'

In [3]:
from scripts.compare_sta import get_all_directories_within, group_directories_by_cell_and_seed
sim_directory = '11-06-2024-11-56-10-STAs/'
sim_directories = get_all_directories_within(sim_directory)
grouped_directories = group_directories_by_cell_and_seed(sim_directories)

In [4]:
import analysis

Get observation data (Each segment's voltage, synaptic currents, ion channel currents)

In [7]:
parameters = analysis.DataReader.load_parameters(sim_folder)
parameters.channel_names.extend(['v', 'i_AMPA', 'i_NMDA'])
parameters.channel_names = list(np.unique(parameters.channel_names))

print(parameters.channel_names)

all_data_list = [analysis.DataReader.read_data(sim_folder, data_name).astype(np.float16) for data_name in parameters.channel_names]


# Combine them into a single 3D array of shape (nseg, ntimes, ndatatypes)
all_obs_data_matrix = np.stack(all_data_list, axis=-1)
print(all_obs_data_matrix.shape)

['gNaTa_t_NaTa_t' 'i_AMPA' 'i_NMDA' 'i_pas' 'ica' 'ica_Ca_HVA'
 'ica_Ca_LVAst' 'ihcn_Ih' 'ik' 'ik_Im' 'ik_SK_E2' 'ik_SKv3_1' 'ina'
 'ina_NaTa_t' 'ina_Nap_Et2' 'v']
(643, 120001, 16)


In [9]:
nsegs = all_obs_data_matrix.shape[0]
ntimes = all_obs_data_matrix.shape[1]
n_obs_datatypes = all_obs_data_matrix.shape[2]

Get state data (Each segment's spike classification over time and Soma spikes)

In [10]:
# get soma spike states
soma_spikes = analysis.DataReader.read_data(sim_folder, 'soma_spikes')
flat_soma_spikes = soma_spikes.flatten()
# must round because soma spikes are recorded at dt=0.1 ms while all data is sampled every 1 ms
rounded_indices = np.round(flat_soma_spikes).astype(int) # may need to drop duplicates

# create the binary sequence
# Determine the length of the binary sequence
sequence_length = ntimes
binary_sequence = np.zeros(sequence_length, dtype=int)
binary_sequence[rounded_indices] = 1
soma_spikes_binary = binary_sequence

In [12]:
parameters.channel_names

array(['gNaTa_t_NaTa_t', 'i_AMPA', 'i_NMDA', 'i_pas', 'ica', 'ica_Ca_HVA',
       'ica_Ca_LVAst', 'ihcn_Ih', 'ik', 'ik_Im', 'ik_SK_E2', 'ik_SKv3_1',
       'ina', 'ina_NaTa_t', 'ina_Nap_Et2', 'v'], dtype='<U14')

In [14]:
# get dendritic spike states
ica = all_obs_data_matrix[:, :, parameters.channel_names.index('ica')]
v = all_obs_data_matrix[:, :, parameters.channel_names.index('v')]
inmda = all_obs_data_matrix[:, :, parameters.channel_names.index('i_NMDA')]

nseg = all_obs_data_matrix.shape[0]
ntimes = all_obs_data_matrix.shape[1]
nspike_types = 3

# Initialize the 3D matrix
spike_matrix = np.zeros((nseg, ntimes, nspike_types), dtype=int)

for i in range(nseg):
    # Get Ca spikes
    left_bounds, right_bounds, _ = analysis.VoltageTrace.get_Ca_spikes(v[i], -40, ica[i])
    # Convert Ca spikes to binary sequence
    for start, end in zip(left_bounds, right_bounds):
        start_idx = int(np.round(start))
        end_idx = int(np.round(end))
        spike_matrix[i, start_idx:end_idx+1, 0] = 1

    # Get Na spikes and their durations
    left_bounds, right_bounds, _ = analysis.VoltageTrace.get_Na_spikes(v[i], 0.001 / 1000, soma_spikes, 2, v[i], v[0])
    # Convert NMDA spikes to binary sequence
    for start, end in zip(left_bounds, right_bounds):
        start_idx = int(np.round(start))
        end_idx = int(np.round(end))
        spike_matrix[i, start_idx:end_idx+1, 1] = 1
    # threshold = 0.001 / 1000
    # spikes, _ = analysis.VoltageTrace.get_Na_spikes(v[i], threshold, soma_spikes, 2, v[i], v[0])
    # if len(spikes) > 0:
    #     _, downward_crossing = analysis.VoltageTrace.get_crossings(v[i], threshold)
    #     durations = analysis.VoltageTrace.get_duration(spikes, downward_crossing)
    #     for spike_start, duration in zip(spikes, durations):
    #         start_idx = int(np.round(spike_start))
    #         end_idx = int(np.round(spike_start + duration))
    #         spike_matrix[i, start_idx:end_idx+1, 1] = 1

    # Get NMDA spikes
    left_bounds, right_bounds, _ = analysis.VoltageTrace.get_NMDA_spikes(v[i], -40, inmda[i])
    # Convert NMDA spikes to binary sequence
    for start, end in zip(left_bounds, right_bounds):
        start_idx = int(np.round(start))
        end_idx = int(np.round(end))
        spike_matrix[i, start_idx:end_idx+1, 2] = 1

  start_idx = int(np.round(start))
  end_idx = int(np.round(end))
  start_idx = int(np.round(start))
  end_idx = int(np.round(end))
  start_idx = int(np.round(start))
  end_idx = int(np.round(end))


In [15]:
spike_matrix.shape

(643, 120001, 3)

In [16]:

# Segment index (449)
segment_index = 449

# Spike type index for Ca spikes (0)
ca_spike_type_index = 0

# Calculate the total ntimes where segment 449 had Ca spikes
total_ca_spikes_ntimes = np.sum(spike_matrix[segment_index, :, ca_spike_type_index])

print(f"Total ntimes where segment {segment_index} had Ca spikes: {total_ca_spikes_ntimes}")

Total ntimes where segment 449 had Ca spikes: 92238


In [17]:
print(all_obs_data_matrix.shape) # currents/conductances
print(spike_matrix.shape) # spike occurences

(643, 120001, 16)
(643, 120001, 3)


Calculate State Transition Probabilities

In [18]:
import numpy as np
from itertools import product

# Assuming spike_matrix is shape (nseg, ntimes, nspike_types)
nstates = 2 ** spike_matrix.shape[2]  # Total number of possible states
nseg, ntimes, nspike_types = spike_matrix.shape

# Initialize the transition count matrix
transition_counts = np.zeros((nstates, nstates))

# Function to convert state vector to a unique state index
def state_vector_to_index(state_vector):
    return sum(val * (2 ** idx) for idx, val in enumerate(state_vector))

# Count transitions between states
for seg in range(nseg):
    for t in range(ntimes - 1):
        current_state = state_vector_to_index(spike_matrix[seg, t])
        next_state = state_vector_to_index(spike_matrix[seg, t + 1])
        transition_counts[current_state, next_state] += 1

# Normalize to get transition probabilities
transition_probabilities = transition_counts / transition_counts.sum(axis=1, keepdims=True)

# Handle cases where there are no transitions from a state to avoid division by zero
transition_probabilities = np.nan_to_num(transition_probabilities)

# Example printout to verify the matrix
print("Transition Counts:\n", transition_counts)
print("Transition Probabilities:\n", transition_probabilities)


Transition Counts:
 [[3.9222683e+07 6.1213800e+05 1.6250000e+03 4.8000000e+01 1.1845300e+05
  1.4489000e+05 1.2000000e+01 2.2000000e+01]
 [5.8402300e+05 2.5542380e+07 0.0000000e+00 3.9000000e+01 1.5800000e+02
  1.5279400e+05 0.0000000e+00 0.0000000e+00]
 [5.1600000e+02 0.0000000e+00 5.1600000e+02 1.1020000e+03 0.0000000e+00
  0.0000000e+00 0.0000000e+00 7.0000000e+00]
 [0.0000000e+00 1.1640000e+03 0.0000000e+00 9.4000000e+02 0.0000000e+00
  4.1000000e+01 0.0000000e+00 0.0000000e+00]
 [3.5607000e+04 2.1800000e+02 0.0000000e+00 0.0000000e+00 1.1434090e+06
  9.1971000e+04 0.0000000e+00 0.0000000e+00]
 [2.5704200e+05 1.2346900e+05 0.0000000e+00 0.0000000e+00 9.1850000e+03
  9.1154510e+06 0.0000000e+00 7.0000000e+00]
 [0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00 1.2000000e+01]
 [0.0000000e+00 2.5000000e+01 0.0000000e+00 1.6000000e+01 0.0000000e+00
  7.0000000e+00 0.0000000e+00 3.0000000e+01]]
Transition Probabilities:
 [[9.78124917e-01

State 0 (000): No spikes of any type.
State 1 (001): NMDA spikes only.
State 2 (010): Sodium (Na) spikes only.
State 3 (011): Sodium (Na) and NMDA spikes.
State 4 (100): Calcium (Ca) spikes only.
State 5 (101): Calcium (Ca) and NMDA spikes.
State 6 (110): Calcium (Ca) and Sodium (Na) spikes.
State 7 (111): Calcium (Ca), Sodium (Na), and NMDA spikes.

In [20]:
print(transition_counts.shape)
print(transition_probabilities.shape)

(8, 8)
(8, 8)


Calculate Emission Probabilities

In [36]:
from collections import defaultdict
from scipy.stats import multivariate_normal

from collections import defaultdict
import numpy as np

# Helper function to convert state vector to unique index
def state_vector_to_index(state_vector):
    return sum(val * (2 ** idx) for idx, val in enumerate(state_vector))

# Function to calculate emission probabilities
def calculate_emission_probabilities(combined_data, state_data, nstates):
    emissions = defaultdict(list)
    
    for seg in range(nseg):
        for t in range(ntimes):
            state_tuple = tuple(state_data[seg, t])
            state_index = state_vector_to_index(state_tuple)
            observations = combined_data[seg, t, :]
            emissions[state_index].append(observations)
    
    # Debugging: Print collected emissions
    for state, obs in emissions.items():
        print(f"State {state}: Collected {len(obs)} observations")

    # Calculate mean and covariance for each state
    emission_probabilities = {}
    for state in range(nstates):
        if emissions[state]:
            observations = np.array(emissions[state])
            print(f"State {state}: Observations shape {observations.shape}")  # Debugging
            mean = observations.mean(axis=0)
            cov = np.cov(observations, rowvar=False)
            emission_probabilities[state] = (mean, cov)
        else:
            print(f"State {state}: No observations")  # Debugging
            emission_probabilities[state] = (np.zeros(combined_data.shape[2]), np.eye(combined_data.shape[2]))
    
    return emission_probabilities

# Calculate the emission probabilities
emission_probabilities = calculate_emission_probabilities(all_obs_data_matrix, spike_matrix, nstates)


State 0: Collected 40100514 observations
State 2: Collected 2141 observations
State 3: Collected 2145 observations
State 1: Collected 26279394 observations
State 4: Collected 1271205 observations
State 5: Collected 9505154 observations
State 7: Collected 78 observations
State 6: Collected 12 observations
State 0: Observations shape (40100514, 16)


MemoryError: Unable to allocate 4.78 GiB for an array with shape (40100514, 16) and data type float64

In [37]:
import numpy as np
from collections import defaultdict

# Helper function to convert state vector to unique index
def state_vector_to_index(state_vector):
    return sum(val * (2 ** idx) for idx, val in enumerate(state_vector))

# Incremental mean and covariance calculation
class IncrementalCovariance:
    def __init__(self, n_features, dtype=np.float32):
        self.n_features = n_features
        self.dtype = dtype
        self.mean = np.zeros(n_features, dtype=dtype)
        self.covariance = np.zeros((n_features, n_features), dtype=dtype)
        self.n_samples = 0

    def update(self, x):
        self.n_samples += 1
        delta = x - self.mean
        self.mean += delta / self.n_samples
        self.covariance += np.outer(delta, x - self.mean)

    def finalize(self):
        return self.mean, self.covariance / (self.n_samples - 1) if self.n_samples > 1 else self.covariance

def calculate_emission_probabilities(combined_data, state_data, nstates):
    dtype = np.float32  # Change data type here
    emissions = defaultdict(list)
    
    for seg in range(nseg):
        for t in range(ntimes):
            state_tuple = tuple(state_data[seg, t])
            state_index = state_vector_to_index(state_tuple)
            observations = combined_data[seg, t, :].astype(dtype)
            emissions[state_index].append(observations)
    
    # Debugging: Print collected emissions
    for state, obs in emissions.items():
        print(f"State {state}: Collected {len(obs)} observations")

    # Calculate mean and covariance for each state
    emission_probabilities = {}
    for state in range(nstates):
        if emissions[state]:
            obs_iter = iter(emissions[state])
            first_obs = next(obs_iter)
            inc_cov = IncrementalCovariance(len(first_obs), dtype=dtype)
            inc_cov.update(first_obs)
            for observation in obs_iter:
                inc_cov.update(observation)
            mean, cov = inc_cov.finalize()
            print(f"State {state}: Mean shape {mean.shape}, Covariance shape {cov.shape}")  # Debugging
            emission_probabilities[state] = (mean, cov)
        else:
            print(f"State {state}: No observations")  # Debugging
            emission_probabilities[state] = (np.zeros(combined_data.shape[2], dtype=dtype), np.eye(combined_data.shape[2], dtype=dtype))
    
    return emission_probabilities

# Calculate the emission probabilities
emission_probabilities = calculate_emission_probabilities(all_obs_data_matrix, spike_matrix, nstates)


: 

In [33]:
print(emission_probabilities[2])

(array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

Combine Transition and Emission Probabilities

In [24]:
# Example function to get probability of an observation given a state
def get_emission_probability(observation, state, emission_probabilities):
    mean, cov = emission_probabilities[state]
    return multivariate_normal.pdf(observation, mean=mean, cov=cov)

# Example use of transition and emission probabilities for a given observation
def infer_state(observation, previous_state, transition_probabilities, emission_probabilities):
    max_prob = -1
    best_next_state = None
    for next_state in range(nstates):
        trans_prob = transition_probabilities[previous_state, next_state]
        emit_prob = get_emission_probability(observation, next_state, emission_probabilities)
        prob = trans_prob * emit_prob
        if prob > max_prob:
            max_prob = prob
            best_next_state = next_state
    return best_next_state


OLd below

In [None]:
import numpy as np
from sklearn.preprocessing import LabelEncoder
from hmmlearn import hmm

# Assuming 'data' is a NumPy array of shape (n_samples, n_features)
n_samples = 1000
n_features = 7
data = np.random.rand(n_samples, n_features)  # Replace with your actual data

# Generate labels for each sample
labels = np.random.choice(['CA_nexus', 'NMDA_distal', 'Soma_spike', 'Soma_burst'], n_samples)

# Encode labels to integers
# Encode labels to integersg
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(labels)


# Combine data and encoded labels into observation sequences
observation_sequences = np.column_stack((data, encoded_labels))


# Define the number of hidden states based on unique labels
n_states = len(np.unique(encoded_labels))

# Initialize the Gaussian HMM
model = hmm.GaussianHMM(n_components=n_states, covariance_type="diag", n_iter=100)

# Fit the model to the data
model.fit(data)

# Get the hidden states
hidden_states = model.predict(data)


In [None]:
hidden_states

In [None]:
hidden_states = model.predict(data)
plt.figure(figsize=(15, 6))
plt.plot(hidden_states, label='Hidden States')
plt.title('Hidden State Sequence Over Time')
plt.xlabel('Time')
plt.ylabel('State')
plt.legend()
plt.show()