In [21]:
import numpy as np

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl

In [None]:
from typing import Any, Dict, List, Optional, Union, Tuple

from pytorch_lightning.utilities.types import STEP_OUTPUT


def same_padding1d(sequence_length: int, kernel_size: int, stride: Optional[int] = 1, dilation: Optional[int] = 1):
    p = (sequence_length - 1) * stride + (kernel_size - 1) * dilation + 1 - sequence_length
    return p // 2, p - p // 2


class Pad1d(nn.ConstantPad1d):
    def __init__(self, padding: Any, value: Optional[float] = 0.):
        super().__init__(padding, value)


class SameConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Optional[Union[Tuple[int], int]] = 3,
        stride: Optional[int] = 1,
        dilation: Optional[int] = 1,
        bias: Optional[bool] = False
    ) -> None:
        super().__init__()
        self.kernel_size, self.stride, self.dilation = kernel_size, stride, dilation
        # Create the conv module that will be used for same padding
        self.conv1d_same = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            bias=bias
        )
        self.weight = self.conv1d_same.weight
        if bias == True:
            self.bias = self.conv1d_same.bias
        self.pad = Pad1d
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self.padding = same_padding1d(x.shape[-1], self.kernel_size, dilation=self.dilation) # Stride: will not be used on padding calculation
        return self.conv1d_same(self.pad(self.padding)(x))


def Conv1d(
    in_channels: int,
    out_channels: int,
    kernel_size: Optional[Union[Tuple[int], int]] = None,
    stride: Optional[int] = 1,
    padding: Optional[Union[str, int]] = 'same',
    dilation: Optional[int] = 1,
    bias: Optional[bool] = False
) -> nn.Module:
    if padding == 'same':
        if kernel_size % 2 == 1:
            conv = nn.Conv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                stride=stride,
                padding=kernel_size // 2 * dilation,
                dilation=dilation,
                bias=bias
            )
        else:
            conv = SameConv(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                bias=bias
            )
    else:
        conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias
        )

    return conv


class DeepSVDDAutoEncoder(pl.LightningModule):
    def __init__(self, sequence_length: int, in_channels: int, representation_dim: int = 32) -> None:
        super().__init__()

        self.sequence_length = sequence_length
        self.in_channels = in_channels
        self.representation_dim = representation_dim

        # --- Encoder --- #
        self.encoder_conv1 = Conv1d(
            in_channels=in_channels,
            out_channels=8,
            kernel_size=5,
            bias=False,
        )
        self.encoder_bn1 = nn.BatchNorm1d(num_features=8, eps=1e-04, affine=False)
        self.encoder_conv2 = Conv1d(
            in_channels=8,
            out_channels=4,
            kernel_size=5,
            bias=False,
        )
        self.encoder_bn2 = nn.BatchNorm1d(num_features=4, eps=1e-04, affine=False)
        self.encoder_linear = nn.Linear(self.sequence_length * 4, self.representation_dim, bias=False)
        
        # --- Decoder --- #
        self.decoder_linear = nn.Linear(self.representation_dim, self.sequence_length * 4, bias=False)
        self.decoder_conv1 = Conv1d(
            in_channels=4,
            out_channels=4,
            kernel_size=5,
            bias=False,
        )
        self.decoder_bn1 = nn.BatchNorm1d(num_features=4, eps=1e-04, affine=False)
        self.decoder_conv2 = Conv1d(
            in_channels=4,
            out_channels=8,
            kernel_size=5,
            bias=False,
        )
        self.decoder_bn2 = nn.BatchNorm1d(num_features=8, eps=1e-04, affine=False)
        self.decoder_conv3 = Conv1d(
            in_channels=8,
            out_channels=1,
            kernel_size=5,
            bias=False,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = F.leaky_relu(self.encoder_bn1(self.encoder_conv1(x)))
        z = F.leaky_relu(self.encoder_bn2(self.encoder_conv2(z)))
        z = z.view(z.size(0), -1)
        z = self.encoder_linear(z) # Final representation output for encoder

        x_hat = self.decoder_linear(z)
        x_hat = x_hat.view(z.size(0), 4, self.sequence_length)
        x_hat = F.leaky_relu(self.decoder_bn1(self.decoder_conv1(x_hat)))
        x_hat = F.leaky_relu(self.decoder_bn2(self.decoder_conv2(x_hat)))
        x_hat = self.decoder_conv3(x_hat) # Final reconstruction output for decoder

        return x_hat, z
    
    def configure_optimizers(self) -> Any:
        # Set optimizer for the autoencoder task
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, weight_decay=1e-6, amsgrad=False)
        # Set learning rate scheduler
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[250], gamma=0.1)
        return [optimizer], [scheduler]
    
    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        x_hat, z = self(x)
        
        loss = torch.sum((x_hat - x) ** 2, dim=tuple(range(1, x_hat.dim())))
        loss = torch.mean(loss)
        
        self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)

        return loss
    
    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> List[Dict[str, Any]]:
        x, y = batch
        x_hat, z = self(x)

        loss = torch.sum((x_hat - x) ** 2, dim=tuple(range(1, x_hat.dim())))
        loss = torch.mean(loss)
        
        self.log('test_loss', loss, prog_bar=True, on_step=False, on_epoch=True)

        return


