In [None]:
import math

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.utils import _single, _reverse_repeat_tuple
from torch.nn.parameter import Parameter

import pytorch_lightning as pl

from typing import Tuple, Union, Literal, Callable, Any

### Deformable Convolutions applied to Time Series

#### Kernel definition

In this experiments we define a kernel used in a convolution as the following way. Let $\mathcal{R}$ be a used kernel of size 3 to sample a small region of the input. We can define this kernel as the positions:

$$
\mathcal{R} = \{ (-1), (0), (+1) \}
$$

Note that for each position $i$ in the kernel, $\mathcal{R}_i$ represents a index value sparsed in a simetric way, this will make more easier for us to apply the offsets. Also, for each $\mathcal{R}_i$ we have an weight $k_{\mathcal{R}_i}$ associated to that position.

For example, for that kernel with size $3$, we have a tensor storing all the weights for each position.

```python
k = torch.Tensor([0.32, 0.21, -0.34])
```

Where $k$ is the kernel weights.


#### Formulating a basic one-dimensional convolution operation


In the simplest case, the output of a one-dimensional convolutional layer with input $x$ of size $(N, C_{in}, L)$ and output $y$ of size $(N, C_{out}, L)$ can be precisely described as:

$$
y(N_i, C_{out_j}, p_0) = \text{bias}(C_{out_j}) + \sum^{C_{in} - 1}_{k=0} \sum_{p_n \in \mathcal{R}} w(C_{out_j}, k, p_n) ~ \cdot ~ {x}(N_i, k, p_0 + p_n)
$$

where $N$ is a batch size, $C$ denotes a number of channels, $L$ is  length of signal sequence, $p_0$ is the starting position of each kernel and $p_n$ is enumerating along with all the positions in $\mathcal{R}$.

#### First concepts

Different for the normal convolution, the deformable convolution instead of using just a simple fixed sampling grid, introduces offsets to the normal convolution operation. If $\mathcal{R}$ is the normal grid, then the deformable convolution operation augments learned offsets to the grid, thereby deforming the sampling position of the grid.

This operation can be explained by the following equation:

$$
y(N_i, C_{out_j}, p_0) = \text{bias}(C_{out_j}) + \sum^{C_{in} - 1}_{k=0} \sum_{p_n \in \mathcal{R}} w(C_{out_j}, k, p_n) \cdot x(N_i, k, p_0 + p_n + \Delta p_n)
$$

where the new term $\Delta p_n$ denotes the offsets added to the normal convolution.

**Note 1**: As the sampling is done on the irregular and offset location, and $\Delta p_n$ is generally fractional, we use linear interpolation to implement the above equation.

##### Linear Interpolation

We use Linear Interpolation because as we add offsets to the existing sampling positions, we obtain fractional points, which are not defined locations on the grid. In order to estimate their values we use linear interpolation which uses 2 of the neighbouring values to estimate the value of the new deformed position.

The equation that is used to perform a linear interpolation and estimate the pixel value a the fractional position is given below where $p = p_0 + p_n + \Delta p_n$ is the deformed position, $q$ enumerates all the valid positions on the input feature map $x$ and $G(\cdot)$ is the linear interpolation kernel.

$$
x(p) = \sum_q G(q, p) \cdot x(q)
$$

**TODO**: formulate the new equation for one dimensional linear interpolation instead of the bilinear used for 2D images.

Note that $G$ is two dimensional and it is separated into two one dimensional kernels.

$$
G(q, p) = g(q_x, p_x) \cdot g(q_y, p_y)
$$

where $g(a,b) = \max(0, 1 - |a-b|)$.

##### Modulated Modules

$$
y(p) = \sum^{K}_{k=1} w_k \cdot x(p + p_k + \Delta p_k) \cdot \Delta m_k
$$

