#✅ DNABERT ➜ Spike Encoder ➜ Brian2 SNN: Final Pipeline
##🧠 This notebook runs the full NeuroGenAI pipeline: from DNA to spikes, simulating brain-like computation on genetic sequences.

##🧩 1. Setup & Imports

In [None]:
# 📦 Dependencies
!pip install matplotlib numpy transformers==4.41.0 torch brian2 scikit-learn

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer, AutoModel
import time
from sklearn.preprocessing import MinMaxScaler
from brian2 import *
import json

# 🧠 Core modules
from src.nlp.dna_embedding_model import DNAEmbedder
from src.snn.spike_encoder import SpikeEncoder
from src.snn.brian_model import run_brian2_simulation
from src.eval.snn_metrics import evaluate_spikes

##🧬 2. Load and Clean DNA Sequences

In [None]:
# 🧬 Load example FASTA (can be your own)
fasta_path = "data/processed/human_fasta_clean.csv"
df = pd.read_csv(fasta_path)

# Filter for clean sequences
df = df[df["Length"] >= 30].head(100)  # ✅ Just use top 100 for demo
sequences = df["Sequence"].tolist()
print(f"✅ Loaded {len(sequences)} DNA sequences.")

##🔬 3. Embed with DNABERT

In [None]:
# 🔬 Embed DNA with DNABERT
embedder = DNAEmbedder(model_id="zhihan1996/DNA_bert_6", k=6)
# Extract sequences
sequences = df["Sequence"].tolist()

# Embed all
embeddings = embedder.embed_batch(sequences)
print("✅ Final embedding shape:", embeddings.shape)

# Save as .npy
np.save("data/processed/fasta_dnabert_embeddings.npy", embeddings)
print("📁 Saved to: data/processed/fasta_dnabert_embeddings.npy")

##⚡ 4. Encode Spikes from Embeddings

In [None]:
encoder = SpikeEncoder(rate_max_hz=120)

# Normalize + Firing rates
norm_rates = encoder.normalize_embeddings(embeddings)
firing_rates = encoder.to_firing_rates(norm_rates)
np.save("data/processed/firing_rates.npy", firing_rates)

# ⏱️ Generate spike matrix
spike_matrix = encoder.generate_poisson_spike_train(firing_rates)
np.save("data/processed/spike_train.npy", spike_matrix)

print("✅ Spike train shape:", spike_matrix.shape)

##🧠 5. Simulate Spiking Neural Network (Brian2)

In [None]:
# 🔬 Run brain-inspired simulation
snn_duration_ms = 100
sim_result_path = "data/outputs/snn_sim_results.npz"
plot_path = "data/outputs/snn_spike_plot.png"

# 📥 Load spike matrix
spike_matrix = np.load("data/processed/spike_train.npy")

# 🧠 Run brain-inspired simulation
run_brian2_simulation(
    spike_matrix=spike_matrix,
    duration_ms=snn_duration_ms,
    plot_path=plot_path,
    save_path=sim_result_path
)

##📊 6. Analyze and Save Metrics

In [None]:
metrics_path = "data/outputs/snn_metrics.json"

evaluate_spikes(
    path=sim_result_path,
    save_path=metrics_path
)

##📦 7. Metadata Log

In [None]:
meta = {
    "model_id": embedder.model_id,
    "vector_dim": embeddings.shape[1],
    "n_sequences": len(sequences),
    "fasta_source": fasta_path,
    "rate_max_hz": encoder.rate_max_hz,
    "duration_ms": snn_duration_ms
}

with open("data/outputs/meta.json", "w") as f:
    import json
    json.dump(meta, f, indent=4)

print("✅ Pipeline metadata saved to data/outputs/meta.json")

# Visualize SNN Activity

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
import os

# 📥 Load simulation data
sim_data = np.load("data/outputs/snn_sim_results.npz")
spike_times = sim_data["spike_times"]
spike_indices = sim_data["spike_indices"]

# 🔭 Optional 3D Spike Raster Plot
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(spike_times, spike_indices, zs=0, zdir='z', s=2, c=spike_indices, cmap='plasma')
ax.set_xlabel('Time (ms)')
ax.set_ylabel('Neuron ID')
ax.set_zlabel('Depth (for visual separation)')
ax.set_title("3D Spike Raster Plot")

plt.tight_layout()
plt.savefig("data/outputs/snn_spike_plot_3d.png", dpi=300)
plt.close()
print("🖼️ 3D spike raster saved to: data/outputs/snn_spike_plot_3d.png")

##🎞️ Optional: Animate Spike Activity Over Time

In [None]:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# 🧠 Prepare data for animation
time_window = 1  # ms
max_time = int(np.max(spike_times)) + 1
frame_interval = 20  # ms

frames = int(max_time / time_window)

fig, ax = plt.subplots(figsize=(10, 4))
ax.set_xlim(0, max_time)
ax.set_ylim(0, np.max(spike_indices) + 10)
ax.set_xlabel("Time (ms)")
ax.set_ylabel("Neuron ID")
ax.set_title("Spiking Activity Over Time")

scat = ax.scatter([], [], s=3, c='red')

def update(frame):
    t_start = frame * time_window
    t_end = t_start + time_window
    mask = (spike_times >= t_start) & (spike_times < t_end)
    scat.set_offsets(np.column_stack((spike_times[mask], spike_indices[mask])))
    return scat,

ani = FuncAnimation(fig, update, frames=frames, interval=frame_interval, blit=True)
ani.save("data/outputs/visual_spikes.gif", writer="pillow", fps=15)
plt.close()

print("🎬 Animated spike activity saved to: data/outputs/visual_spikes.gif")