In [1]:
# ============================================================
# Functional Autoencoder (FAE) + Latent VAR Forecasting
# Australia Fertility Data (1921–2015)
#
# Goal:
#  1) Train a functional autoencoder on curves up to year 2010
#  2) Treat the latent vectors h_t as a multivariate time series
#  3) Fit a VAR model to {h_t}
#  4) Forecast h_2011..h_2015
#  5) Decode forecast latents back to curves
#  6) Plot Real vs Forecast for 2010–2015
#
# Notes:
#  - Uses B-spline basis (or Fourier) evaluated on rescaled age grid.
#  - Projects curves to basis coefficients via weighted least squares (stable for splines).
#  - Trains AE in coefficient space, but reconstruction loss is measured in curve space.
#  - Standardizes curves by age (mean/std across years) for stable training and VAR.
#  - Limits BLAS/MKL threads to avoid kernel crashes on some systems.
# ============================================================

import os
# ---- thread limits to reduce kernel crashes (set BEFORE numpy/torch) ----
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

from statsmodels.tsa.api import VAR  # pip install statsmodels


# -------------------------
# Reproducibility
# -------------------------
def set_seed(seed=743):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


# -------------------------
# Basis construction
# -------------------------
class BasisFCBuilder:
    """
    Build basis matrix evaluated on tpts.
    Output shape: [n_time, n_basis]
    """
    def __init__(self, n_basis=20, basis_type="Bspline", custom_basis_fn=None, bspline_degree=3):
        self.n_basis = n_basis
        self.basis_type = basis_type
        self.basis_type_l = basis_type.lower()
        self.custom_basis_fn = custom_basis_fn
        self.bspline_degree = bspline_degree

    def build(self, tpts: torch.Tensor):
        if self.custom_basis_fn is not None:
            B = self.custom_basis_fn(tpts)
            if not torch.is_tensor(B):
                B = torch.tensor(B, dtype=torch.float32)
            return B.float()

        if self.basis_type_l == "fourier":
            return self._build_fourier(tpts)
        elif self.basis_type_l in ("bspline", "b-spline", "b_spline"):
            return self._build_bspline(tpts, degree=self.bspline_degree)
        else:
            raise ValueError("basis_type must be 'Fourier' or 'Bspline'.")

    def _build_fourier(self, tpts: torch.Tensor):
        t = tpts.flatten()
        t_min, t_max = t.min(), t.max()
        denom = (t_max - t_min).clamp_min(1e-8)
        tau = (t - t_min) / denom  # normalize to [0,1]

        n_time = tpts.shape[0]
        n_basis = self.n_basis
        device = t.device

        B = torch.zeros(n_time, n_basis, dtype=torch.float32, device=device)
        if n_basis > 0:
            B[:, 0] = 1.0

        k = 1
        idx = 1
        while idx < n_basis:
            B[:, idx] = torch.sin(2.0 * math.pi * k * tau)
            idx += 1
            if idx < n_basis:
                B[:, idx] = torch.cos(2.0 * math.pi * k * tau)
                idx += 1
            k += 1
        return B

    def _build_bspline(self, tpts: torch.Tensor, degree: int = 3):
        # Open-uniform B-spline basis on [0,1] built via Cox-de Boor recursion
        t = tpts.flatten()
        t_min, t_max = t.min(), t.max()
        denom = (t_max - t_min).clamp_min(1e-8)
        tau = (t - t_min) / denom
        tau_np = tau.detach().cpu().numpy()

        n_time = tau_np.shape[0]
        n_basis = self.n_basis
        p = degree

        if n_basis < p + 1:
            raise ValueError(f"Bspline: n_basis={n_basis} must be at least degree+1={p+1}.")

        n_int = max(n_basis - p - 1, 0)
        if n_int > 0:
            interior = np.linspace(0.0, 1.0, n_int + 2)[1:-1]
            knots = np.concatenate((np.zeros(p + 1), interior, np.ones(p + 1)))
        else:
            knots = np.concatenate((np.zeros(p + 1), np.ones(p + 1)))

        N = np.zeros((n_basis, n_time), dtype=np.float64)

        # degree 0
        for i in range(n_basis):
            left, right = knots[i], knots[i + 1]
            N[i, :] = np.where((tau_np >= left) & (tau_np < right), 1.0, 0.0)
        N[-1, tau_np == 1.0] = 1.0

        # elevate to degree p
        for k in range(1, p + 1):
            N_next = np.zeros_like(N)
            for i in range(n_basis):
                denom_left = knots[i + k] - knots[i]
                if denom_left > 0:
                    coeff_left = (tau_np - knots[i]) / denom_left
                    N_left = coeff_left * N[i, :]
                else:
                    N_left = 0.0

                denom_right = (knots[i + k + 1] - knots[i + 1]) if (i + 1) < n_basis else 0.0
                if denom_right > 0 and (i + 1) < n_basis:
                    coeff_right = (knots[i + k + 1] - tau_np) / denom_right
                    N_right = coeff_right * N[i + 1, :]
                else:
                    N_right = 0.0

                N_next[i, :] = N_left + N_right
            N = N_next

        return torch.tensor(N.T, dtype=torch.float32, device=tpts.device)  # [n_time, n_basis]