In [None]:
def linear_interpolation(
    x: torch.Tensor,
    offsets: torch.Tensor,
    kernel_size: int,
    dilation: int,
    stride: int,
    dilated_positions = None,
    device: str = 'cpu',
    unconstrained: bool = False
) -> None:
    # Ensure that the x and offsets are in the same device
    assert x.device == offsets.device, 'The tensors x and offsets must be on same device.'
    
    # Calculate the receptive field for that kernel
    kernel_rfield = dilation * (kernel_size - 1) + 1
    
    # Every index in x (input) we need to consider
    if dilated_positions == None:
        dilated_positions = torch.linspace(
            0,
            kernel_rfield - 1,
            kernel_size,
            device=offsets.device,
            dtype=offsets.dtype
        )
        
    max_t0 = (offsets.shape[-2] - 1) * stride
    t0s = torch.linspace(0, max_t0, offsets.shape[-2], device=offsets.device, dtype=offsets.dtype).unsqueeze(-1)
    dilated_offsets_repeated = dilated_positions + offsets
    
    T = t0s + dilated_offsets_repeated # batch_size x channels x out_length x kernel_size
    if not unconstrained:
        T = torch.max(T, t0s)
        T = torch.min(T, t0s + torch.max(dilated_positions))
    else:
        T = torch.clamp(T, 0.0, float(x.shape[-1]))
        
    with torch.no_grad():
        U = torch.floor(T).to(torch.long)
        U = torch.clamp(U, min=0, max=x.shape[-2] - 2)
        
        U = torch.stack([U, U + 1], dim=-1)

        if U.shape[1] < x.shape[1]:
            U = U.repeat(1, x.shape[1], 1, 1, 1)
    
    x = x.unsqueeze(-1).repeat(1, 1, 1, U.shape[-1])

    x = torch.stack([
        x.gather(index=torch.clamp(U[:, :, :, i, :], 0, x.shape[-2] - 1), dim=-2)
        for i in range(U.shape[-2])], dim=-1)
    
    G = torch.max(
        torch.zeros(U.shape, device=device),
        1 - torch.abs(U - T.unsqueeze(-1))
    )
    
    mx = torch.multiply(G, x.moveaxis(-2, -1))
    return torch.sum(mx, axis=-1)

#### General and Basic Implementation

The Figure below shows an example of the overall strategy used to implement the deformable convolution.

<div align='center'>
    <img src='./deformable_convolution.png'>
</div>

As shown in this Figure, the offsets are obtained by applying a convolution layer over the input. The convolution kernel used has spatial resolution and dilation as those of the current convolution layer. The output offset has the same resolution as that of the input and has $C_{\text{off}}$ channels, where $C_{\text{off}}$ in that case correspond to $N$ 1d offsets.

