In [1]:
from typing import List, Tuple, Union, Optional

import torch

def temporal_scale_distribution(
    n_scales: int,
    min_scale: float = 1,
    max_scale: Optional[float] = None,
    c: Optional[float] = 1.41421,
):
    r"""
    Provides temporal scales according to [Lindeberg2016].
    The scales will be logarithmic by default, but can be changed by providing other values for c.

    .. math:
        \tau_k = c^{2(k - K)} \tau_{max}
        \mu_k = \sqrt(\tau_k - \tau_{k - 1})

    Arguments:
      n_scales (int): Number of scales to generate
      min_scale (float): The minimum scale
      max_scale (Optional[float]): The maximum scale. Defaults to None. If set, c is ignored.
      c (Optional[float]): The base from which to generate scale values. Should be a value
        between 1 to 2, exclusive. Defaults to sqrt(2). Ignored if max_scale is set.

    .. [Lindeberg2016] Lindeberg 2016, Time-Causal and Time-Recursive Spatio-Temporal
        Receptive Fields, https://link.springer.com/article/10.1007/s10851-015-0613-9.
    """
    xs = torch.linspace(1, n_scales, n_scales)
    if max_scale is not None:
        if n_scales > 1:  # Avoid division by zero when having a single scale
            c = (min_scale / max_scale) ** (1 / (2 * (n_scales - 1)))
        else:
            return torch.tensor([min_scale]).sqrt()
    else:
        max_scale = (c ** (2 * (n_scales - 1))) * min_scale
    taus = c ** (2 * (xs - n_scales)) * max_scale
    return taus.sqrt()


In [52]:
from typing import Callable, List, NamedTuple, Optional, Tuple, Type, Union

import torch

from norse.torch.module.leaky_integrator_box import LIBoxCell, LIBoxParameters
from norse.torch.module.snn import SNNCell
from norse.torch.functional.receptive_field import (
    spatial_receptive_fields_with_derivatives,
    temporal_scale_distribution,
)

class TemporalReceptiveField(torch.nn.Module):
    """Creates ``n_scales`` temporal receptive fields for arbitrary n-dimensional inputs.
    The scale spaces are selected in a range of [min_scale, max_scale] using an exponential distribution, scattered using ``torch.linspace``.

    Parameters:
        shape (torch.Size): The shape of the incoming tensor, where the first dimension denote channels
        n_scales (int): The number of temporal scale spaces to iterate over.
        activation (SNNCell): The activation neuron. Defaults to LIBoxCell
        activation_state_map (Callable): A function that takes a tensor and provides a neuron parameter tuple.
            Required if activation is changed, since the default behaviour provides LIBoxParameters.
        min_scale (float): The minimum scale space. Defaults to 1.
        max_scale (Optional[float]): The maximum scale. Defaults to None. If set, c is ignored.
        c (Optional[float]): The base from which to generate scale values. Should be a value
            between 1 to 2, exclusive. Defaults to sqrt(2). Ignored if max_scale is set.
        time_constants (Optional[torch.Tensor]): Hardcoded time constants. Will overwrite the automatically generated, logarithmically distributed scales, if set. Defaults to None.
        dt (float): Neuron simulation timestep. Defaults to 0.001.
    """
    def __init__(
        self,
        shape: torch.Size,
        n_scales: int = 4,
        activation: Type[SNNCell] = LIBoxCell,
        activation_state_map: Callable[
            [torch.Tensor], NamedTuple
        ] = lambda t: LIBoxParameters(tau_mem_inv=t),
        min_scale: float = 1,
        max_scale: Optional[float] = None,
        c: float = 1.41421,
        time_constants: Optional[torch.Tensor] = None,
        dt: float = 0.001,
    ):
        super().__init__()
        if time_constants is None:
            taus = (1 / dt) / temporal_scale_distribution(
                n_scales, min_scale=min_scale, max_scale=max_scale, c=c
            )
            self.time_constants = torch.stack(
                [
                    torch.full(
                        [shape[0], *[1 for i in range(len(shape) - 1)]],
                        tau,
                        dtype=torch.float32,
                    )
                    for tau in taus
                ]
            )
        else:
            self.time_constants = time_constants
        self.ps = torch.nn.Parameter(self.time_constants)
        # pytype: disable=missing-parameter
        self.neurons = activation(p=activation_state_map(self.ps), dt=dt)
        # pytype: enable=missing-parameter
        self.rf_dimension = len(shape)
        self.n_scales = n_scales

    def forward(self, x: torch.Tensor, state: Optional[NamedTuple] = None):
        x_repeated = torch.stack(
            [x for _ in range(self.n_scales)], dim=-self.rf_dimension - 1
        )
        return self.neurons(x_repeated, state)

In [111]:
from norse.torch.module.lift import Lift
m = (TemporalReceptiveField((1, 1, 10), 3))
y, s = m(torch.ones(2, 1, 1, 10)*0.1)

In [112]:
y.shape

torch.Size([2, 3, 1, 1, 10])

In [113]:
y[:,1,0,0,0]

tensor([0.0707, 0.0707], grad_fn=<SelectBackward0>)