# -------------------------
# Weighted LS projection to basis coefficients
# -------------------------
def trapezoid_weights(tpts: torch.Tensor):
    """Trapezoidal weights W_vec, shape [n_time]."""
    t = tpts.flatten()
    dt = t[1:] - t[:-1]
    zero = torch.zeros(1, device=tpts.device, dtype=tpts.dtype)
    W = 0.5 * torch.cat([zero, dt]) + 0.5 * torch.cat([dt, zero])
    return W

def project_to_basis_coeffs(x: torch.Tensor, B: torch.Tensor, W_vec: torch.Tensor, ridge=1e-6):
    """
    Weighted least squares projection:
      c = (B^T W B + ridge I)^(-1) (B^T W x)
    x: [N, n_time]
    B: [n_time, n_basis]
    returns c: [N, n_basis]
    """
    BW = B * W_vec[:, None]                 # [T, M]
    G = B.T @ BW                            # [M, M]
    M = G.shape[0]
    G = G + ridge * torch.eye(M, device=G.device, dtype=G.dtype)
    Ginv = torch.linalg.pinv(G)             # stable inverse

    xW = x * W_vec[None, :]                 # [N, T]
    rhs = (B.T @ xW.T).T                    # [N, M]
    c = rhs @ Ginv.T                        # [N, M]
    return c

def reconstruct_from_coeffs(c: torch.Tensor, B: torch.Tensor):
    """x_hat = c B^T; c [N,M], B [T,M] -> [N,T]"""
    return c @ B.T


# -------------------------
# Functional autoencoder in coefficient space
# -------------------------
class FAECoef(nn.Module):
    """
    Encoder: coeffs -> latent h (K)
    Decoder: latent h -> coeffs_hat
    Curve reconstruction: x_hat = coeffs_hat @ B^T
    """
    def __init__(self, n_basis: int, n_rep: int, hidden: int = 64, nonlinear: bool = True):
        super().__init__()
        act = nn.Tanh() if nonlinear else nn.Identity()
        self.encoder = nn.Sequential(
            nn.Linear(n_basis, hidden),
            act,
            nn.Linear(hidden, n_rep)
        )
        self.decoder = nn.Sequential(
            nn.Linear(n_rep, hidden),
            act,
            nn.Linear(hidden, n_basis)
        )

    def forward(self, coef):
        h = self.encoder(coef)
        coef_hat = self.decoder(h)
        return coef_hat, h


def curve_smooth_penalty(x_hat):
    """Second finite difference penalty on curves x_hat [N,T]."""
    d2 = x_hat[:, 2:] - 2 * x_hat[:, 1:-1] + x_hat[:, :-2]
    return (d2**2).mean()