In [None]:
class DeformableConvolution1d(nn.Module):
    
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int = 1,
                 padding: Union[int, Literal['valid', 'same']] = 'valid',
                 dilation: int = 1,
                 groups: int = 1,
                 bias: bool = True,
                 padding_mode: str = 'reflect',
                 device: str = 'cpu',
                 interpolation_func: Callable = linear_interpolation,
                 unconstrained: str = None,
                 *args,
                 **kwargs) -> None:
        
        self.device = device
        self.interpolation_func = interpolation_func
        padding_ = padding if isinstance(padding, str) else _single(padding)
        stride_ = _single(stride)
        dilation_ = _single(dilation)
        kernel_size_ = _single(kernel_size)
        
        super().__init__(*args, **kwargs)
        
        if groups < 0:
            raise ValueError('groups must be a positive integer')
        if in_channels % groups != 0:
            raise ValueError('input channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out channels must be divisible by groups')
        
        valid_padding_strings = {'same', 'valid'}
        if isinstance(padding, str):
            if padding not in valid_padding_strings:
                raise ValueError('invalid padding string, you must use valid or same')
            if padding == 'same' and any(s != 1 for s in stride_):
                raise ValueError('padding=same is not supported for strided convolutions')
            
        valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
        if padding_mode not in valid_padding_modes:
            raise ValueError('invalid padding mode, you must use zeros, reflect, replicate or circular')
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding_
        self.dilation = dilation
        self.groups = groups
        self.padding_mode = padding_mode
        
        if isinstance(self.padding, str):
            self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size_)
            if padding == 'same':
                for d, k, i in zip(dilation_, kernel_size_, range(len(kernel_size_) - 1, -1, -1)):
                    total_padding = d * (k - 1)
                    left_pad = total_padding // 2
                    self._reversed_padding_repeated_twice[2 * i] = left_pad
                    self._reversed_padding_repeated_twice[2 * i + 1] = (
                        total_padding - left_pad
                    )
        else:
            self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)
            
        self.weight = Parameter(
            torch.empty(out_channels, in_channels // groups, kernel_size)
        )
        
        self.dilated_positions = torch.linspace(
            0, dilation * kernel_size - dilation, kernel_size
        )
        
        if bias:
            self.bias = Parameter(torch.empty(out_channels))
        else:
            self.register_parameter('bias', None)
            
        if not unconstrained == None:
            self.unconstrained = unconstrained
            
        self.reset_parameters()
        self.to(device)
        
    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
            
    def extra_repr(self) -> str:
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding = {padding}'
        # if self.dilation != (1,) * len(self.dilation):
        s += ', dilation={dilation}'
        # if self.output_padding != (0,) * len(self.output_padding):
        #     s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        if self.padding_mode != 'zeros':
            s += ', padding_mode={padding_mode}'

        return s.format(**self.__dict__)
    
    def __setstate__(self, state) -> None:
        super().__setstate__(state)
        if not hasattr(self, 'padding_mode'):
            self.padding_mode = 'zeros'
            
    def forward(self,
                x: torch.Tensor,
                offsets: torch.Tensor,
                mask: torch.Tensor = None) -> torch.Tensor:
        in_shape = x.shape
        if self.padding_mode != 'zeros':
            x = F.pad(
                x,
                self._reversed_padding_repeated_twice,
                mode=self.padding_mode
            )
        elif self.padding == 'same':
            x = F.pad(
                x,
                self._reversed_padding_repeated_twice,
                mode='constant',
                value=0
            )
            
        if not self.device == offsets.device:
            self.device = offsets.device
        if self.dilated_positions.device != self.device:
            self.dilated_positions = self.dilated_positions.to(self.device)
            
        if 'unconstrained' in self.__dict__.keys():
            x = self.interpolation_func(
                x,
                kernel_size=self.kernel_size,
                dilation=self.dilation,
                offsets=offsets,
                stride=self.stride,
                dilated_positions=self.dilated_positions,
                device=self.device,
                unconstrained=self.unconstrained
            )
        else:
            x = self.interpolation_func(
                x,
                kernel_size=self.kernel_size,
                dilation=self.dilation,
                offsets=offsets,
                stride=self.stride,
                dilated_positions=self.dilated_positions,
                device=self.device
            )
            
        x = x.flatten(-2, -1)
        output = F.conv1d(
            x,
            weight=self.weight,
            bias=self.bias,
            stride=self.kernel_size,
            groups=self.groups
        )
        
        if self.padding == 'same':
            assert in_shape[-1] == output.shape[-1], f'input length {in_shape} and output length {output.shape} do not match'

        return output

In [None]:
EPS = 1e-9

class GlobalLayerNormalization(nn.Module):
    def __init__(self, channel_size) -> None:
        super().__init__()
        self.gamma = nn.Parameter(torch.Tensor(1, 1, channel_size))
        self.beta = nn.Parameter(torch.Tensor(1, 1, channel_size))
        
        self.reset_parameters()
        
    def reset_parameters(self) -> None:
        self.gamma.data.fill_(1)
        self.beta.data.zero_()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
        var = (
            (torch.pow(x - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
        )
        gln_x = self.gamma * (x - mean) / torch.pow(var + EPS, 0.5) + self.beta
        return gln_x

In [None]:
class PackedDeformableConvolution1d(DeformableConvolution1d):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int = 1,
                 padding: int | Literal['valid', 'same'] = 'valid',
                 dilation: int = 1,
                 groups: int = 1,
                 bias: bool = True,
                 padding_mode: str = 'reflect',
                 offset_groups: int = 1,
                 device: str = 'cpu',
                 interpolation_func: Callable = linear_interpolation,
                 unconstrained: str = None,
                 *args, **kwargs) -> None:
        
        assert offset_groups in [1, in_channels], 'offset groups only implemented for 1 or in_channels'
        
        super().__init__(in_channels,
                         out_channels,
                         kernel_size,
                         stride,
                         padding,
                         dilation,
                         groups,
                         bias,
                         padding_mode,
                         device,
                         interpolation_func,
                         unconstrained,
                         *args, **kwargs)
        
        self.offset_groups = offset_groups

        self.offset_dconv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=kernel_size,
            stride=1,
            groups=in_channels,
            padding=padding,
            padding_mode=padding_mode,
            bias=False
        )
        self.offset_dconv_norm = GlobalLayerNormalization(
            in_channels
        )
        self.offset_dconv_prelu = nn.LeakyReLU()
        
        
        self.offset_pconv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=kernel_size*offset_groups,
            kernel_size=1,
            stride=1,
            bias=False
        )
        self.offset_pconv_norm = GlobalLayerNormalization(
            kernel_size * offset_groups
        )
        self.offset_pconv_prelu = nn.LeakyReLU()

        self.device = device
        self.to(device)
        
    def forward(self, x: torch.Tensor, with_offsets: bool = False) -> torch.Tensor:
        # offsets = self.offset_dconv(x)
        # offsets = self.offset_dconv_norm(self.offset_dconv_prelu(offsets).moveaxis(1, 2)).moveaxis(2, 1)
        
        self.device = x.device
        
        assert str(x.device) == str(self.device), 'x and the deformable conv must be on same device'
        # assert str(x.device) == str(offsets.device), 'x and offsets must be on same device'
        
        offsets = self.offset_pconv(x)
        # offsets = self.offset_pconv_norm(
        #     self.offset_pconv_prelu(offsets).moveaxis(1, 2)
        # ).moveaxis(2, 1)
        offsets = offsets.unsqueeze(0).chunk(self.offset_groups, dim=2)
        offsets = torch.vstack(offsets).moveaxis((0, 2), (1, 3))
        
        if with_offsets:
            return super().forward(x, offsets), offsets
        else:
            return super().forward(x, offsets)
        

