In [None]:
# Imports and reproducibility
from __future__ import annotations

import os
import sys
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split 

from skfda import representation

def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
# Hyperparameters
@dataclass(frozen=True)
class FAEConfig:
    n_rep: int = 5
    n_basis_project: int = 20
    n_basis_revert: int = 20
    basis_type_project: str = "Bspline"   # "Bspline" or "Fourier"
    basis_type_revert: str = "Bspline"    # "Bspline" or "Fourier"

    penalty: Optional[str] = "diff"       # None or "diff"
    lamb: float = 1e-3

    epochs: int = 5000
    batch_size: int = 28
    lr: float = 1e-3
    weight_decay: float = 1e-6
    init_weight_sd: Optional[float] = 0.5

    split_rate: float = 0.8
    n_iter: int = 20
    base_seed: int = 743

    log_every: int = 100
    device: str = "cpu"

In [None]:
# Basis construction
def build_basis_fc(
    tpts: torch.Tensor,
    n_basis: int,
    basis_type: str
) -> torch.Tensor:
    """
    Returns basis evaluated on tpts as a torch tensor of shape [n_time, n_basis].
    """
    t_min = float(torch.min(tpts))
    t_max = float(torch.max(tpts))

    if basis_type == "Bspline":
        basis = representation.basis.BSpline(n_basis=n_basis, order=4)
    elif basis_type == "Fourier":
        basis = representation.basis.Fourier([t_min, t_max], n_basis=n_basis)
    else:
        raise ValueError(f"Unknown basis_type: {basis_type}")

    # skfda returns [n_time, n_basis, 1] for derivative=0
    eval_ = basis.evaluate(tpts, derivative=0)[:, :, 0]
    return torch.from_numpy(eval_).float()