In [1]:
import math
from GNAT.network import Synapse, PhysNetwork
from GNAT.quadtree import Spike, SpikePair, QuadTree
from GNAT.raster import SpikeRaster

In [2]:
class GnatFinder:
    def __init__(self, n_cells, qtarray_size):
        self.g_raster = SpikeRaster(n_cells)
        self.g_network = PhysNetwork()
        self.g_qtarray = [QuadTree() for _ in range(qtarray_size)]



In [3]:
def compute_gnat_edges(tau, thresh, c_radius):
    for post_idx in range(g_network.n_cells):
        # Print status
        if post_idx % 10 == 0:
            print(f"Cell {post_idx} of {g_network.n_cells}")

        # Iterate over spike pairs in post qtree
        sp_a = g_raster.sp_lists[post_idx]
        while sp_a:
            sp_b = sp_a.next
            while sp_b:
                if not spike_equals(sp_a, sp_b):
                    spp_post = create_spike_pair(sp_a, sp_b)
                    # print_spike_pair(spp_post)  # Uncomment if needed

                    # List of presynaptic partners
                    presyn = g_network.presyns[post_idx]

                    while presyn:
                        # Quadtree associated to presynaptic neuron
                        presyn_qtree = g_qtarray[presyn.src_id]

                        # Set query bounding box
                        query_bbox = BoundingBox(c_x=spp_post.sp1.ts, c_y=spp_post.sp2.ts, w2=c_radius)

                        # Apply edge test to queried range
                        QTreeMapGNATEdge(presyn_qtree, query_bbox, spp_post, presyn, tau, thresh)
                        presyn = presyn.next

                sp_b = sp_b.next
            sp_a = sp_a.next


In [5]:
import sys

def main():
    # Check number of arguments
    if len(sys.argv) < 7:
        print(f"Usage: {sys.argv[0]} <N cells> <spike file> <network file> <tau> <thresh> <causal_radius>")
        sys.exit(-1)

    _n_cells = int(sys.argv[1])
    tau = float(sys.argv[4])
    thresh = float(sys.argv[5])
    c_radius = float(sys.argv[6])

    # Initialize raster and network
    g_raster = SpikeRaster(_n_cells)
    g_network = PhysNetwork(_n_cells)

    # Read spikes from file into global raster
    RasterReadFile(g_raster, sys.argv[2])

    # Read network connectivity file
    PhysNetworkReadFile(g_network, sys.argv[3])
    # Uncomment if needed: PhysNetworkPrint(g_network)

    # Build top-level bounding box
    _cx = (g_raster.t_max + g_raster.t_min) / 2
    _cy = _cx
    _hw = (g_raster.t_max - g_raster.t_min) / 2
    bbox_top_level = BBoxCreate(_cx, _cy, _hw)

    # Build quadtrees for each cell
    g_qtarray = [QTreeCreate(bbox_top_level) for _ in range(_n_cells)]
    for idx in range(_n_cells):
        insert_spike_pairs(g_qtarray[idx], g_raster.sp_lists[idx])
        # Uncomment if needed for debugging:
        # print("-------- QuadTree --------")
        # QTreePrint(g_qtarray[idx])
        # print("-------- End QuadTree --------")

    # Initialize output file
    initialize_edge_buffer("gnat2_out.txt")

    # Compute gnats
    compute_gnat_edges(tau, thresh, c_radius)

    # Clean up
    finalize_edge_buffer()


In [6]:
main()

Usage: /Users/marco_cmp/opt/anaconda3/envs/GNATS/lib/python3.10/site-packages/ipykernel_launcher.py <N cells> <spike file> <network file> <tau> <thresh> <causal_radius>


SystemExit: -1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
