In [None]:
from flask import Flask, request, render_template, jsonify, url_for, send_from_directory
from pathlib import Path
import uuid, os, cv2, numpy as np, torch, torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from aedat import Decoder   # pip install aedat

app = Flask(__name__, static_folder="static", template_folder="templates")
UPLOAD_FOLDER = Path("uploads"); UPLOAD_FOLDER.mkdir(exist_ok=True)
OUTPUT_FOLDER = Path("outputs"); OUTPUT_FOLDER.mkdir(exist_ok=True)

# ─────────── Config ───────────────────────────────────────────────────────
NUM_BINS, SENSOR_W, SENSOR_H = 5, 346, 260
RESIZE_W, RESIZE_H          = 160, 120
WIN_MS, FPS                 = 100.0, int(1000/100.0)
BLINK_THRESH                = 0.5
DEVICE                      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─────────── Model ─────────────────────────────────────────────────────────
class SpikingCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(2,32,3,padding=1);   self.bn1 = nn.BatchNorm3d(32)
        self.conv2 = nn.Conv3d(32,64,3,padding=1);  self.bn2 = nn.BatchNorm3d(64)
        self.conv3 = nn.Conv3d(64,128,3,padding=1); self.bn3 = nn.BatchNorm3d(128)
        self.pool  = nn.AdaptiveAvgPool3d(1)
        self.act   = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())
        self.fc1   = nn.Linear(128 + RESIZE_W*RESIZE_H, 64)
        self.drop  = nn.Dropout(0.4)
        self.blink_head  = nn.Linear(64, 1)
        self.centre_head = nn.Linear(64, 2)

    def forward(self, x, cnt):
        out,_ = self.act(self.bn1(self.conv1(x)))
        out,_ = self.act(self.bn2(self.conv2(out)))
        out,_ = self.act(self.bn3(self.conv3(out)))
        feat     = self.pool(out).view(out.size(0), -1)
        cnt_flat = cnt.view(out.size(0), -1)
        h        = torch.relu(self.fc1(torch.cat([feat, cnt_flat], dim=1)))
        h        = self.drop(h)
        return self.blink_head(h), self.centre_head(h)

# ─────────── Read AEDAT4 via aedat.Decoder ───────────────────────────────
def read_aedat4(path: Path) -> np.ndarray:
    """
    Parse .aedat4 into an (N,4) array [x, y, t_us, polarity].
    Supports:
      • object‐style packets: packet.x, packet.y, packet.timestamp, packet.polarity
      • dict‐style packets with a nested 'events' structured array
    """
    events = []
    dec = Decoder(str(path))

    for i, packet in enumerate(dec):
        # object‐style
        if not isinstance(packet, dict):
            xs = packet.x
            ys = packet.y
            ts = packet.timestamp
            ps = packet.polarity

        # dict‐style, expect packet['events'] to be a structured numpy array
        else:
            ev_arr = packet.get("events")
            if ev_arr is None:
                raise KeyError(f"Packet #{i} dict keys = {packet.keys()} (no 'events')")
            names = ev_arr.dtype.names or []
            # grab the correct field names
            try:
                xs = ev_arr["x"]
                ys = ev_arr["y"]
            except KeyError:
                raise KeyError(f"Packet #{i} events fields = {names} (need 'x','y')")
            # timestamp: could be 't', 'ts', or 'timestamp'
            if "t" in names:
                ts = ev_arr["t"]
            elif "timestamp" in names:
                ts = ev_arr["timestamp"]
            else:
                raise KeyError(f"Packet #{i} events fields = {names} (need 't' or 'timestamp')")
            # polarity: could be 'p' or 'polarity'
            if "p" in names:
                ps = ev_arr["p"]
            elif "polarity" in names:
                ps = ev_arr["polarity"]
            else:
                raise KeyError(f"Packet #{i} events fields = {names} (need 'p' or 'polarity')")

        ev = np.column_stack([
            xs.astype(np.int32),
            ys.astype(np.int32),
            ts.astype(np.int64),
            ps.astype(np.int8),
        ])
        events.append(ev)

    if not events:
        return np.zeros((0, 4), dtype=np.int64)
    return np.vstack(events)