### Classifiers using Deformable Convolutions

In [None]:
from aeon.datasets import load_arrow_head, load_osuleaf, load_acsf1, load_classification
import numpy as np

X_train, y_train, _ = load_classification(name='CinCECGTorso', split='TRAIN')
y_train = y_train.astype(float)

X_test, y_test, _ = load_classification(name='CinCECGTorso', split='TEST')
y_test = y_test.astype(float)

X_train.shape, y_train.shape

In [None]:
num_classes = len(np.unique(y_train))
np.unique(y_train)

In [None]:
if np.unique(y_train)[0] > 0.:
    y_train = y_train - 1
    y_test = y_test - 1
    
np.unique(y_train)

In [None]:
from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).long())
test_dataset = TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test).long())

train_loader = DataLoader(train_dataset, batch_size=8)
test_loader = DataLoader(test_dataset, batch_size=8)

In [None]:
from typing import Any
from pytorch_lightning.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
from sklearn.metrics import accuracy_score

class FCN(pl.LightningModule):

    def __init__(self, num_classes: int) -> None:
        super().__init__()
        self.in_channels = 1
        self.num_classes = num_classes

        self.layers = nn.Sequential(*[
            nn.Conv1d(in_channels=self.in_channels, out_channels=128, kernel_size=8, stride=1, padding='same'),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(),
            nn.Conv1d(in_channels=128, out_channels=256, kernel_size=5, stride=1, padding='same'),
            nn.BatchNorm1d(num_features=256),
            nn.ReLU(),
            nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding='same'),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(),
        ])

        self.linear = nn.Linear(128, self.num_classes)
        self.softmax = nn.Softmax(dim=1)

        self.criteria = nn.CrossEntropyLoss()

    def configure_optimizers(self) -> OptimizerLRScheduler:
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
    def forward(self, x):
        x = self.layers(x)
        return self.softmax(self.linear(x.mean(dim=-1)))
    
    def training_step(self, batch, batch_idx) -> torch.Tensor:
        x, y = batch
        
        y_pred = self(x)
        
        loss = self.criteria(y_pred, y)
        self.log('train_loss', loss, prog_bar=True)
        
        acc = accuracy_score(y_pred.argmax(dim=-1).cpu().numpy(), y.cpu().numpy())
        self.log('train_acc', acc, prog_bar=True)

        return loss
    
    def test_step(self, batch, batch_idx) -> Any:
        x, y = batch
        y_pred = self(x)
        
        acc = accuracy_score(y_pred.argmax(dim=-1).cpu().numpy(), y.cpu().numpy())
        self.log('test_acc', acc, prog_bar=True)
        
        return


In [None]:
from sklearn.metrics import accuracy_score