# -------------------------
# Train / encode / decode
# -------------------------
def train_fae(
    coef_train: torch.Tensor,
    x_train: torch.Tensor,
    B: torch.Tensor,
    n_rep: int = 7,
    hidden: int = 64,
    nonlinear: bool = True,
    epochs: int = 3000,
    batch_size: int = 16,
    lr: float = 2e-3,
    lamb: float = 1e-4,
    device: str = "cpu",
    log_every: int = 200
):
    """
    Train AE on (coef, x) pairs.
    Loss compares reconstructed curves to the true curves.
    """
    device = torch.device(device)
    coef_train = coef_train.to(device)
    x_train = x_train.to(device)
    B = B.to(device)

    ds = TensorDataset(coef_train, x_train)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True)

    model = FAECoef(n_basis=coef_train.shape[1], n_rep=n_rep, hidden=hidden, nonlinear=nonlinear).to(device)
    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    loss_fn = nn.MSELoss()

    history = {"epoch": [], "train_loss": []}

    for ep in range(1, epochs + 1):
        model.train()
        tot = 0.0
        nb = 0

        for coef_b, x_b in loader:
            coef_hat, _ = model(coef_b)
            x_hat = reconstruct_from_coeffs(coef_hat, B)

            loss = loss_fn(x_hat, x_b)
            if lamb > 0:
                loss = loss + lamb * curve_smooth_penalty(x_hat)

            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            tot += float(loss.detach().cpu())
            nb += 1

        if ep % log_every == 0:
            avg = tot / max(nb, 1)
            history["epoch"].append(ep)
            history["train_loss"].append(avg)
            print(f"epoch {ep:5d} | train_loss={avg:.6f}")

    return model, history


@torch.no_grad()
def encode_latents(model: nn.Module, coef: torch.Tensor, device="cpu"):
    model.eval()
    coef = coef.to(device)
    _, h = model(coef)
    return h.detach().cpu().numpy()

@torch.no_grad()
def decode_latents_to_curves(model: nn.Module, h_np: np.ndarray, B: torch.Tensor, device="cpu"):
    model.eval()
    h = torch.tensor(h_np, dtype=torch.float32, device=device)
    coef_hat = model.decoder(h)
    x_hat = reconstruct_from_coeffs(coef_hat, B.to(device))
    return x_hat.detach().cpu().numpy()


# -------------------------
# VAR latent forecasting
# -------------------------
def fit_var_forecast(h_train: np.ndarray, steps: int = 5, maxlags: int = 2):
    """
    h_train: [T_train, K]
    returns: h_fore [steps, K], fitted var results
    """
    h_train = np.asarray(h_train)
    if not np.isfinite(h_train).all():
        raise ValueError("Latents contain NaN/Inf. Training or scaling unstable.")

    var_model = VAR(h_train)
    res = var_model.fit(maxlags=maxlags, ic=None)  # keep it stable
    h_fore = res.forecast(h_train[-res.k_ar:], steps=steps)
    return h_fore, res


# -------------------------
# Plotting
# -------------------------
def plot_training_loss(history):
    if len(history["epoch"]) == 0:
        return
    plt.figure()
    plt.plot(history["epoch"], history["train_loss"], marker="o")
    plt.xlabel("Epoch")
    plt.ylabel("Train loss (MSE)")
    plt.title("FAE training loss")
    plt.tight_layout()
    plt.show()

def plot_real_vs_forecast(ages, years, x_real, forecast_years, x_fore, title):
    """
    x_real: [n_year, n_age] in original scale
    x_fore: [len(forecast_years), n_age] in original scale
    Plots 2010-2015 real and 2011-2015 forecast dashed.
    """
    year_to_idx = {y: i for i, y in enumerate(years)}
    plt.figure(figsize=(9, 5))

    # plot real 2010-2015
    for y in [2010, 2011, 2012, 2013, 2014, 2015]:
        if y in year_to_idx:
            plt.plot(ages, x_real[year_to_idx[y]], label=f"Real {y}")

    # plot forecasts
    for i, y in enumerate(forecast_years):
        plt.plot(ages, x_fore[i], "--", label=f"Forecast {y}")

    plt.xlabel("Age")
    plt.ylabel("Fertility rate")
    plt.title(title)
    plt.legend(ncol=2, fontsize=8)
    plt.tight_layout()
    plt.show()


