# Training a SNN Demapper with Receptive-field Encoding

More details can be found in [E. Arnold et al., “Spiking neural network nonlinear
demapping on neuromorphic hardware for IM/DD optical communication”](https://ieeexplore.ieee.org/abstract/document/10059327/).


In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import torch
import norse.torch as norse

from IMDD import PAM4IMDD, IMDDModel, IMDDParams, helpers
from IMDD.snn.encoding import ReceptiveFieldEncoder

## Receptive-field Encoding

One question when training a spiking demapper is how to best translate a chunk of real-valued data into a spiking representation in an efficient way.
The receptive-field encoding translates each samples $y_k$ in the chunk to a set of $K$ spiking neurons.
Each neuron has a `reference_point` assigned and its spike time is determined by the distance of $y_k$ to the given reference value.
This results in a spatio-temporal encoding with $K$ neurons per sample $k$.

We first create an IMDD link to generate data in order to visualize the input encoding:

In [None]:
params = IMDDParams(
    N=10000,
    n_taps=7,
    alphabet=torch.tensor([-3., -1., 1., 3.]),
    oversampling_factor=3,
    baudrate=112000000000,
    wavelength=0.000001270,
    dispersion_parameter=-0.000005,
    fiber_length=4000,
    noise_power_gain_db=20.,
    roll_off=0.2,
    bias=2.2)
link = IMDDModel(params)

## Generate some data
# Samples to send
samples = link.source()
print("Send symbols:\n", samples, samples.shape)

# Received samples
chunks = link(samples)
print("Received data (chunked):\n", chunks, chunks.shape)

Next we create an instance of the `ReceptiveFieldEncoder`. The encoding is defined by the `references` which we choose to be $10$ values equdistantly distributed in $[0, 7]$. The distance $y_k - \Chi_i$ is scaled by `scaling`. We neglect spikes which are later than `cutoff`.

In [None]:
K = 10
references = torch.linspace(0, 7, K)
print("References: ", references)

# Time resolution of encoding and SNN
dt = 5e-4
time_length = 0.03  # s
cutoff = 0.015  # s

# The encoer
encoder = ReceptiveFieldEncoder(
    scaling=0.008,
    offset=0.0,
    time_length=time_length,
    dt=dt,
    references=references,
    cutoff=cutoff,
    inverse=False)
encoder

Now we encode the 10 first chunks into binary spike tensors. The first dimension in the resulting data is the time axis, which is 60, corresponding to 30 ms. This becomes clearer in the plot below.

In [None]:
spikes = encoder(chunks[:10])  # first ten samples
print("Spikes:\n", spikes, spikes.shape)

In [None]:
# Convert time-dense spikes into event-based spikes
events = torch.nonzero(spikes[:, 0])
time = np.linspace(0, time_length, int(time_length / dt))

fig, axes = plt.subplots(1)
axes.set_xlim(0, time_length)
axes.set_ylim(0, params.n_taps * K)
axes.set_xlabel("$t$ [ms]")
axes.set_ylabel("input neuron $i$")
axes.scatter(events[:, 0] * dt, events[:, 1], s=5, color="black")
axes.vlines(cutoff, 0, params.n_taps * K, color="blue", ls=":")
for i in range(7):
    axes.hlines(10 * (i + 1), 0, 0.015, color="grey", lw=0.5)
plt.show()

# Model

In [None]:
class SNNDemapper(torch.nn.Module):
    """ """
    def __init__(self,
                 n_in: int,
                 n_hidden: int,
                 n_out: int,
                 encoder: torch.nn.Module,
                 lif_params: norse.LIFParameters,
                 li_params: norse.LIParameters,
                 dt: float,
                 device):
        """ """
        super().__init__()
        self.device = device

        self.dt = dt
        # Regularization
        self.reg_bursts = 0.0005
        self.reg_weight_1 = 0.0001
        self.reg_weight_2 = 0.0001
        self.reg_readout = 0.0
        self.target_rate = 0.5

        # Encoding symbols to spikes
        self.encoder = encoder

        # SNN
        self.linear_1 = torch.nn.Linear(n_in, n_hidden, device=device, bias=None)
        self.lif = norse.LIFCell(lif_params)
        self.linear_2 = torch.nn.Linear(n_hidden, n_out, device=device, bias=None)
        self.li = norse.LICell(li_params)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """ """
        zi = self.encoder(input).float()

        T = zi.shape[0]
        s_lif, s_li = None, None
        zs, ys, s_lifs, s_lis = [], [], [], []
        for ts in range(T):
            g1 = self.linear_1(zi[ts])
            z, s_lif = self.lif(g1, s_lif)
            g2 = self.linear_2(z)
            y, s_li = self.li(g2, s_li)

            zs.append(z)
            ys.append(y)
            s_lifs.append(s_lif)
            s_lis.append(s_li)

        self.spikes = torch.stack(zs)
        self.traces = torch.stack(ys)
        self.v_lif = torch.stack([s.v for s in s_lifs])

        self.score = torch.amax(self.traces, 0)

        return self.score

    def regularize(self) -> torch.Tensor:
        """ Regularization terms for demapper """
        reg = torch.tensor(0.).to(self.device)
        # Regularize linear weights
        reg += self.reg_weight_1 * torch.mean(self.linear_1.weight ** 2)
        reg += self.reg_weight_2 * torch.mean(self.linear_2.weight ** 2)
        # Regularize firing rates
        reg += self.reg_bursts * (
            (self.target_rate - self.spikes.sum(0)).mean(0) ** 2).mean()
        # Regularize readout traces
        reg += self.reg_readout * torch.mean(torch.max(self.traces, 0)[0] ** 2)
        return reg

# Training

In [None]:

def stats(loss: torch.Tensor, pred: torch.Tensor, data: torch.Tensor):
    ber = helpers.bit_error_rate(data, pred, False)
    acc = helpers.accuracy(data, pred, False)
    count = torch.count_nonzero(torch.argmax(pred, 1) != data)
    return ber, acc, count


def train(dataloader, optimizer, scheduler, loss_fn, demapper, device):
    loss, acc, ber = [], [], []

    for i, (data, target) in enumerate(dataloader):
        optimizer.zero_grad()

        data = data.to(device)
        target = target.to(device)

        pred_b = demapper(data)
        loss_b = loss_fn(pred_b, target)
        # regularization
        loss_b += demapper.regularize()

        # Optimize
        loss_b.backward()
        optimizer.step()

        # Get stats
        ber_b, acc_b, _ = stats(loss_b, pred_b, target)

        # Accumualte
        loss.append(loss_b.detach())
        acc.append(acc_b)
        ber.append(ber_b)

    scheduler.step()

    return (torch.stack(loss).reshape(-1).mean(),
            np.stack(acc).reshape(-1).mean(),
            np.stack(ber).reshape(-1).mean())


def test(dataloader, demapper, loss_fn, device, min_false_symbols, max_test_epochs):
    loss, acc, ber, n_false = [], [], [], 0

    for epoch in range(max_test_epochs):
        for i, (data, target) in enumerate(dataloader):

            data = data.to(device)
            target = target.to(device)

            pred_b = demapper(data)
            loss_b = loss_fn(pred_b, target)
            loss_b += demapper.regularize()

            ber_b, acc_b, count = stats(loss_b, pred_b, target)

            loss.append(loss_b.detach())
            acc.append(acc_b)
            ber.append(ber_b)

            n_false += count

        if n_false >= min_false_symbols:
            break

    return (torch.stack(loss).reshape(-1).mean(),
            np.stack(acc).reshape(-1).mean(),
            np.stack(ber).reshape(-1).mean(), n_false)

In [None]:
# The device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

# Training parameters
batch_size_train = 100 
batch_size_val = 10000 
lr = 0.001
epochs = 100
min_false_symbols = 1000
max_test_epochs = 100

torch.manual_seed(0)
np.random.seed(0)

lif_demapper = SNNDemapper(
    n_in=70,  # n_taps * n_reference_points
    n_hidden=40,
    n_out=4,  # len(alphabet)
    encoder=encoder,
    lif_params=norse.LIFParameters(
        tau_mem_inv=1/6e-3,
        tau_syn_inv=1/6e-3,
        v_leak=0.,
        v_reset=0.,
        v_th=1.),
    li_params=norse.LIParameters(
        tau_mem_inv=torch.tensor(1/6e-3).to(device),
        tau_syn_inv=torch.tensor(1/6e-3).to(device),
        v_leak=torch.tensor(0.)),
    dt=dt,
    device=device)

# Dataset
dataset = PAM4IMDD(params)

# Dataloader
train_loader = torch.utils.data.DataLoader(
    dataset, batch_size_train, shuffle=True)
val_loader = torch.utils.data.DataLoader(
    dataset, batch_size_val, shuffle=True)

# Loss function
loss_fn = torch.nn.CrossEntropyLoss()

# The SNRs we train the demapper for
snrs = torch.flip(torch.arange(15., 24., 1.), dims=(0,))
snrs[0] = 30. # We train the first demapper with only a little noise

# Validation data
val_datas = torch.zeros((snrs.shape[0], epochs // 10, 4))

model_dir = Path("./models")
model_dir.mkdir(exist_ok=True)

for i, snr in enumerate(snrs):
    print(f"SNR: {snr.item()}")
    # update SNR in Dataset
    dataset.simulator.params.noise_power_gain_db = snr.item()

    # New scheduler
    optimizer = torch.optim.Adam(lif_demapper.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=10, gamma=0.9)

    # train for SNR
    # pbar = tqdm(total=epochs, unit="epoch")
    val_data = torch.zeros((epochs // 10, 4))
    best_val_ber = np.inf
    for epoch in range(epochs):
        train_loss, train_acc, train_ber = train(
            train_loader, optimizer, scheduler, loss_fn, lif_demapper, device)
        print(f"Epoch={epoch}, loss={train_loss:.4f}, ber={train_ber:.7f}, "
              + f"acc={train_acc:.4f}")
        if (epoch + 1) % 10 == 0 and epoch > 0:
            val_loss, val_acc, val_ber, n_false = test(
                val_loader, lif_demapper, loss_fn, device, min_false_symbols,
                max_test_epochs)
            val_datas[i, epoch // 10, 0] = val_loss
            val_datas[i, epoch // 10, 1] = val_ber
            val_datas[i, epoch // 10, 2] = val_acc
            val_datas[i, epoch // 10, 3] = n_false

            # Save best Demapper
            if val_ber < best_val_ber:
                torch.save(
                    lif_demapper.state_dict(), f"./models/snr_{int(snr)}.pt")
                best_val_ber = val_ber

            print(f"Epoch={epoch}, val_loss={val_loss:.4f}, "
                  + f"val_ber={val_ber:.7f}, val_acc={val_acc:.4f}, "
                  + f"n_false={n_false}")

    # Save data
    np.save("snrs.npy", val_datas)


In [None]:
# Test demapper on independet data

torch.manual_seed(42)
np.random.seed(42)

min_false_symbols = 2000

# Dataset and loader
dataset = PAM4IMDD(params)
test_loader = torch.utils.data.DataLoader(
    dataset, batch_size_val, shuffle=False)

bers = np.zeros((snrs.shape[0] - 1, 3))

for s, snr in enumerate(snrs[1:]):
    # Set SNR in dataset
    dataset.simulator.params.noise_power_gain_db = snr.item()

    # Load best model for current SNR
    state_dict = torch.load(f"./models/snr_{int(snr)}.pt")
    lif_demapper.load_state_dict(state_dict)

    loss, acc, ber, n_false = [], [], [], 0

    ber = []
    while True:
        for i, (data, target) in enumerate(test_loader):

            data = data.to(device)
            target = target.to(device)

            pred_b = lif_demapper(data)

            ber_b = helpers.bit_error_rate(target, pred_b, False)
            ber.append(ber_b)

            n_false += torch.count_nonzero(torch.argmax(pred_b, 1) != target)

        if n_false >= min_false_symbols:
            break

    bers[s, 0] = snr
    bers[s, 1] = np.stack(ber).reshape(-1).mean()
    bers[s, 2] = n_false

    print(f"Tested Demapper for {snr}. BER = {bers[s, 1]}, n_false = {n_false}")

np.save("test_bers.npy", bers)


In [None]:
# Plot first sample in last batch
events = torch.nonzero(lif_demapper.spikes.detach().cpu()[:, 0]).numpy()

_, axs = plt.subplots(nrows=3)
for i in range(3):
    axs[i].set_xlim(0, 60)
axs[0].scatter(events[:, 0], spikes[:, 1])
axs[1].plot(lif_demapper.v_lif.detach().cpu().numpy()[:, 0])
axs[2].plot(lif_demapper.traces.detach().cpu().numpy()[:, 0])

## Plot BER-SNR curve

In [None]:
data = np.load("./test_bers.npy")

import pickle
with open("../huawei/results/data/imdd/ecoc_jlt/ber_snr_references/1tap_linear.pkl", "rb") as file:
    data_linear_1tap = pickle.load(file)

with open("../huawei/results/data/imdd/ecoc_jlt/ber_snr_references/7tap_linear.pkl", "rb") as file:
    data_linear_7tap = np.array(pickle.load(file))

with open("../huawei/results/data/imdd/ecoc_jlt/hw/snr_sweep_additive/n_hidden_40/n_taps_7/test_bers.pkl", "rb") as file:
    data_snn_bss_7tap = pickle.load(file)

print(data_snn_bss_7tap)


print(data)

params = {'text.usetex' : True,
          'font.size' : 8,
          }
plt.rcParams.update(params)

color = ["#FAC90F", "#FA8D0F", "#0F69FA", "#7A6F45"]

fig, axs = plt.subplots(figsize=(3.5, 2.7))
axs.set_ylabel("BER")
axs.set_xlabel("$−10 \log_{10}(\sigma^2)$ [dB]")
axs.set_yscale("log")
axs.set_ylim(1e-4, 3e-2)
axs.set_xlim(14.5, 22.5)
axs.set_xticks(data[:, 0])
axs.plot(data[:, 0], data[:, 1], lw=1, color=color[0], label=r"7 tap SNN")
axs.plot(-data_linear_7tap[1:, 0], data_linear_7tap[1:, 1], lw=1, color=color[1], label=r"7 tap, linear \cite{arnold2023spiking}")
axs.plot(data_snn_bss_7tap["snrs"], data_snn_bss_7tap["bers"], lw=1, color=color[2], label=r"7 tap SNN, BSS-2 \cite{arnold2023spiking}")
axs.scatter(-data_linear_7tap[1:, 0], data_linear_7tap[1:, 1], s=10, color=color[1])
axs.scatter(data_snn_bss_7tap["snrs"], data_snn_bss_7tap["bers"], s=10, color=color[2])
axs.scatter(data[:, 0], data[:, 1], color=color[0], s=10)
axs.hlines(2e-3, 14.5, 22.5, color="grey", ls="--", label="KP4 FEC Threshold")
axs.grid(which="minor", lw=0.5)
axs.grid(which="major", lw=0.7)
axs.yaxis.set_label_coords(-0.1, 0.6)
axs.legend(fontsize=9, handlelength=0.5)
plt.tight_layout()
plt.savefig("./snr.pgf")