# ─────────── Voxel encoding ────────────────────────────────────────────────
def events_to_voxel(events, t0, t1):
    T   = NUM_BINS
    vox = np.zeros((T, 2, SENSOR_H, SENSOR_W), np.float32)
    dt  = (t1 - t0) / T
    mask = (events[:,2] >= t0) & (events[:,2] < t1)
    sl   = events[mask]
    if sl.size:
        bins = ((sl[:,2] - t0) / dt).astype(int)
        np.clip(bins, 0, T-1, out=bins)
        for b, (x, y, p) in zip(bins, sl[:, [0,1,3]].astype(int)):
            if 0 <= x < SENSOR_W and 0 <= y < SENSOR_H:
                vox[b, p, y, x] += 1
    vox = np.clip(vox, -3, 3)

    # Downsample spatially
    small = np.zeros((T, 2, RESIZE_H, RESIZE_W), np.float32)
    for t in range(T):
        for c in range(2):
            small[t, c] = cv2.resize(
                vox[t, c],
                (RESIZE_W, RESIZE_H),
                interpolation=cv2.INTER_AREA
            )
    return small

# ─────────── Load model once ────────────────────────────────────────────────
MODEL_PATH = Path("model/blink_centre_model.pth")
model      = SpikingCNN().to(DEVICE)
state      = torch.load(str(MODEL_PATH), map_location=DEVICE)
model.load_state_dict(state)
model.eval()

# ─────────── Flask routes ─────────────────────────────────────────────────
@app.route("/")
def index():
    return render_template("index.html")

@app.route("/process", methods=["POST"])
def process_file():
    try:
        # 1) Save upload
        f = request.files["aedat"]
        uid = uuid.uuid4().hex
        in_path  = UPLOAD_FOLDER / f"{uid}.aedat4"
        out_path = OUTPUT_FOLDER / f"{uid}.mp4"
        f.save(str(in_path))

        # 2) Read all events
        events = read_aedat4(in_path)
        if events.shape[0] == 0:
            return jsonify(error="No events found in AEDAT file"), 400

        # 3) Prepare video writer
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        writer = cv2.VideoWriter(str(out_path), fourcc, FPS, (RESIZE_W,RESIZE_H))

        # 4) Slide windows & inference
        win_us = int(WIN_MS * 1000)
        t0     = int(events[0,2])
        t_end  = int(events[-1,2])
        n_win  = (t_end - t0) // win_us

        for _ in range(int(n_win)):
            t1  = t0 + win_us
            vox = events_to_voxel(events, t0, t1)
            t0 += win_us

            x   = torch.from_numpy(vox) \
                        .permute(1,0,2,3).unsqueeze(0) \
                        .to(DEVICE)
            cnt = x.abs().sum(dim=(1,2), keepdim=True)

            with torch.no_grad():
                blink_logit, centre = model(x, cnt)
                prob = torch.sigmoid(blink_logit).item()

            label = "blink" if prob > BLINK_THRESH else "open"
            cx = int(centre[0,0].clamp(0,1).item() * RESIZE_W)
            cy = int(centre[0,1].clamp(0,1).item() * RESIZE_H)

            cm     = cnt.squeeze().cpu().numpy()
            cm_img = (cm / (cm.max()+1e-6) * 255).astype(np.uint8)
            frame  = cv2.cvtColor(cm_img, cv2.COLOR_GRAY2BGR)
            cv2.circle(frame, (cx, cy), 3, (0,0,255), -1)
            cv2.putText(frame, label, (5,20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6,
                        (255,255,255), 1, cv2.LINE_AA)

            writer.write(frame)

        writer.release()

        download_url = url_for("download_file", filename=out_path.name)
        return jsonify(video_url=download_url)

    except Exception as e:
        # Log the traceback on the server
        import traceback; traceback.print_exc()
        return jsonify(error=str(e)), 500


@app.route("/download/<filename>")
def download_file(filename):
    return send_from_directory(OUTPUT_FOLDER, filename, as_attachment=True)

if __name__ == "__main__":
    app.run(debug=False, port=24587)




 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:24587
Press CTRL+C to quit
127.0.0.1 - - [23/May/2025 16:27:33] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [23/May/2025 16:27:33] "GET /static/app.js HTTP/1.1" 304 -
127.0.0.1 - - [23/May/2025 16:27:33] "GET /static/style.css HTTP/1.1" 304 -
127.0.0.1 - - [23/May/2025 16:28:34] "POST /process HTTP/1.1" 200 -
127.0.0.1 - - [23/May/2025 16:30:06] "GET /download/41fe2d0adaac4adb810a3594b3b94146.mp4 HTTP/1.1" 200 -
127.0.0.1 - - [24/May/2025 14:58:03] "POST /process HTTP/1.1" 200 -
127.0.0.1 - - [24/May/2025 14:58:20] "GET /download/da0b52a162f94087ba91a4a737a550cd.mp4 HTTP/1.1" 200 -
