In [None]:
# Option 1 (simplest): Pass your own basis matrices directly
import numpy as np
import torch

def custom_basis_polysin(tpts):
    t = tpts.detach().cpu().numpy().flatten()
    B = np.column_stack([
        np.ones_like(t),
        t,
        t**2,
        np.sin(2*np.pi*t),
        np.cos(2*np.pi*t),
    ])  # shape [n_time, n_basis]
    return torch.tensor(B, dtype=torch.float32)

#basis_fc_project = custom_basis_polysin(tpts).to(device)  # [n_time, n_basis]
#basis_fc_revert  = custom_basis_polysin(tpts).to(device)

In [None]:
# Option 2: Modify build_basis_fc to accept a custom callable
from skfda import representation

def build_basis_fc(tpts, n_basis=20, basis_type="Bspline", custom_basis_fn=None):
    """
    Returns basis_fc of shape [n_time, n_basis].
    If custom_basis_fn is provided, it overrides basis_type.
    """
    if custom_basis_fn is not None:
        B = custom_basis_fn(tpts)  # must return torch tensor [n_time, n_basis]
        if not torch.is_tensor(B):
            B = torch.tensor(B, dtype=torch.float32)
        return B.float()

    # otherwise, use built-in choices (Bspline/Fourier)
    t = tpts.flatten().detach().cpu().numpy()
    t_min, t_max = float(t.min()), float(t.max())

    if basis_type == "Bspline":
        basis = representation.basis.BSplineBasis(n_basis=n_basis, order=4)
        eval_ = basis(tpts.detach().cpu().numpy(), derivative=0)[:, :, 0]  # [n_time, n_basis]
    elif basis_type == "Fourier":
        basis = representation.basis.Fourier([t_min, t_max], n_basis=n_basis)
        eval_ = basis(tpts.detach().cpu().numpy(), derivative=0)[:, :, 0]
    else:
        raise ValueError("basis_type must be 'Bspline' or 'Fourier'")

    return torch.from_numpy(eval_).float()

#basis_fc_project = build_basis_fc(tpts, custom_basis_fn=custom_basis_polysin).to(device)
#basis_fc_revert  = build_basis_fc(tpts, custom_basis_fn=custom_basis_polysin).to(device)

In [None]:
# Option 3: Use a data-driven basis (PCA basis from the curves)
import numpy as np
import torch

def pca_basis_from_data(x, n_basis):
    """
    x: torch tensor [n_subject, n_time]
    returns basis [n_time, n_basis] using right singular vectors
    """
    X = x.detach().cpu().numpy()
    X = X - X.mean(axis=0, keepdims=True)

    # SVD of X: X = U S V^T, V is [n_time, n_time]
    U, S, Vt = np.linalg.svd(X, full_matrices=False)
    B = Vt[:n_basis].T  # [n_time, n_basis]
    return torch.tensor(B, dtype=torch.float32)

#basis_fc_project = pca_basis_from_data(x, n_basis_project).to(device)
#basis_fc_revert  = basis_fc_project.clone()  # often same basis

Important practical notes:

1) Shapes must match your project() / revert()

If your clean code uses [n_time, n_basis], stick to that everywhere.

2) Orthonormal vs non-orthonormal basis

Reconstruction with coef @ B.T works for any basis.

But the optional “feature loss” MSE(feature, coef) is only meaningful if the basis is orthonormal (as you already commented in your old code). Totally fine to ignore it.

3) Want to include a custom smoothing penalty?

If your basis isn’t “ordered” (like wavelets), your second-difference penalty may not be meaningful. You can replace it with:

- L2 penalty on coefficients

- penalty based on derivative matrices

- group lasso style penalties, etc.