In [8]:
import numpy as np
import os
from mozaik.storage.queries import *
from mozaik.storage.datastore import PickledDataStore
from mozaik.tools.distribution_parametrization import PyNNDistribution
from parameters import ParameterSet

import matplotlib.pyplot as plt
from pathlib import Path


def load_datastore(base_dir):
    return PickledDataStore(
        load=True,
        parameters=ParameterSet(
            {"root_directory": base_dir, "store_stimuli": False}
        ),
        replace=False,
    )

def get_segments(data_store, sheet_name=None):
    """
    Returns a list of segments in the DataStore, ordered by their identifier.
    Optionally filters the segments for results from a specific sheet.

    Parameters
    ----------

    data_store : Datastore to retrieve segments from
    sheet_name : name of neuron sheet (layer) to retrieve the recorded segments for

    Returns
    -------
    A list of all segments in a DataStore, optionally filtered for sheet name.
    """
    dsv = param_filter_query(
        data_store, st_name="EndOfSimulationBlank", negative=True
    )

    if sheet_name is None:
        # If no sheet name specified, load all sheets
        return dsv.get_segments()
    else:
        return sorted(
            param_filter_query(dsv, sheet_name=sheet_name).get_segments(),
            key=lambda x: x.identifier,
        )

def get_voltages(data_store, sheet_name=None, max_neurons=None):
    """
    Returns the recorded membrane potentials for neurons recorded in the DataStore,
    ordered by segment identifiers and neuron ids, flattened into a single 1D vector.
    Can be optionally filtered for neuron sheet, and maximum number of neurons.

    Parameters
    ----------

    data_store : Datastore to retrieve voltages from
    sheet_name : name of neuron sheet (layer) to retrieve the recorded voltages from
    max_neurons : maximum number of neurons to get the voltages for

    Returns
    -------
    A 1D list of membrane potential voltages recorded in neurons in the DataStore
    """

    segments = get_segments(data_store, sheet_name)
    return [
        v
        for segment in segments
        for neuron_id in sorted(segment.get_stored_vm_ids())[:max_neurons]
        for v in segment.get_vm(neuron_id).flatten()
    ]

def get_spikes(data_store, sheet_name=None, max_neurons=None):
    """
    Returns the recorded spike times for neurons recorded in the DataStore,
    ordered by segment identifiers and neuron ids, flattened into a single 1D vector.
    Can be optionally filtered for neuron sheet, and maximum number of neurons.

    Parameters
    ----------

    data_store : Datastore to retrieve spike times from
    sheet_name : name of neuron sheet (layer) to retrieve the recorded spike times from
    max_neurons : maximum number of neurons to get spike times for

    Returns
    -------
    A 1D list of spike times recorded in neurons in the DataStore
    """
    segments = get_segments(data_store, sheet_name)
    return [
        v
        for segment in segments
        for neuron_id in sorted(segment.get_stored_spike_train_ids())[:max_neurons]
        for v in segment.get_spiketrain(neuron_id).flatten()
    ]

def check_spikes(ds0, ds1, sheet_name=None, max_neurons=None):
    """
    Check if spike times recorded in two DataStores are equal. Spike times are merged
    into a single 1D array and compared using numpy assertions.
    Can be optionally filtered for neuron sheet, and maximum number of neurons.

    Parameters
    ----------

    ds0, ds1 : DataStores to retrieve spike times from
    sheet_name : name of neuron sheet (layer) to check spike times for
    max_neurons : maximum number of neurons to check spike times for
    """
    np.testing.assert_equal(
        get_spikes(ds0, sheet_name, max_neurons),
        get_spikes(ds1, sheet_name, max_neurons),
    )

