In [None]:
# Functional Autoencoder in PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F


class FourierBasis(nn.Module):
    """
    Simple Fourier basis on [0, 1] for representing functional weights.
    """
    def __init__(self, num_basis: int, T: int, device=None, dtype=torch.float32):
        super().__init__()
        self.num_basis = num_basis
        self.T = T
        device = device or torch.device("cpu")

        # time grid in [0, 1]
        t = torch.linspace(0.0, 1.0, T, dtype=dtype, device=device)  # [T]
        self.register_buffer("t", t)

        # Build basis matrix Phi[b, t] = phi_b(t)
        # basis 0: constant; others: sines and cosines
        Phi = torch.zeros(num_basis, T, dtype=dtype, device=device)
        Phi[0] = 1.0
        k = 1
        freq = 1
        while k < num_basis:
            Phi[k] = torch.sin(2 * torch.pi * freq * t)
            k += 1
            if k < num_basis:
                Phi[k] = torch.cos(2 * torch.pi * freq * t)
                k += 1
            freq += 1

        self.register_buffer("Phi", Phi)

        # approximate dt for Riemann sum
        dt = (t[-1] - t[0]) / (T - 1)
        self.register_buffer("dt", torch.tensor(dt, dtype=dtype, device=device))


class FunctionalAutoencoder(nn.Module):
    """
    Functional Autoencoder (FAE) for P-dimensional functional data.

    x: [batch, P, T]  (P channels, T time points)
    latent_dim: dimension d of encoding.
    """
    def __init__(
        self,
        P: int,
        T: int,
        latent_dim: int,
        num_basis: int = 16,
        hidden_dims=None,
        device=None,
        dtype=torch.float32,
    ):
        super().__init__()
        device = device or torch.device("cpu")
        self.P = P
        self.T = T
        self.latent_dim = latent_dim

        # Basis to represent functional weights
        self.basis = FourierBasis(num_basis=num_basis, T=T, device=device, dtype=dtype)

        B = num_basis

        ### Functional first layer: HP -> R^H1 ### 
        # coefficients for w^{(1)}_{k,j}(t): shape [H1, P, B]
        H1 = hidden_dims[0] if hidden_dims is not None else latent_dim
        self.w1_coeffs = nn.Parameter(
            torch.randn(H1, P, B, dtype=dtype, device=device) * 0.1
        )

        # optional vector-to-vector hidden layers (between functional first layer and latent)
        fc_layers = []
        in_dim = H1
        if hidden_dims is not None and len(hidden_dims) > 1:
            for h in hidden_dims[1:]:
                fc_layers.append(nn.Linear(in_dim, h, bias=True))
                fc_layers.append(nn.Tanh())
                in_dim = h
        # final linear to latent representation
        fc_layers.append(nn.Linear(in_dim, latent_dim, bias=True))
        self.encoder_mlp = nn.Sequential(*fc_layers)

        # --- Functional last layer: R^latent -> HP ---
        # coefficients for w^{(last)}_{j,k}(t):
        # shape [P, latent_dim, B]
        self.w_last_coeffs = nn.Parameter(
            torch.randn(P, latent_dim, B, dtype=dtype, device=device) * 0.1
        )

        self.activation = torch.tanh  # can be ReLU / etc.

    # ----- helper: build functional weights on the time grid -----
    def _weights_from_coeffs(self, coeffs):
        """
        coeffs: [..., B]
        basis.Phi: [B, T]
        returns: [..., T]
        """
        # Einstein summation: (..., B) * (B, T) -> (..., T)
        return torch.einsum("...b,bt->...t", coeffs, self.basis.Phi)

    # ----- encoder functional layer -----
    def _encode_first_layer(self, x):
        """
        x: [B, P, T]
        returns hidden_1: [B, H1]
        """
        Bbatch = x.shape[0]
        dt = self.basis.dt

        # w1_funcs: [H1, P, T]
        w1_funcs = self._weights_from_coeffs(self.w1_coeffs)

        # inner product <x_j, w_{k,j}>_L2  ≈ Σ_t x_j(t) w_{k,j}(t) dt
        # x -> [B, 1, P, T]
        # w1_funcs -> [1, H1, P, T]
        prod = x.unsqueeze(1) * w1_funcs.unsqueeze(0)  # [B, H1, P, T]
        inner_per_channel = prod.sum(dim=-1) * dt      # [B, H1, P]

        # sum over P dimensions
        h_raw = inner_per_channel.sum(dim=-1)  # [B, H1]

        h = self.activation(h_raw)
        return h

    # ----- decoder functional layer -----
    def _decode_last_layer(self, z):
        """
        z: [B, latent_dim]
        returns x_hat: [B, P, T]
        """
        # w_last_funcs: [P, latent_dim, T]
        w_last_funcs = self._weights_from_coeffs(self.w_last_coeffs)

        # x_hat_j(t) = Σ_k z_k * w_{j,k}(t)
        # z: [B, K], w_last_funcs: [P, K, T] -> [B, P, T]
        x_hat = torch.einsum("bk,pkt->bpt", z, w_last_funcs)
        return x_hat

    # ----- full forward pass -----
    def encode(self, x):
        """
        x: [batch, P, T]
        returns z: [batch, latent_dim]
        """
        h1 = self._encode_first_layer(x)
        z = self.encoder_mlp(h1)
        return z

    def decode(self, z):
        return self._decode_last_layer(z)

    def forward(self, x):
        z = self.encode(x)
        x_hat = self.decode(z)
        return x_hat, z


def functional_mse_loss(x, x_hat, dt):
    """
    Approximate functional L2 reconstruction loss:
    (1 / 2n) sum_i sum_j ∫ (x - x_hat)^2 dt
    x, x_hat: [B, P, T]
    """
    diff2 = (x - x_hat) ** 2  # [B, P, T]
    # integrate over t
    integral = diff2.sum(dim=-1) * dt  # [B, P]
    loss = 0.5 * integral.mean()       # average over batch and P
    return loss


In [None]:
# Example training loop
import torch
from torch.utils.data import DataLoader, TensorDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ----- fake data (replace with real functional data) -----
n_samples = 256
P = 3        # number of functional dimensions
T = 100      # number of time points
X_train = torch.randn(n_samples, P, T).to(device)

dataset = TensorDataset(X_train)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# ----- build model -----
latent_dim = 10
num_basis = 16
hidden_dims = [32, 32]   # first is size of functional hidden layer

model = FunctionalAutoencoder(
    P=P,
    T=T,
    latent_dim=latent_dim,
    num_basis=num_basis,
    hidden_dims=hidden_dims,
    device=device,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

dt = model.basis.dt

# ----- training -----
for epoch in range(50):
    total_loss = 0.0
    for (batch_x,) in loader:
        batch_x = batch_x.to(device)

        x_hat, z = model(batch_x)
        loss = functional_mse_loss(batch_x, x_hat, dt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_x.size(0)

    avg_loss = total_loss / len(dataset)
    print(f"Epoch {epoch:03d} | loss = {avg_loss:.6f}")

# After training, get embeddings
with torch.no_grad():
    embeddings = model.encode(X_train)    # [n_samples, latent_dim]