class DeepSVDD(pl.LightningModule):
    def __init__(self, sequence_length: int, in_channels: int, representation_dim: int = 32) -> None:
        super().__init__()

        self.sequence_length = sequence_length
        self.in_channels = in_channels
        self.representation_dim = representation_dim
        
        self.R = torch.tensor(0.0, device=self.device)
        self.nu = 0.1

        # --- Encoder --- #
        self.encoder_conv1 = Conv1d(
            in_channels=in_channels,
            out_channels=8,
            kernel_size=5,
            bias=False,
        )
        self.encoder_bn1 = nn.BatchNorm1d(num_features=8, eps=1e-04, affine=False)
        self.encoder_conv2 = Conv1d(
            in_channels=8,
            out_channels=4,
            kernel_size=5,
            bias=False,
        )
        self.encoder_bn2 = nn.BatchNorm1d(num_features=4, eps=1e-04, affine=False)
        self.encoder_linear = nn.Linear(self.sequence_length * 4, self.representation_dim, bias=False)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = F.leaky_relu(self.encoder_bn1(self.encoder_conv1(x)))
        z = F.leaky_relu(self.encoder_bn2(self.encoder_conv2(z)))
        z = z.view(z.size(0), -1)
        z = self.encoder_linear(z)

        return z

    def init_center(self, loader: DataLoader, eps: Optional[float] = 0.01) -> torch.Tensor:
        n_samples = 0
        center = torch.zeros(self.representation_dim, device=self.device)

        self.eval()
        with torch.no_grad():
            for (x, y) in loader:
                x = x.to(self.device)
                z = self(x)

                n_samples += z.shape[0]
                center += torch.sum(z, dim=0)

        center /= n_samples

        center[(abs(center) < eps) & (center < 0)] = -eps
        center[(abs(center) < eps) & (center > 0)] = eps
        
        return center
    
    def get_radius(self, distance: torch.Tensor, nu: float):
        return np.quantile(np.sqrt(distance.clone().data.cpu().numpy()), 1 - nu)

    def configure_optimizers(self) -> Any:
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, weight_decay=1e-6, amsgrad=False)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150], gamma=0.1)
        return [optimizer], [scheduler]
    
    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        x, y = batch
        z = self(x)
        
        distance = torch.sum((z - self.center) ** 2, dim=1)
        scores = distance - self.R ** 2
        loss = self.R ** 2 + (1 / self.nu) * torch.mean(torch.max(torch.zeros_like(scores), scores))
        
        self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        
        if self.current_epoch >= 10:
            self.R.data = torch.tensor(self.get_radius(distance, self.nu), device=self.device)
        
        return loss
    
    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> List[Dict[str, Any]]:
        return


torch.Size([2, 3])