def check_voltages(ds0, ds1, sheet_name=None, max_neurons=None):
    """
    Check if membrane potential voltages recorded in two DataStores are equal. Voltages
    are merged into a single 1D array and compared using numpy assertions.
    Can be optionally filtered for neuron sheet, and maximum number of neurons.

    Parameters
    ----------

    ds0, ds1 : DataStores to retrieve spike times from
    sheet_name : name of neuron sheet (layer) to check voltages for
    max_neurons : maximum number of neurons to check voltages for
    """

    print(len(get_voltages(ds0, sheet_name, max_neurons)), flush=True)
    print(len(get_voltages(ds1, sheet_name, max_neurons)), flush=True)
    np.testing.assert_equal(
        get_voltages(ds0, sheet_name, max_neurons),
        get_voltages(ds1, sheet_name, max_neurons),
    )

**TODO:**
1. comapare params
2. compare figs
3. compare datastore
    - test spikes
    - test voltages 


In [3]:
# Data from Remys model
paper_model = '/home/haman/layers56/LSV1M_paper/'
# Data from referenced model
ref_model = '/home/haman/layers56/LSV1M_refs-update/'

spont = '20241129-Spont/SelfSustainedPushPull_ParameterSearch_____trial:1'
size = '20241129-SizeTuning/SelfSustainedPushPull_ParameterSearch_____trial:1'
orient = '20241129-OrientTuning/SelfSustainedPushPull_ParameterSearch_____trial:1'

# "sheet_name", ["V1_Exc_L4", "V1_Inh_L4", "V1_Exc_L2/3", "V1_Inh_L2/3"]


In [None]:
for experiment in [spont, size, orient]:
    print(experiment.split('/')[0].split('-')[1])
    paper = load_datastore(paper_model + experiment)
    ref = load_datastore(ref_model + experiment)

    for sheet in ["V1_Exc_L4", "V1_Inh_L4", "V1_Exc_L2/3", "V1_Inh_L2/3"]:
        print(sheet)
        check_voltages(paper, ref, sheet, max_neurons=5)
        check_spikes(paper, ref, sheet, max_neurons=5)

Spont
V1_Exc_L4
201600
201600
V1_Inh_L4
201600
201600
V1_Exc_L2/3
201600
201600
V1_Inh_L2/3
201600
201600
SizeTuning
V1_Exc_L4
2410800
2410800
V1_Inh_L4
2410800
2410800
V1_Exc_L2/3
2410800
2410800
V1_Inh_L2/3
2410800
2410800
OrientTuning
V1_Exc_L4
3304700


In [9]:
get_segments(paper, "V1_Exc_L4")
get_voltages(paper, "V1_Exc_L4")

[array(-66.14553378) * mV,
 array(-66.08566572) * mV,
 array(-67.21809941) * mV,
 array(-67.69979669) * mV,
 array(-68.67354227) * mV,
 array(-69.25099729) * mV,
 array(-70.28812457) * mV,
 array(-71.24279548) * mV,
 array(-71.81908005) * mV,
 array(-72.2558431) * mV,
 array(-72.48595575) * mV,
 array(-72.45390447) * mV,
 array(-72.12378448) * mV,
 array(-71.9928111) * mV,
 array(-72.89747556) * mV,
 array(-73.06498061) * mV,
 array(-71.22067931) * mV,
 array(-69.02953449) * mV,
 array(-65.98745755) * mV,
 array(-63.3747007) * mV,
 array(-62.62393046) * mV,
 array(-64.15459631) * mV,
 array(-67.99869021) * mV,
 array(-70.62408309) * mV,
 array(-71.6621623) * mV,
 array(-71.23827294) * mV,
 array(-71.77418182) * mV,
 array(-72.35769567) * mV,
 array(-72.681688) * mV,
 array(-72.36751569) * mV,
 array(-71.34411108) * mV,
 array(-71.30657549) * mV,
 array(-70.99645863) * mV,
 array(-70.5980364) * mV,
 array(-69.9543057) * mV,
 array(-68.91962378) * mV,
 array(-68.52422508) * mV,
 array(-6