In [4]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import umap
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader

import pytorch_lightning as pl

from deeptime.models.representation import ConvVariationalAutoEncoder
from deeptime.data import BaseDataset
from deeptime.models.utils import Conv1dSamePadding, UpSample

from sktime.datasets import load_UCR_UEA_dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [None]:
class ConvVAEOCC(pl.LightningModule):
    
    def __init__(
        self,
        in_channels: int,
        in_features: int,
        latent_dim: int,
        learning_rate: float = 5e-6,
        radius: float = 0.
    ) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.in_features = in_features
        self.latent_dim = latent_dim
        
        self.learning_rate = learning_rate

        self.radius = radius
        
        self.e = nn.Sequential(
            Conv1dSamePadding(
                in_channels=in_channels,
                out_channels=128,
                kernel_size=8,
                stride=1,
                bias=False,
            ),
            nn.BatchNorm1d(num_features=128),
            nn.Tanh(),
            Conv1dSamePadding(
                in_channels=128,
                out_channels=256,
                kernel_size=5,
                stride=1,
                bias=False,
            ),
            nn.BatchNorm1d(num_features=256),
            nn.Tanh(),
            Conv1dSamePadding(
                in_channels=256,
                out_channels=128,
                kernel_size=3,
                stride=1,
                bias=False,
            ),
            nn.BatchNorm1d(num_features=256),
            nn.Tanh(),
            nn.Flatten(),
            nn.Linear(in_features=128 * in_features, out_features=256, bias=False),
            nn.Tanh(),
            nn.Linear(in_features=256, out_features=128, bias=False),
            nn.Tanh(),
            nn.Linear(in_features=128, out_features=latent_dim * 2, bias=False),
            nn.Tanh(),
        )
        
    def reparametrize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        sample = mu + (eps * std)
        return sample
    
    def forward(self, 