# 🧠 NeuroGenAI | Spike Train Generation from DNABERT Embeddings
# Story 3: Convert semantic embeddings into biologically plausible spikes

## 🧬 1. Imports and Setup

#📘 Markdown Cells to Add in the Notebook
##❓ What are Poisson Spike Trains?
Poisson spike trains simulate the natural firing behavior of neurons, where spikes occur with a given probability over time. In our case, we convert normalized DNABERT vector values into spike probabilities per neuron.

##🔍 Why Convert DNABERT to Spikes?
Spiking Neural Networks (SNNs) process information in discrete time steps using binary events (spikes), unlike traditional ML models. This transformation is crucial to use powerful biological computing simulators like Brian2.

##🔥 What’s Cool Here?
Every neuron has a max firing rate (Hz) and spikes based on it probabilistically.

Time series shape: [timesteps, neurons].

Enables time-aware biological modeling from language models like DNABERT.

##✅ Output Artifacts
data/processed/dnabert_rate_vectors.npy – Normalized firing rates
data/processed/spike_train.npy – Poisson spike train (binary 0/1)
outputs/spike_train_preview.png – Spike raster plot visualization

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys

from src.snn.spike_encoder import SpikeEncoder

print("✅ Environment ready!")

## 📥 2. Load Embeddings from DNABERT
# These were generated in Story 1 and saved as .npy file

In [None]:
embedding_path = "data/processed/fasta_dnabert_embeddings.npy"
embeddings = np.load(embedding_path)

print("✅ Loaded DNABERT embeddings with shape:", embeddings.shape)

## ⚙️ 3. Initialize SpikeEncoder

### rate_max_hz: peak neuron firing rate
### duration_ms: total simulation time
### dt_ms: simulation timestep
### stdp_ready: if True, retains rate history for later use
### seed: for reproducibility of spikes

In [None]:
encoder = SpikeEncoder(
    rate_max_hz=120,
    duration_ms=500,
    dt_ms=1.0,
    stdp_ready=False,
    seed=42
)

print("🧠 SpikeEncoder initialized with:")
print(f"   Max Rate: {encoder.rate_max_hz} Hz | Duration: {encoder.duration_ms} ms | dt: {encoder.dt_ms} ms")

## 🔁 4. Normalize and Convert to Firing Rates

In [None]:
firing_rates = encoder.encode_and_save(
    embeddings=embeddings,
    output_dir= "data/processed",
    prefix="dnabert"
)

print("✅ Firing rate matrix shape:", firing_rates.shape)

## ⚡ 5. Generate Poisson Spike Train (binary spikes over time)

In [None]:
spike_train = encoder.generate_poisson_spike_train(
    firing_rates=firing_rates,
    save_path="data/processed/spike_train.npy"
)

print("✅ Spike train shape (timesteps, neurons):", spike_train.shape)

## 🎯 6. Visualize Spike Raster Plot

In [None]:
encoder.plot_raster(
    spike_train,
    save_path="outputs/spike_train_preview.png"
)

print("📊 Raster plot saved for visual inspection.")