## NeXtCortex project

The installation of the packages DrWatson, SpikingNeuralNetworks, UnPack, Logging, Plots, Statistics, Random, CSV, DataFrames, Unitful are required to run this file.

The file src/networks_utils.jl is also required.

In [7]:
using DrWatson
findproject(@__DIR__) |> quickactivate

include("src/network_utils.jl")
using .NetworkUtils

using SpikingNeuralNetworks
using UnPack
using Logging
using Plots
using Statistics
using Random
using CSV
using DataFrames
import SpikingNeuralNetworks: @update

global_logger(NullLogger())
SNN.@load_units

import SpikingNeuralNetworks: PoissonLayer, monitor!, sim!, SingleExpSynapse, IFParameter, PostSpike, STTC

img_path = "plots_and_images";

#### Network baseline configuration

In [8]:
TC3inhib_network = (
    Npop = (TE=200, CE=4000, PV=400, SST=300, VIP=300),
    seed = 1234,

    exc = IFParameter(
        τm=200pF / 10nS,
        El=-70mV,
        Vt=-50mV,
        Vr=-70mV,
        R=1/10nS,
    ),

    inh_PV  = IFParameter(τm=100pF/10nS, El=-70mV, Vt=-53mV, Vr=-70mV, R=1/10nS),
    inh_SST = IFParameter(τm=200pF/10nS, El=-70mV, Vt=-53mV, Vr=-70mV, R=1/10nS),
    inh_VIP = IFParameter(τm=150pF/10nS, El=-70mV, Vt=-53mV, Vr=-70mV, R=1/10nS),

    spike     = PostSpike(τabs=5ms),
    spike_PV  = PostSpike(τabs=2ms),
    spike_SST = PostSpike(τabs=10ms),
    spike_VIP = PostSpike(τabs=5ms),

    synapse     = SingleExpSynapse(τi=5ms, τe=5ms, E_i=-80mV, E_e=0mV),
    synapse_PV  = SingleExpSynapse(τi=5ms, τe=5ms, E_i=-80mV, E_e=0mV),
    synapse_SST = SingleExpSynapse(τi=12ms, τe=5ms, E_i=-80mV, E_e=0mV),
    synapse_VIP = SingleExpSynapse(τi=7ms, τe=5ms, E_i=-80mV, E_e=0mV),

    connections = (
        TE_to_CE = (p=0.05, μ=4nS, rule=:Fixed),
        TE_to_PV = (p=0.05, μ=4nS, rule=:Fixed),

        CE_to_CE = (p=0.05, μ=2nS, rule=:Fixed),
        CE_to_PV = (p=0.05, μ=2nS, rule=:Fixed),
        CE_to_TE = (p=0.05, μ=2nS, rule=:Fixed),
        CE_to_SST = (p=0.05, μ=2nS, rule=:Fixed),
        CE_to_VIP = (p=0.05, μ=2nS, rule=:Fixed),

        PV_to_CE  = (p=0.05, μ=10nS, rule=:Fixed),
        PV_to_PV  = (p=0.05, μ=10nS, rule=:Fixed),
        PV_to_SST = (p=0.05, μ=10nS, rule=:Fixed),

        SST_to_CE  = (p=0.025, μ=10nS, rule=:Fixed),
        SST_to_PV  = (p=0.025, μ=10nS, rule=:Fixed),
        SST_to_VIP = (p=0.025, μ=10nS, rule=:Fixed),

        VIP_to_SST = (p=0.3, μ=10nS, rule=:Fixed),
    ),

    afferents_to_TE  = (layer=PoissonLayer(rate=1.5Hz, N=1000), conn=(p=0.05, μ=4nS, rule=:Fixed)),
    afferents_to_CE  = (layer=PoissonLayer(rate=1.5Hz, N=1000), conn=(p=0.02, μ=4nS, rule=:Fixed)),
    afferents_to_PV  = (layer=PoissonLayer(rate=1.5Hz, N=1000), conn=(p=0.02, μ=4nS, rule=:Fixed)),
    afferents_to_SST = (layer=PoissonLayer(rate=1.5Hz, N=1000), conn=(p=0.15, μ=2nS, rule=:Fixed)),
    afferents_to_VIP = (layer=PoissonLayer(rate=1.5Hz, N=1000), conn=(p=0.10, μ=2nS, rule=:Fixed)),
);

In [10]:
pops_to_modify = (:VIP_to_SST, :PV_to_CE, :SST_to_CE, :TE_to_CE)

μ_values = [0.1, 0.5, 1.0, 1.5, 2.0, 5.0, 10]
p_values = [0.1, 0.2, 0.7, 1.0];

In [None]:
# FIXED: Convert spike times from milliseconds to seconds

