# LCD-Task: Training a SNN Demapper

The SNN demapper used in this example is based on [1].


### References

[1] [E. Arnold et al., “Spiking neural network nonlinear
demapping on neuromorphic hardware for IM/DD optical communication”](https://ieeexplore.ieee.org/abstract/document/10059327/).


In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import torch
import norse.torch as norse

from IMDD import LCDDataset, helpers

## Receptive-field Encoding

One question when training a spiking demapper is how to best translate a chunk of real-valued data into a spiking representation in an efficient way.
The receptive-field encoding translates each samples $y[k]$ in the chunk to a set of $P$ spiking neurons.
Each neuron has a `reference_point` assigned and its spike time is determined by the distance of $y[k]$ to the given reference value.
This results in a spatio-temporal encoding with $P$ neurons per sample $k$, so $n_\text{tap}\cdot P$ input neurons as detailed in [1].

We first create implement the encoder as defined in [1]:

In [None]:
class ReceptiveFieldEncoder(torch.nn.Module):

    def __init__(self, scaling: float, offset: float, time_length: float,
                 dt: float, references: torch.Tensor, cutoff: float = None):
        super().__init__()
        self.scaling = scaling
        self.offset = offset
        self.time_length = time_length
        self.dt = dt
        self.time_steps = int(time_length // dt) + 1
        self.references = references
        self.cutoff = cutoff if cutoff is not None else time_length

    def forward(self, trace: torch.Tensor) -> torch.Tensor:
        """ """
        dev = trace.device

        # positive spike times
        times = self.scaling * torch.abs(
            trace.unsqueeze(-1) - self.references.to(dev)).reshape(
                trace.shape[0], -1)

        times[(times < 0) | (times > self.cutoff)] = self.time_length + self.dt
        times += self.offset

        bins = (times / self.dt + 1).long()
        mask = bins < self.time_steps
        mesh = torch.meshgrid([torch.arange(s) for s in times.shape])

        indices = torch.stack(
            (bins.to(dev)[mask].reshape(-1),
             mesh[0].to(dev)[mask].reshape(-1),
             *(mesh[i].to(dev)[mask].reshape(-1)
               for i in range(1, len(mesh)))))

        spikes = torch.sparse_coo_tensor(
            indices, torch.ones(indices.shape[1]).to(dev),
            (self.time_steps, times.shape[0], *times.shape[1:]), dtype=int)

        return spikes.to_dense()

Next we create the dataset and visualize the input encoding.

In [None]:
dataset = LCDDataset()

# Generate some data
y_chunk, q = dataset[45]

# Data
print("Received symbols (chunked):\n", y_chunk, y_chunk.shape)
print("Corresponding index q:\n", q, q.shape)
print("Corresponding send bits:\n", helpers.get_graylabel(2)[q])

Next we create an instance of the `ReceptiveFieldEncoder`. The encoding is defined by the `references` which we choose to be $10$ values equdistantly distributed in $[0, 7]$. The distance $y_k - \Chi_i$ is scaled by `scaling`. We neglect spikes which are later than `cutoff`.

In [None]:
P = 10
references = torch.linspace(0, 7, P)
print("References: ", references)

# Time resolution of encoding and SNN
dt = 5e-4
time_length = 0.03  # s
cutoff = 0.015  # s

# The encoer
encoder = ReceptiveFieldEncoder(
    scaling=0.008,
    offset=0.0,
    time_length=time_length,
    dt=dt,
    references=references,
    cutoff=cutoff)
encoder

Now we encode `y_chunk` into a binary spike tensor. The first dimension in the resulting data is the time axis, which is 60, corresponding to 30 ms. This becomes clearer in the plot below.

In [None]:
spikes = encoder(y_chunk.unsqueeze(0))  # add batch dim
print("Spikes:\n", spikes, spikes.shape)

In [None]:
# Convert time-dense spikes into event-based spikes
events = torch.nonzero(spikes[:, 0])
time = np.linspace(0, time_length, int(time_length / dt))

fig, axes = plt.subplots(1)
axes.set_xlim(0, time_length)
axes.set_ylim(0, dataset.simulator.params.n_taps * P)
axes.set_xlabel("$t$ [ms]")
axes.set_ylabel("input neuron $i$")
axes.scatter(events[:, 0] * dt, events[:, 1], s=5, color="black")
axes.vlines(cutoff, 0, dataset.simulator.params.n_taps * P, color="blue", ls=":")
for i in range(7):
    axes.hlines(10 * (i + 1), 0, 0.015, color="grey", lw=0.5)
plt.show()

# Model

Now we define a SNN which we train to solve the demapping task.

In [None]:
class SNNDemapper(torch.nn.Module):
    """ """
    def __init__(self,
                 n_in: int,
                 n_hidden: int,
                 n_out: int,
                 encoder: torch.nn.Module,
                 lif_params: norse.LIFParameters,
                 li_params: norse.LIParameters,
                 dt: float,
                 device):
        """ """
        super().__init__()
        self.device = device

        self.dt = dt
        # Regularization
        self.reg_bursts = 0.0005
        self.reg_weight_1 = 0.0001
        self.reg_weight_2 = 0.0001
        self.reg_readout = 0.0
        self.target_rate = 0.5

        # Encoding symbols to spikes
        self.encoder = encoder

        # SNN
        self.linear_1 = torch.nn.Linear(
            n_in, n_hidden, device=device, bias=None)
        self.lif = norse.LIFCell(lif_params)
        self.linear_2 = torch.nn.Linear(
            n_hidden, n_out, device=device, bias=None)
        self.li = norse.LICell(li_params)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """ """
        self.zi = self.encoder(input).float()

        T = self.zi.shape[0]
        s_lif, s_li = None, None
        zs, ys, s_lifs, s_lis = [], [], [], []
        for ts in range(T):
            g1 = self.linear_1(self.zi[ts])
            z, s_lif = self.lif(g1, s_lif)
            g2 = self.linear_2(z)
            y, s_li = self.li(g2, s_li)

            zs.append(z)
            ys.append(y)
            s_lifs.append(s_lif)
            s_lis.append(s_li)

        self.spikes = torch.stack(zs)
        self.traces = torch.stack(ys)
        self.v_lif = torch.stack([s.v for s in s_lifs])

        self.score = torch.amax(self.traces, 0)

        return self.score

    @property
    def rate(self) -> torch.tensor:
        return self.zi.sum(0).sum(1), self.spikes.sum(0).sum(1)

    def regularize(self) -> torch.Tensor:
        """ Regularization terms for demapper """
        reg = torch.tensor(0.).to(self.device)
        # Regularize linear weights
        reg += self.reg_weight_1 * torch.mean(self.linear_1.weight ** 2)
        reg += self.reg_weight_2 * torch.mean(self.linear_2.weight ** 2)
        # Regularize firing rates
        reg += self.reg_bursts * (
            (self.target_rate - self.spikes.sum(0)).mean(0) ** 2).mean()
        # Regularize readout traces
        reg += self.reg_readout * torch.mean(torch.max(self.traces, 0)[0] ** 2)
        return reg

# Training

In [None]:
# Some functions for training and testing

def stats(loss: torch.Tensor, pred: torch.Tensor, data: torch.Tensor):
    ber = helpers.bit_error_rate(data, pred, False)
    acc = helpers.accuracy(data, pred, False)
    count = torch.count_nonzero(torch.argmax(pred, 1) != data)
    return ber, acc, count


def train(dataloader, optimizer, scheduler, loss_fn, demapper, device):
    loss, acc, ber = [], [], []

    for i, (data, target) in enumerate(dataloader):
        optimizer.zero_grad()

        data = data.to(device)
        target = target.to(device)

        pred_b = demapper(data)
        loss_b = loss_fn(pred_b, target)
        # regularization
        loss_b += demapper.regularize()

        # Optimize
        loss_b.backward()
        optimizer.step()

        # Get stats
        ber_b, acc_b, _ = stats(loss_b, pred_b, target)

        # Accumualte
        loss.append(loss_b.detach())
        acc.append(acc_b)
        ber.append(ber_b)

    scheduler.step()

    return (torch.stack(loss).reshape(-1).mean(),
            np.stack(acc).reshape(-1).mean(),
            np.stack(ber).reshape(-1).mean())


def test(dataloader, demapper, loss_fn, device, min_false_symbols, max_test_epochs):
    loss, acc, ber, n_false = [], [], [], 0

    for epoch in range(max_test_epochs):
        for i, (data, target) in enumerate(dataloader):

            data = data.to(device)
            target = target.to(device)

            pred_b = demapper(data)
            loss_b = loss_fn(pred_b, target)
            loss_b += demapper.regularize()

            ber_b, acc_b, count = stats(loss_b, pred_b, target)

            loss.append(loss_b.detach())
            acc.append(acc_b)
            ber.append(ber_b)

            n_false += count

        if n_false >= min_false_symbols:
            break

    return (torch.stack(loss).reshape(-1).mean(),
            np.stack(acc).reshape(-1).mean(),
            np.stack(ber).reshape(-1).mean(), n_false)

In [None]:
# The device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

# Training parameters
batch_size_train = 100 
batch_size_val = 10000 
lr = 0.001
epochs = 100
max_test_epochs = 100
min_false_symbols = 2000
min_false_bits = 2000

# Loss function
loss_fn = torch.nn.CrossEntropyLoss()

# Bits
gray_bits = torch.tensor(helpers.get_graylabel(2))

## Train SNN demapper SNR sweep

First we train the SNN demapper for a range of noise powers $\sigma_\text{n}^2$.
We start training with little noise and increase in 1 dB steps.
For each noise level we save the model which performs best on the validation data and test it later against unseen data.
Note that we continue the training of the SNN across different noise levels.
A demapper could also be trained for only one noise level and then be tested across different SNRs.

In [None]:
# Fixe seed
torch.manual_seed(0)
np.random.seed(0)

# We create folder to save the models and training data at
data_dir = Path("../results2/snr")
data_dir.mkdir(exist_ok=True, parents=True)
model_dir = data_dir / "models"
model_dir.mkdir(exist_ok=True, parents=True)

# Model
lif_demapper = SNNDemapper(
    n_in=70,  # n_taps * n_reference_points
    n_hidden=40,
    n_out=4,  # len(alphabet)
    encoder=encoder,
    lif_params=norse.LIFParameters(
        tau_mem_inv=1/6e-3,
        tau_syn_inv=1/6e-3,
        v_leak=0.,
        v_reset=0.,
        v_th=1.),
    li_params=norse.LIParameters(
        tau_mem_inv=torch.tensor(1/6e-3).to(device),
        tau_syn_inv=torch.tensor(1/6e-3).to(device),
        v_leak=torch.tensor(0.)),
    dt=dt,
    device=device)

# Dataset
dataset = LCDDataset()

# Dataloader
train_loader = torch.utils.data.DataLoader(
    dataset, batch_size_train, shuffle=True)
val_loader = torch.utils.data.DataLoader(
    dataset, batch_size_val, shuffle=True)

# The SNRs we train the demapper for
snrs = torch.flip(torch.arange(15., 24., 1.), dims=(0,))
snrs[0] = 30. # We train the first demapper with only a little noise

# Validation data
val_datas = torch.zeros((snrs.shape[0], epochs // 10, 4))

# Sweep SNR
for i, snr in enumerate(snrs):
    print(f"SNR: {snr.item()}")
    # update SNR in Dataset
    dataset.set_noise_power_db(-snr.item())

    # New scheduler
    optimizer = torch.optim.Adam(lif_demapper.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=10, gamma=0.9)

    # Train for SNR
    val_data = torch.zeros((epochs // 10, 4))
    best_val_ber = np.inf
    for epoch in range(epochs):
        train_loss, train_acc, train_ber = train(
            train_loader, optimizer, scheduler, loss_fn, lif_demapper, device)
        if (epoch + 1) % 10 == 0 and epoch > 0:
            val_loss, val_acc, val_ber, n_false = test(
                val_loader, lif_demapper, loss_fn, device, min_false_symbols,
                max_test_epochs)
            val_datas[i, epoch // 10, 0] = val_loss
            val_datas[i, epoch // 10, 1] = val_ber
            val_datas[i, epoch // 10, 2] = val_acc
            val_datas[i, epoch // 10, 3] = n_false

            # Save best Demapper
            if val_ber < best_val_ber:
                torch.save(
                    lif_demapper.state_dict(), model_dir / f"snr_{int(snr)}.pt")
                best_val_ber = val_ber

            print(f"Epoch={epoch}, val_loss={val_loss:.4f}, "
                  + f"val_ber={val_ber:.7f}, val_acc={val_acc:.4f}, "
                  + f"n_false={n_false}")

    # Save data at each SNR in case we interrupt training
    np.save(data_dir / "val_bers.npy", val_datas)

After training we test each SNN for each SNR on new data

In [None]:
# Test demapper on independent data for different SNRs
snr_data = np.zeros((snrs.shape[0] - 1, 5))

for s, snr in enumerate(snrs[1:]):
    # Fix seed
    torch.manual_seed(42)
    np.random.seed(42)

    # Dataset and loader
    dataset = LCDDataset()
    test_loader = torch.utils.data.DataLoader(
        dataset, batch_size_val, shuffle=False)

    # Set SNR in dataset
    dataset.set_noise_power_db(-snr.item())

    # Load best model for current SNR
    state_dict = torch.load(model_dir / f"snr_{int(snr)}.pt")
    lif_demapper.load_state_dict(state_dict)

    loss, acc, ber, i_rate, h_rate, n_false = [], [], [], [], [], 0
    while True:
        for i, (data, target) in enumerate(test_loader):

            data = data.to(device)
            target = target.to(device)

            pred_b = lif_demapper(data)

            ber_b = helpers.bit_error_rate(target, pred_b, False)
            ber.append(ber_b)

            i_rate.append(lif_demapper.rate[0].cpu().detach().numpy())
            h_rate.append(lif_demapper.rate[1].cpu().detach().numpy())

            # Number false bits in current batch
            n_false += torch.count_nonzero(
                (gray_bits[torch.argmax(pred_b.cpu(), 1)]
                 != gray_bits[target.cpu()]).reshape(-1))

        if n_false >= min_false_bits:
            break

    snr_data[s, 0] = snr
    snr_data[s, 1] = np.stack(ber).reshape(-1).mean()
    snr_data[s, 2] = n_false
    snr_data[s, 3] = np.stack(i_rate).reshape(-1).mean()
    snr_data[s, 4] = np.stack(h_rate).reshape(-1).mean()

    print(f"Tested Demapper for {snr}. BER = {snr_data[s, 1]}, " \
          + f"n_false = {n_false}, rate = {snr_data[s, 4]}")

    np.save(data_dir / "test_bers.npy", snr_data)

### Plot the BER - SNR curve

In [None]:
data = np.load("../results/snr/test_bers.npy")

color = ["#FAC90F", "#FA8D0F", "#0F69FA", "#7A6F45"]

fig, axs = plt.subplots(ncols=2, figsize=(5.5, 2.7))

axs[0].set_ylabel("BER")
axs[0].set_xlabel("$-\sigma^2_\mathrm{n}$ [dB]")
axs[0].set_yscale("log")
axs[0].set_ylim(1e-4, 3e-2)
axs[0].set_xlim(14.5, 22.5)
axs[0].grid(which="minor", lw=0.2, ls=":")
axs[0].grid(which="major", lw=0.7)
axs[0].set_xticks(data[:, 0])
axs[0].set_yticks([1e-4, 1e-3, 1e-2])
axs[0].plot(data[:, 0], data[:, 1], lw=1, color=color[0], label=r"SNN")
axs[0].scatter(data[:, 0], data[:, 1], color=color[0], s=10)
axs[0].hlines(2e-3, 14.5, 22.5, color="grey", ls="--", label="KP4 FEC")
axs[0].legend()

axs[1].set_ylabel("Spikes")
axs[1].set_xlabel("$-\sigma^2_\mathrm{n}$ [dB]")
axs[1].set_xticks(data[:, 0])
axs[1].set_ylim(0, 95)
axs[1].plot(data[:, 0], data[:, 3], lw=1, color=color[0])
axs[1].scatter(data[:, 0], data[:, 3], color=color[0], s=10, label="Input")
axs[1].plot(data[:, 0], data[:, 4], lw=1, color=color[2])
axs[1].scatter(data[:, 0], data[:, 4], color=color[2], s=10, label="Hidden")
axs[1].legend()

plt.tight_layout()
plt.savefig("../results/snr/ber_snr.png")

## Sweep Hidden Size as -20dB

Now we repeat the procedure but instead we keep the noise level constant at required value and seep the number of neurons in the hidden layer. 

In [None]:
# Train demapper on independet data
data_dir = Path("../results2/hidden_size")
data_dir.mkdir(exist_ok=True, parents=True)
model_dir = data_dir / "models"
model_dir.mkdir(exist_ok=True, parents=True)


for n_hidden in [5, 10, 15, 20, 30, 40, 60, 100]:
    print(f"n_hidden: {n_hidden}")
    
    # reset seed
    torch.manual_seed(0)
    np.random.seed(0)

    lif_demapper = SNNDemapper(
        n_in=70,  # n_taps * n_reference_points
        n_hidden=n_hidden,
        n_out=4,  # len(alphabet)
        encoder=encoder,
        lif_params=norse.LIFParameters(
            tau_mem_inv=1/6e-3,
            tau_syn_inv=1/6e-3,
            v_leak=0.,
            v_reset=0.,
            v_th=1.),
        li_params=norse.LIParameters(
            tau_mem_inv=torch.tensor(1/6e-3).to(device),
            tau_syn_inv=torch.tensor(1/6e-3).to(device),
            v_leak=torch.tensor(0.)),
        dt=dt,
        device=device)

    # Dataset
    dataset = LCDDataset()

    # Dataloader
    train_loader = torch.utils.data.DataLoader(
        dataset, batch_size_train, shuffle=True)
    val_loader = torch.utils.data.DataLoader(
        dataset, batch_size_val, shuffle=True)

    # We pre train at lower noise levels
    for i, snr in enumerate(torch.tensor([30, 22, 21, 20])):
        print(f"SNR: {snr.item()}")
        # update SNR in Dataset
        dataset.set_noise_power_db(-snr.item())

        # New scheduler
        optimizer = torch.optim.Adam(lif_demapper.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=10, gamma=0.9)

        # train for SNR
        best_val_ber = np.inf
        for epoch in range(epochs):
            train_loss, train_acc, train_ber = train(
                train_loader, optimizer, scheduler, loss_fn, lif_demapper, device)
            if (epoch + 1) % 10 == 0 and epoch > 0 and snr == 20.:
                val_loss, val_acc, val_ber, n_false = test(
                    val_loader, lif_demapper, loss_fn, device, min_false_symbols,
                    max_test_epochs)

                # Save best Demapper
                if val_ber < best_val_ber:
                    torch.save(
                        lif_demapper.state_dict(),
                        model_dir / f"n_hidden_{n_hidden}.pt")
                    best_val_ber = val_ber

                print(f"Epoch={epoch}, val_loss={val_loss:.4f}, "
                    + f"val_ber={val_ber:.7f}, val_acc={val_acc:.4f}, "
                    + f"n_false={n_false}")

In [None]:
# Test demapper on independent data for different hidden sizes
datas = np.zeros((8, 5))

for s, n_hidden in enumerate([5, 10, 15, 20, 30, 40, 60, 100]):
    torch.manual_seed(42)
    np.random.seed(42)


    lif_demapper = SNNDemapper(
        n_in=70,  # n_taps * n_reference_points
        n_hidden=n_hidden,
        n_out=4,  # len(alphabet)
        encoder=encoder,
        lif_params=norse.LIFParameters(
            tau_mem_inv=1/6e-3,
            tau_syn_inv=1/6e-3,
            v_leak=0.,
            v_reset=0.,
            v_th=1.),
        li_params=norse.LIParameters(
            tau_mem_inv=torch.tensor(1/6e-3).to(device),
            tau_syn_inv=torch.tensor(1/6e-3).to(device),
            v_leak=torch.tensor(0.)),
        dt=dt,
        device=device)

    # Dataset and loader
    dataset = LCDDataset()
    test_loader = torch.utils.data.DataLoader(
        dataset, batch_size_val, shuffle=False)

    # Load best model for current SNR
    state_dict = torch.load(model_dir / f"n_hidden_{n_hidden}.pt")
    lif_demapper.load_state_dict(state_dict)

    loss, acc, ber, n_false, i_rate, h_rate = [], [], [], 0, [], []
    while True:
        for i, (data, target) in enumerate(test_loader):

            data = data.to(device)
            target = target.to(device)

            pred_b = lif_demapper(data)

            ber_b = helpers.bit_error_rate(target, pred_b, False)
            ber.append(ber_b)
            i_rate.append(lif_demapper.rate[0].cpu().detach().numpy())
            h_rate.append(lif_demapper.rate[1].cpu().detach().numpy())

            n_false += torch.count_nonzero(
                (gray_bits[torch.argmax(pred_b.cpu(), 1)]
                 != gray_bits[target.cpu()]).reshape(-1))

        if n_false >= min_false_bits:
            break

    datas[s, 0] = n_hidden
    datas[s, 1] = np.stack(ber).reshape(-1).mean()
    datas[s, 2] = n_false
    datas[s, 3] = np.stack(i_rate).reshape(-1).mean()
    datas[s, 4] = np.stack(h_rate).reshape(-1).mean()

    print(f"Tested Demapper for {n_hidden}. BER = {datas[s, 1]}, " \
          + f"n_false = {n_false}, rate = {datas[s, 4]}")

    np.save(data_dir / "test_bers.npy", datas)

### Plot BER-Hidden Size Curve

In [None]:
data = np.load("../results/hidden_size/test_bers.npy")

color = ["#FAC90F", "#FA8D0F", "#0F69FA", "#7A6F45"]

fig, axs = plt.subplots(ncols=2, figsize=(5.5, 2.7))

axs[0].set_ylabel("BER")
axs[0].set_xlabel("Hidden Neurons")
axs[0].set_yscale("log")
axs[0].set_xscale("log")
axs[0].grid(which="minor", lw=0.2, ls=":")
axs[0].grid(which="major", lw=0.7)
axs[0].plot(data[:, 0], data[:, 1], lw=1, color=color[0])
axs[0].scatter(data[:, 0], data[:, 1], color=color[0], s=10)
axs[0].set_ylim(4e-4, 2e-3)
axs[0].legend()

axs[1].set_ylabel("Spikes")
axs[1].set_xscale("log")
axs[1].set_xlabel("Hidden Neurons")
axs[1].set_ylim(0, 95)
axs[1].plot(data[:, 0], data[:, 3], lw=1, color=color[0])
axs[1].scatter(data[:, 0], data[:, 3], color=color[0], s=10, label="Input")
axs[1].plot(data[:, 0], data[:, 4], lw=1, color=color[2])
axs[1].scatter(data[:, 0], data[:, 4], color=color[2], s=10, label="Hidden")
axs[1].legend()

plt.tight_layout()
plt.savefig("../results/hidden_size/ber_hidden_size.png")

## Sweep $n_\text{taps}$ at -20 dB 

In [None]:
# Train demapper

data_dir = Path("../results2/n_taps")
data_dir.mkdir(exist_ok=True, parents=True)
model_dir = data_dir / "models"
model_dir.mkdir(exist_ok=True, parents=True)

ths = [0.4, 0.7, 0.8, 1, 1, 1, 1, 1, 1]
taps = [1, 3, 5, 7, 9, 11, 13, 15, 17]


for v_th, n_taps in zip(ths, taps):
    print(f"n_taps: {n_taps}")

    # reset seed
    torch.manual_seed(0)
    np.random.seed(0)

    lif_demapper = SNNDemapper(
        n_in=10*n_taps,  # n_taps * n_reference_points
        n_hidden=40,
        n_out=4,  # len(alphabet)
        encoder=encoder,
        lif_params=norse.LIFParameters(
            tau_mem_inv=1/6e-3,
            tau_syn_inv=1/6e-3,
            v_leak=0.,
            v_reset=0.,
            v_th=v_th),
        li_params=norse.LIParameters(
            tau_mem_inv=torch.tensor(1/6e-3).to(device),
            tau_syn_inv=torch.tensor(1/6e-3).to(device),
            v_leak=torch.tensor(0.)),
        dt=dt,
        device=device)

    # Dataset
    dataset = LCDDataset(params)
    dataset.set_n_taps(n_taps)

    # Dataloader
    train_loader = torch.utils.data.DataLoader(
        dataset, batch_size_train, shuffle=True)
    val_loader = torch.utils.data.DataLoader(
        dataset, batch_size_val, shuffle=True)

    for i, snr in enumerate(torch.tensor([30, 22, 21, 20])):
        print(f"SNR: {snr.item()}")
        # update SNR in Dataset
        dataset.set_noise_power_gain_db(-snr.item())

        # New scheduler
        optimizer = torch.optim.Adam(lif_demapper.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=10, gamma=0.9)

        # train for SNR
        best_val_ber = np.inf
        for epoch in range(epochs):
            train_loss, train_acc, train_ber = train(
                train_loader, optimizer, scheduler, loss_fn, lif_demapper,
                device)
            if (epoch + 1) % 10 == 0 and epoch > 0 and snr == 20.:
                val_loss, val_acc, val_ber, n_false = test(
                    val_loader, lif_demapper, loss_fn, device, min_false_symbols,
                    max_test_epochs)

                # Save best Demapper
                if val_ber < best_val_ber:
                    torch.save(
                        lif_demapper.state_dict(),
                        model_dir / f"n_taps_{n_taps}.pt")
                    best_val_ber = val_ber

                print(f"Epoch={epoch}, val_loss={val_loss:.4f}, "
                    + f"val_ber={val_ber:.7f}, val_acc={val_acc:.4f}, "
                    + f"n_false={n_false}")

In [None]:
# Test demapper on independent data for different n_taps
datas = np.zeros((9, 5))


for s, (v_th, n_taps) in enumerate(zip(ths, taps)):
    # Fix Seed
    torch.manual_seed(42)
    np.random.seed(42)

    lif_demapper = SNNDemapper(
        n_in=10 * n_taps,  # n_taps * n_reference_points
        n_hidden=40,
        n_out=4,  # len(alphabet)
        encoder=encoder,
        lif_params=norse.LIFParameters(
            tau_mem_inv=1/6e-3,
            tau_syn_inv=1/6e-3,
            v_leak=0.,
            v_reset=0.,
            v_th=v_th),
        li_params=norse.LIParameters(
            tau_mem_inv=torch.tensor(1/6e-3).to(device),
            tau_syn_inv=torch.tensor(1/6e-3).to(device),
            v_leak=torch.tensor(0.)),
        dt=dt,
        device=device)

    # Dataset and loader
    dataset = PAM4IMDDDataset(params)
    test_loader = torch.utils.data.DataLoader(
        dataset, batch_size_val, shuffle=False)
    dataset.set_noise_power_gain_db(20.)
    dataset.set_n_taps(n_taps)

    # Load best model for current SNR
    state_dict = torch.load(model_dir / f"/n_taps_{n_taps}.pt")
    lif_demapper.load_state_dict(state_dict)

    loss, acc, ber, n_false, i_rate, h_rate = [], [], [], 0, [], []
    while True:
        for i, (data, target) in enumerate(test_loader):

            data = data.to(device)
            target = target.to(device)

            pred_b = lif_demapper(data)

            ber_b = helpers.bit_error_rate(target, pred_b, False)
            ber.append(ber_b)
            i_rate.append(lif_demapper.rate[0].cpu().detach().numpy())
            h_rate.append(lif_demapper.rate[1].cpu().detach().numpy())

            n_false += torch.count_nonzero(
                (gray_bits[torch.argmax(pred_b.cpu(), 1)]
                 != gray_bits[target.cpu()]).reshape(-1))

        if n_false >= min_false_bits:
            break

    datas[s, 0] = n_taps
    datas[s, 1] = np.stack(ber).reshape(-1).mean()
    datas[s, 2] = n_false
    datas[s, 3] = np.stack(i_rate).reshape(-1).mean()
    datas[s, 4] = np.stack(h_rate).reshape(-1).mean()

    print(f"Tested Demapper for {n_taps}. BER = {datas[s, 1]}, " \
          + f"n_false = {n_false}, rate = {datas[s, 4]}")

    np.save(data_dir / "test_bers.npy", datas)

### Plot $n_\text{taps}$-BER Curve

In [None]:
data = np.load("../results/n_taps/test_bers.npy")
data[:, 0] = [1, 3, 5, 7, 9, 11, 13, 15, 17]


color = ["#FAC90F", "#FA8D0F", "#0F69FA", "#7A6F45"]

fig, axs = plt.subplots(ncols=2, figsize=(5.5, 2.7))

axs[0].set_ylabel("BER")
axs[0].set_yticks([1e-3, 5e-4])
axs[0].set_xlabel("$n_\mathrm{taps}$")
axs[0].set_ylim(4e-4, 4e-2)
axs[0].set_yscale("log")
axs[0].grid(which="minor", lw=0.2, ls=":")
axs[0].grid(which="major", lw=0.7)
axs[0].plot(data[:, 0], data[:, 1], lw=1, color=color[0])
axs[0].scatter(data[:, 0], data[:, 1], color=color[0], s=10)
axs[0].legend()

axs[1].set_ylabel("Spikes")
axs[1].set_xlabel("$n_\mathrm{taps}$")
axs[1].set_ylim(0, 95)
axs[1].plot(data[:, 0], data[:, 3], lw=1, color=color[0])
axs[1].scatter(data[:, 0], data[:, 3], color=color[0], s=10, label="Input")
axs[1].plot(data[:, 0], data[:, 4], lw=1, color=color[2])
axs[1].scatter(data[:, 0], data[:, 4], color=color[2], s=10, label="Hidden")
axs[1].legend()

plt.tight_layout()
plt.savefig("../results/n_taps/n_taps_ber.png")