# ============================================================
# MAIN
# ============================================================
if __name__ == "__main__":
    set_seed(743)
    torch.set_num_threads(1)

    # -------- Load data
    df = pd.read_csv("Australiafertility.csv")
    ages = df["age"].to_numpy()

    year_cols = [c for c in df.columns if c != "age"]
    years = np.array([int(c) for c in year_cols])

    rates = df[year_cols].to_numpy()        # [n_age, n_year]
    x_raw = rates.T.astype(float)           # [n_year, n_age]
    # If you have smoothed curves, use:
    # x_raw = smoothed.T.astype(float)

    # -------- Age grid -> [0,1]
    tpts_np = ages.astype(float)
    tpts_rescale = (tpts_np - tpts_np.min()) / (tpts_np.max() - tpts_np.min())
    tpts = torch.tensor(tpts_rescale, dtype=torch.float32)

    # -------- Basis
    n_basis = 50
    basis_type = "Bspline"   # or "Fourier"
    bspline_degree = 3

    basis_builder = BasisFCBuilder(n_basis=n_basis, basis_type=basis_type, bspline_degree=bspline_degree)
    B = basis_builder.build(tpts)           # [n_age, n_basis]
    W_vec = trapezoid_weights(tpts)         # [n_age]

    # -------- Convert to tensor
    x = torch.tensor(x_raw, dtype=torch.float32)  # [n_year, n_age]

    # -------- Standardize by age (mean/std over years)
    x_mean = x.mean(dim=0, keepdim=True)
    x_std = x.std(dim=0, keepdim=True).clamp_min(1e-6)
    x_scaled = (x - x_mean) / x_std

    # -------- Project each scaled curve to basis coefficients
    coef_all = project_to_basis_coeffs(x_scaled, B, W_vec, ridge=1e-6)  # [n_year, n_basis]

    # -------- Train on years <= 2010, forecast 2011-2015
    train_end_year = 2010
    forecast_years = [2011, 2012, 2013, 2014, 2015]

    year_to_idx = {y: i for i, y in enumerate(years)}
    idx_train = [year_to_idx[y] for y in years if y <= train_end_year]
    idx_fore_target = [year_to_idx[y] for y in forecast_years]

    coef_train = coef_all[idx_train]
    x_train = x_scaled[idx_train]  # train targets are curves (scaled)

    # -------- Train FAE
    device = "cpu"
    model, history = train_fae(
        coef_train=coef_train,
        x_train=x_train,
        B=B,
        n_rep=7,
        hidden=64,
        nonlinear=True,
        epochs=3000,
        batch_size=16,
        lr=2e-3,
        lamb=1e-4,
        device=device,
        log_every=200
    )

    plot_training_loss(history)

    # -------- Encode latent series in time order
    h_train = encode_latents(model, coef_train, device=device)  # [T_train, K]

    # -------- VAR forecast latents
    h_fore, var_res = fit_var_forecast(h_train, steps=len(forecast_years), maxlags=2)

    # -------- Decode forecast latents -> forecast scaled curves
    x_fore_scaled = decode_latents_to_curves(model, h_fore, B, device=device)  # [5, n_age]

    # -------- Unscale forecasts back to original scale
    x_fore = x_fore_scaled * x_std.numpy() + x_mean.numpy()  # broadcasts

    # -------- Plot real vs forecast (2010-2015)
    plot_real_vs_forecast(
        ages=ages,
        years=years,
        x_real=x_raw,
        forecast_years=forecast_years,
        x_fore=x_fore,
        title="Australia fertility: Real vs Forecast (FAE latents + VAR)"
    )


epoch   200 | train_loss=0.003987
epoch   400 | train_loss=0.002913
epoch   600 | train_loss=0.001894
epoch   800 | train_loss=0.001759
epoch  1000 | train_loss=0.001421
epoch  1200 | train_loss=0.001896
epoch  1400 | train_loss=0.000920
epoch  1600 | train_loss=0.000975
epoch  1800 | train_loss=0.000774
epoch  2000 | train_loss=0.000663
epoch  2200 | train_loss=0.000556
epoch  2400 | train_loss=0.000353
epoch  2600 | train_loss=0.001921
epoch  2800 | train_loss=0.000392
epoch  3000 | train_loss=0.000286


: 