function compute_both_metrics(network, pop, μ, p; seed=1234, window=100ms, step=20ms, threshold=0.8, sim_time=3.0)
    """
    Compute overall STTC and transition time
    
    CRITICAL: Spike times from SNN.spiketimes are in MILLISECONDS, not seconds!
    We need to convert them to seconds for analysis.
    """
    Random.seed!(seed)
    
    modulation = (; network.connections[pop]..., μ = μ, p = p)
    network_modified = (; network...,
        connections = (; network.connections..., pop => modulation)
    )
    
    model = NetworkUtils.build_network(network_modified)
    monitor!(model.pop, [:v], sr=1kHz)
    sim!(model, sim_time*1s)
    
    # Get spike times (in MILLISECONDS)
    myspikes_ms = SNN.spiketimes(model.pop)
    myspikes_sub_ms = myspikes_ms[1:5:end]
    
    # Convert to SECONDS
    myspikes_sub = [s ./ 1000.0 for s in myspikes_sub_ms]
    
    # Check if we have any spikes at all
    total_spikes = sum(length(s) for s in myspikes_sub)
    
    if total_spikes < 100  # Network is essentially inactive
        println("  Inactive network: only $total_spikes spikes")
        return 0.0, sim_time, Float64[], Float64[]
    end
    
    println("  Active network: $total_spikes spikes")
    
    # Show spike time range for debugging
    all_spike_times = vcat(myspikes_sub...)
    if !isempty(all_spike_times)
        println("  Spike time range: $(round(minimum(all_spike_times), digits=2))s - $(round(maximum(all_spike_times), digits=2))s")
    end
    
    # Compute overall STTC (using spike times in seconds, window in ms)
    overall_sttc = mean(STTC(myspikes_sub_ms, 50ms))  # STTC expects times in ms
    println("  Overall STTC: $(round(overall_sttc, digits=3))")
    
    # If overall STTC is low, network never becomes synchronous
    if overall_sttc < 0.1
        println("  Low overall synchrony - no epileptic-like activity")
        return overall_sttc, sim_time, Float64[], Float64[]
    end
    
    # Network is active and synchronous - compute transition time
    println("  High synchrony detected - computing transition time...")
    
    window_s = window / 1000.0  # Convert window to seconds
    step_s = step / 1000.0
    
    # Sliding windows (in seconds)
    t_centers = collect(window_s/2:step_s:(sim_time - window_s/2))
    sttc_series = Float64[]
    
    for (idx, t_center) in enumerate(t_centers)
        t_start = t_center - window_s/2
        t_end = t_center + window_s/2
        
        # Extract spikes in window (spike times now in seconds)
        spikes_window_s = map(myspikes_sub) do s
            s[(s .≥ t_start) .& (s .< t_end)]
        end
        
        # Convert back to milliseconds for STTC function
        spikes_window_ms = [s .* 1000.0 for s in spikes_window_s]
        
        window_spike_count = sum(length(sw) for sw in spikes_window_ms)
        neurons_with_spikes = count(sw -> length(sw) > 0, spikes_window_ms)
        
        if window_spike_count >= 10 && neurons_with_spikes >= 5
            try
                window_sttc = mean(STTC(spikes_window_ms, window))
                push!(sttc_series, window_sttc)
            catch e
                push!(sttc_series, 0.0)
            end
        else
            push!(sttc_series, 0.0)
        end
        
        # Show first few windows for debugging
        if idx <= 3
            println("    Window at t=$(round(t_center, digits=2))s: $window_spike_count spikes, STTC=$(length(sttc_series) > 0 ? round(sttc_series[end], digits=3) : "N/A")")
        end
    end
    
    # Find transition: 3 consecutive windows above threshold
    trans_time = sim_time
    
    if length(sttc_series) >= 3
        for i in 1:(length(sttc_series)-2)
            if all(sttc_series[i:i+2] .≥ threshold)
                trans_time = t_centers[i]
                println("  ✓ Transition detected at $(round(trans_time, digits=2))s")
                break
            end
        end
        
        if trans_time >= sim_time - 0.01
            println("  No transition - network stays below threshold")
        end
    end
    
    return overall_sttc, trans_time, sttc_series, t_centers
end

# Main analysis
println("\n" * "="^70)
println("TRANSITION TIME ANALYSIS - CORRECTED TIME UNITS")
println("="^70)
println("\nInterpretation:")
println("  - STTC heatmap: Overall synchrony (0-1) across full 3s simulation")
println("  - Transition heatmap: Time (in seconds) when STTC≥0.8 is reached")
println("    • 3.0s = Network never becomes epileptic-like")
println("    • <3.0s = Time when epileptic-like activity emerges")
println("="^70)