device = torch.device('cuda')

accuracies = []

for experiment_id in range(10):
    fcn = FCN(num_classes=num_classes)
    fcn_trainer = pl.Trainer(max_epochs=300, accelerator='gpu', devices=-1)
    fcn_trainer.fit(fcn, train_loader)
    
    fcn = fcn.to(device)
    fcn.eval()

    with torch.no_grad():
        y_preds = []
        y_s = []
        
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
        
            y_pred = fcn(x)
            y_pred = y_pred.argmax(dim=-1).cpu().tolist()
            
            y_preds.extend(y_pred)

            y_s.extend(y.cpu().numpy())

        accuracies.append(accuracy_score(y_preds, y_s))

    print(np.mean(accuracies))
    print(np.std(accuracies))


In [None]:
accuracies

In [None]:
np.mean(accuracies), np.std(accuracies)

In [None]:
from pytorch_lightning.utilities.types import OptimizerLRScheduler
from sklearn.metrics import accuracy_score




class DefFCN(pl.LightningModule):

    def __init__(self, num_classes) -> None:
        super().__init__()
        self.in_channels = 1
        self.num_classes = num_classes

        self.conv_blocks = nn.Sequential(*[
            PackedDeformableConvolution1d(
                in_channels=self.in_channels, out_channels=128, kernel_size=8, padding='same', stride=1
            ),
            nn.BatchNorm1d(num_features=128),
            nn.LeakyReLU(),
            
            # nn.Conv1d(in_channels=self.in_channels, out_channels=128, kernel_size=8, stride=1, padding='same'),
            # nn.BatchNorm1d(num_features=128),
            # nn.LeakyReLU(),

            nn.Conv1d(in_channels=128, out_channels=256, kernel_size=5, stride=1, padding='same'),
            nn.BatchNorm1d(num_features=256),
            nn.LeakyReLU(),

            nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding='same'),
            nn.BatchNorm1d(num_features=128),
            nn.LeakyReLU(),
        ])
        
        self.criteria = nn.CrossEntropyLoss()

        self.linear = nn.Linear(128, self.num_classes)
        self.softmax = nn.Softmax(dim=1)
        
    def _set_lr(self, module, grad_in, grad_out):
        new_grad = []
        for i in range(len(grad_in)):
            if grad_in[i] is not None:
                new_grad.append(grad_in[i] * self.lr_ratio)
            else:
                new_grad.append(grad_in[i])
                
        new_grad = tuple(new_grad)
        return new_grad
        
    def configure_optimizers(self) -> OptimizerLRScheduler:
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self, x):
        x = self.conv_blocks(x)
        
        return self.softmax(self.linear(x.mean(dim=-1)))
    
    def training_step(self, batch, batch_idx) -> torch.Tensor:
        x, y = batch
        
        y_pred = self(x)
        
        loss = self.criteria(y_pred, y)
        self.log('train_loss', loss, prog_bar=True)
        
        acc = accuracy_score(y_pred.argmax(dim=-1).cpu().numpy(), y.cpu().numpy())
        self.log('train_acc', acc, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx) -> Any:
        x, y = batch
        y_pred = self(x)
        
        acc = accuracy_score(y_pred.argmax(dim=-1).cpu().numpy(), y.cpu().numpy())
        self.log('test_acc', acc, prog_bar=True)
        
        return

In [None]:
from sklearn.metrics import accuracy_score

device = torch.device('cuda')

accuracies = []

for experiment_id in range(10):
    def_fcn = DefFCN(num_classes=num_classes)
    def_fcn_trainer = pl.Trainer(max_epochs=300, accelerator='gpu', devices=-1)
    def_fcn_trainer.fit(def_fcn, train_loader)
    
    def_fcn = fcn.to(device)
    def_fcn.eval()

    with torch.no_grad():
        y_preds = []
        y_s = []
        
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)
        
            y_pred = def_fcn(x)
            y_pred = y_pred.argmax(dim=-1).cpu().tolist()
            
            y_preds.extend(y_pred)

            y_s.extend(y.cpu().numpy())

        accuracies.append(accuracy_score(y_preds, y_s))

    print(np.mean(accuracies))
    print(np.std(accuracies))


In [None]:
accuracies

### More Complex Models

#### Our Method