In [4]:
import math
from GNAT.network import Synapse, PhysNetwork
from GNAT.quadtree import Spike, SpikePair, QuadTree

# Your code that uses these classes


In [3]:

# Constants
LARGE_GAMMA = 999999
N_EDGBUF = 8192  # Assuming this is defined in gnats.h

# Global variables
g_edgbuf = [None] * N_EDGBUF  # Initializing a list to act as the edge buffer
edgbuf_sz = 0  # Number of edges currently in buffer
fp_edgbuf = None  # File pointer for edge buffer, initialized to None

# Your function implementations and other code here

# Class definitions
class GNATEdge:
    def __init__(self, spp_pre=None, spp_post=None, cd_ratio=0.0):
        self.spp_pre = spp_pre
        self.spp_post = spp_post
        self.cd_ratio = cd_ratio

# Function definitions
def finalize_edge_buffer():
    flush_edge_buffer()
    if fp_edgbuf is not None:
        fp_edgbuf.close()


def initialize_edge_buffer(fname):
    global fp_edgbuf, edgbuf_sz, g_edgbuf

    try:
        fp_edgbuf = open(fname, "w")
    except IOError:
        print(f"FATAL: unable to open output file {fname}")
        exit(-1)

    edgbuf_sz = 0
    g_edgbuf = [None] * N_EDGBUF
    return 0

def QTreeMapGNATEdge(qt, r, spp_post, syn, tau, theta):
    # If the region does not intersect our BBox, return
    if not BBoxIntersects(qt.bdry, r):
        return

    spp_pre = qt.pairs
    while spp_pre:
        # Apply function to the spike pair
        if GNAT_test_for_edge(spp_pre, spp_post, syn, tau, theta):
            # Add edge
            GNAT_add_edge(spp_pre, spp_post, 1)
        spp_pre = spp_pre.next

    # Recursively apply to all quadrants
    if qt.NW:
        QTreeMapGNATEdge(qt.NW, r, spp_post, syn, tau, theta)
    if qt.SW:
        QTreeMapGNATEdge(qt.SW, r, spp_post, syn, tau, theta)
    if qt.NE:
        QTreeMapGNATEdge(qt.NE, r, spp_post, syn, tau, theta)
    if qt.SE:
        QTreeMapGNATEdge(qt.SE, r, spp_post, syn, tau, theta)


def GNAT_test_for_edge(spp_pre, spp_post, edg, tau, thresh):
    gamma_1 = compute_gamma(spp_pre.sp1, spp_post.sp1, edg, tau)
    gamma_2 = compute_gamma(spp_pre.sp2, spp_post.sp2, edg, tau)

    return gamma_1 <= thresh and gamma_2 <= thresh


def compute_gamma(sp_pre, sp_post, edg, tau):
    delta_t = sp_post.ts - sp_pre.ts

    # Heaviside function
    theta = 0 if delta_t >= edg.delay else 1

    gamma = (theta * LARGE_GAMMA) + (edg.neg_log_rel_w + (delta_t - edg.delay) / tau)
    return gamma


def compute_omega(sp_pre, sp_post, edg, tau):
    delta_t = sp_post.ts - sp_pre.ts

    # Heaviside function
    theta = 1 if delta_t >= edg.delay else 0

    omega = theta * edg.rel_w * math.exp(-(delta_t - edg.delay) / tau)
    return omega


def flush_edge_buffer():
    global edgbuf_sz

    if fp_edgbuf is None:
        print("FATAL: output file not initialized")
        exit(-1)

    if edgbuf_sz == 0:
        return

    for idx in range(edgbuf_sz):
        fprint_GNAT_edge(fp_edgbuf, g_edgbuf[idx])

    edgbuf_sz = 0


def fprint_GNAT_edge(fp, edg):
    n_id_1 = edg.spp_pre.sp1.n_id
    n_id_2 = edg.spp_post.sp1.n_id

    t_11 = edg.spp_pre.sp1.ts
    t_12 = edg.spp_pre.sp2.ts
    t_21 = edg.spp_post.sp1.ts
    t_22 = edg.spp_post.sp2.ts

    line = f"{n_id_1} {t_11} {t_12} {n_id_2} {t_21} {t_22}\n"
    fp.write(line)


def GNAT_add_edge(_spp_pre, _spp_post, _cd_ratio):
    global edgbuf_sz, g_edgbuf

    # Check if buffer is full
    if edgbuf_sz >= N_EDGBUF:
        flush_edge_buffer()
        edgbuf_sz = 0

    # Add new edge to buffer
    g_edgbuf[edgbuf_sz] = GNATEdge(_spp_pre, _spp_post, _cd_ratio)
    edgbuf_sz += 1