for pop in pops_to_modify

    println("\n\nAnalyzing population: $pop")
    println("-"^70)

    STTC_matrix = zeros(length(μ_values), length(p_values))
    Trans_matrix = zeros(length(μ_values), length(p_values))
    
    for (i, μ) in enumerate(μ_values)
        for (j, p) in enumerate(p_values)
            println("\n[$i,$j] Testing μ=$(μ)nS, p=$p")
            
            sttc_vals = Float64[]
            trans_vals = Float64[]
            
            for seed in [1234, 1235, 1236]
                overall_sttc, trans_time, sttc_series, time_points = compute_both_metrics(
                    TC3inhib_network, pop, μ, p; 
                    seed=seed, threshold=0.8, sim_time=3.0,
                    window=100ms, step=20ms
                )
                
                push!(sttc_vals, overall_sttc)
                push!(trans_vals, trans_time)
            end
            
            STTC_matrix[i, j] = mean(sttc_vals)
            Trans_matrix[i, j] = mean(trans_vals)
            
            # Summary for this parameter combination
            avg_sttc = mean(sttc_vals)
            avg_trans = mean(trans_vals)
            
            if avg_sttc >= 0.8
                if avg_trans < 2.99
                    println("  ✓ EPILEPTIC-LIKE: STTC=$(round(avg_sttc,digits=2)), emerges at t=$(round(avg_trans,digits=2))s")
                else
                    println("  ⚠ High sync but late: STTC=$(round(avg_sttc,digits=2)), emerges after 2.99s")
                end
            elseif avg_sttc >= 0.1
                println("  ○ Moderate sync: STTC=$(round(avg_sttc,digits=2)), no epileptic threshold")
            else
                println("  ✗ Inactive: STTC≈0, no activity")
            end
        end
    end
    
    println("\n" * "="^70)
    println("RESULTS for $pop:")
    println("="^70)
    println("\nSTTC Matrix (Overall Synchrony):")
    display(STTC_matrix)
    println("\n\nTransition Time Matrix (seconds to epileptic-like state):")
    display(Trans_matrix)
    
    # Count epileptic cases
    n_epileptic = count(STTC_matrix .>= 0.8)
    n_with_transition = count(Trans_matrix .< 2.99)
    println("\n\nSummary:")
    println("  - Parameter combinations with epileptic-like activity (STTC≥0.8): $n_epileptic/$(length(STTC_matrix))")
    println("  - Parameter combinations with detected transitions: $n_with_transition/$(length(Trans_matrix))")
    
    # Plot heatmaps
    sttc_heatmap = heatmap(
        p_values, μ_values, STTC_matrix,
        xlabel="Connection probability (p)",
        ylabel="Synaptic strength (μ in nS)",
        title="Epileptic-like Activity: Overall STTC - $pop",
        color=:viridis,
        clims=(0, 1),
        colorbar_title="STTC",
        size=(500, 400)
    )
    
    for i in 1:length(μ_values), j in 1:length(p_values)
        val = STTC_matrix[i,j]
        val_str = val < 0.01 ? "~0" : string(round(val, digits=2))
        annotate!(sttc_heatmap, p_values[j], μ_values[i], 
                 text(val_str, 8, :white))
    end
    
    savefig(sttc_heatmap, "$img_path/Final_$(pop)_STTC_heatmap.png")
    
    # Transition time heatmap
    transition_heatmap = heatmap(
        p_values, μ_values, Trans_matrix,
        xlabel="Connection probability (p)",
        ylabel="Synaptic strength (μ in nS)",
        title="Time to Epileptic Activity - $pop",
        color=:viridis,
        clims=(0, 3),
        colorbar_title="Time (s)",
        size=(500, 400)
    )
    
    for i in 1:length(μ_values), j in 1:length(p_values)
        trans_val = Trans_matrix[i,j]
        sttc_val = STTC_matrix[i,j]
        
        if sttc_val < 0.1  # Inactive
            label = "N/A"
            color = :gray
        elseif trans_val >= 2.99  # Active but no transition
            label = "Never"
            color = :white
        else  # Has transition
            label = string(round(trans_val, digits=2))
            color = :white
        end
        
        annotate!(transition_heatmap, p_values[j], μ_values[i],
                 text(label, 8, color))
    end
    
    savefig(transition_heatmap, "$img_path/Final_$(pop)_transition_heatmap.png")
    
    # Combined plot
    combined_plot = plot(sttc_heatmap, transition_heatmap, 
                        layout=(1,2), size=(1100, 400),
                        left_margin=5Plots.mm, right_margin=5Plots.mm)
    savefig(combined_plot, "$img_path/Final_$(pop)_combined.png")
    
    # Save results
    CSV.write("$img_path/Final_$(pop)_results.csv", 
             DataFrame(mu=repeat(μ_values, inner=length(p_values)),
                      p=repeat(p_values, outer=length(μ_values)),
                      sttc_overall=vec(STTC_matrix'),
                      transition_time=vec(Trans_matrix'),
                      is_epileptic=vec((STTC_matrix .>= 0.8)'),
                      has_transition=vec((Trans_matrix .< 2.99)')))
    
    println("\n✓ Analysis complete for $pop")
end

println("\n" * "="^70)
println("ANALYSIS COMPLETE")
println("="^70)