In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import datetime
from colorama import just_fix_windows_console, init, Fore, Style
from typing import List, Callable
from prettytable import PrettyTable, ALL


# ------------------------------
# DEVICE (from device.py)
# ------------------------------
def get_device():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    return device


# ------------------------------
# MODELS (from models.py)
# ------------------------------
class PINN(nn.Module):
    def __init__(self, out_vars=5):
        super(PINN, self).__init__()
        self.fc1 = nn.Linear(1, 100)
        self.fc2 = nn.Linear(100, out_vars)
        self.activation = nn.Tanh()
        self.out_vars = out_vars

    def forward(self, t):
        """
        Forward pass for a scalar (or batch) time input.
        We assume t is a tensor of shape (N, 1) or a scalar tensor.
        The network output is modulated as:
            ŷ(t) = (1 - exp(-t)) * NN(t)
        to enforce ŷ(0) = 0.
        """
        if t.dim() == 0:
            t = t.unsqueeze(0)
        x = self.activation(self.fc1(t))
        out = self.fc2(x)
        return (1 - torch.exp(-t)) * out


def create_loss(
    model: PINN, ts: torch.Tensor, phi: Callable[[torch.Tensor], torch.Tensor]
):
    def compute_loss_vectorized():
        # ts: shape (N,) ; we keep ts as 1D tensor and unsqueeze when needed.
        ts_var = ts.clone().detach().requires_grad_(True)
        y_hat = model(ts_var.unsqueeze(1))

        def model_single(t):
            return model(t.unsqueeze(0)).squeeze(0)

        dy_dt = torch.vmap(torch.func.jacrev(model_single))(
            ts_var
        )  # shape: (N, out_vars)
        phi_y = torch.vmap(phi)(y_hat)
        residuals = dy_dt - phi_y
        loss = torch.mean(torch.sum(residuals**2, dim=1))
        return loss

    return compute_loss_vectorized


def save_model(model: PINN, name: str = "pinn_model"):
    current_date = datetime.date.today().strftime("%Y_%m_%d")
    filename = f"{name}_{current_date}_out_vars_{model.out_vars}.pt"
    torch.save(model.state_dict(), filename)
    print(f"Model saved to {filename}")


def train_model(phi, ts, device, lr=0.001, epochs=1000, out_vars=5):
    just_fix_windows_console()
    init()
    model = PINN(out_vars=out_vars).to(device)
    compute_loss_vectorized = create_loss(model, ts, phi)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    print(Fore.RED + "Starting training..." + Style.RESET_ALL)
    for epoch in range(1, epochs + 1):
        optimizer.zero_grad()
        loss_val = compute_loss_vectorized()
        loss_val.backward()
        optimizer.step()
        if epoch % 10 == 0:
            print(
                Fore.MAGENTA
                + f"Epoch {epoch}/{epochs}: Loss = {loss_val.item():.6f}"
                + Style.RESET_ALL
            )
        if loss_val.item() < 1e-4:
            print(
                Fore.GREEN
                + f"Stopping training at epoch {epoch} with loss = {loss_val.item():.6f}"
                + Style.RESET_ALL
            )
            break
    print(Fore.BLUE + "Training complete.\n" + Style.RESET_ALL)
    return model


def load_model(model, device, filename):
    model.load_state_dict(torch.load(filename, map_location=device))
    return model


def test_model(model: PINN, test_times: List[int], device=torch.device("cpu")):
    table = PrettyTable()
    table.field_names = ["t", "ŷ(t)"]
    table.hrules = ALL
    table.align["t"] = "l"
    table.align["ŷ(t)"] = "l"
    for t in test_times:
        t_tensor = torch.tensor(
            t, dtype=torch.float32, device=device, requires_grad=True
        )
        y_pred = model(t_tensor)
        y_pred_str = str(y_pred.detach().cpu().numpy().flatten())
        table.add_row([f"{t:6.2f}", y_pred_str])
    print(Fore.MAGENTA + table.get_string() + Style.RESET_ALL)


# ------------------------------
# ODE (from ode.py)
# ------------------------------
def createPhi(D, A, b):
    r, n = A.shape

    def phi1(y):
        x = y[:n]
        u = y[n:]
        m = torch.clamp(u + A @ x - b, min=0.0)
        top = -(D.unsqueeze(1) + A.t() @ m)
        bottom = m - u
        return torch.cat((top, bottom), 0).squeeze(1)

    def phi2(y):
        x = y[:n]
        u = y[n:]
        m = torch.clamp(u + A @ x - b, min=0.0)
        top = -(D + A.t() @ m)
        bottom = m - u
        return torch.cat((top, bottom), 0)

    phi = phi1 if r == 1 else phi2
    return phi


# ------------------------------
# TRAIN (from train.py)
# ------------------------------
device = get_device()


def example_1(ts: torch.Tensor, epochs: int = 1000):
    D = torch.tensor([-9.54, -8.16, -4.26, -11.43], dtype=torch.float32, device=device)
    A = torch.tensor([[3.18, 2.72, 1.42, 3.81]], dtype=torch.float32, device=device)
    b = torch.tensor([[7.81]], dtype=torch.float32, device=device)
    phi = createPhi(D, A, b)
    out_vars = sum(A.shape)
    model = train_model(phi, ts, device, lr=0.001, epochs=epochs, out_vars=out_vars)
    test_times = [0.0, 2.5, 5.0, 7.5, 10.0]
    test_model(model, test_times, device=device)
    save_model(model, f"example_1_pinn_{epochs}")


def example_2(ts: torch.Tensor, epochs: int = 1000):
    D = torch.tensor([-3, -1, -3], dtype=torch.float32, device=device)
    A = torch.tensor(
        [[2, 1, 1], [1, 2, 3], [2, 2, 1], [-1, 0, 0], [0, -1, 0], [0, 0, -1]],
        dtype=torch.float32,
        device=device,
    )
    b = torch.tensor([2, 5, 6, 0, 0, 0], dtype=torch.float32, device=device)
    phi = createPhi(D, A, b)
    out_vars = sum(A.shape)
    model = train_model(phi, ts, device, lr=0.001, epochs=epochs, out_vars=out_vars)
    test_times = [0.0, 2.5, 5.0, 7.5, 10.0]
    test_model(model, test_times, device=device)
    save_model(model, f"example_2_pinn_{epochs}")


def example_3(ts: torch.Tensor, epochs: int = 1000):
    D = torch.tensor([-1.0, -4.0, -3.0], dtype=torch.float32, device=device)
    A = torch.tensor(
        [
            [2.0, 2.0, 1.0],
            [1.0, 2.0, 2.0],
            [-1.0, 0.0, 0.0],
            [0.0, -1.0, 0.0],
            [0.0, 0.0, -1.0],
        ],
        dtype=torch.float32,
        device=device,
    )
    b = torch.tensor([4.0, 6.0, 0.0, 0.0, 0.0], dtype=torch.float32, device=device)
    phi = createPhi(D, A, b)
    out_vars = sum(A.shape)
    model = train_model(phi, ts, device, lr=0.001, epochs=epochs, out_vars=out_vars)
    test_times = [0.0, 2.5, 5.0, 7.5, 10.0]
    test_model(model, test_times, device=device)
    save_model(model, f"example_3_pinn_{epochs}")


# ------------------------------
# MAIN: Set up collocation points and run examples
# ------------------------------
ts = torch.linspace(0, 10, 128, dtype=torch.float32, device=device)
example_1(ts, epochs=1000)
example_2(ts, epochs=1000)
example_3(ts, epochs=1000)

Using device: cuda
Starting training...
Epoch 10/1000: Loss = 302.110291
Epoch 20/1000: Loss = 147.956573
Epoch 30/1000: Loss = 101.997650
Epoch 40/1000: Loss = 81.422707
Epoch 50/1000: Loss = 70.404388
Epoch 60/1000: Loss = 60.871868
Epoch 70/1000: Loss = 56.565331
Epoch 80/1000: Loss = 54.044228
Epoch 90/1000: Loss = 51.575279
Epoch 100/1000: Loss = 48.887615
Epoch 110/1000: Loss = 46.421116
Epoch 120/1000: Loss = 43.992611
Epoch 130/1000: Loss = 41.788727
Epoch 140/1000: Loss = 39.550995
Epoch 150/1000: Loss = 37.632568
Epoch 160/1000: Loss = 35.729435
Epoch 170/1000: Loss = 34.152168
Epoch 180/1000: Loss = 32.572163
Epoch 190/1000: Loss = 31.232347
Epoch 200/1000: Loss = 30.100338
Epoch 210/1000: Loss = 28.934067
Epoch 220/1000: Loss = 27.976254
Epoch 230/1000: Loss = 27.201008
Epoch 240/1000: Loss = 26.356358
Epoch 250/1000: Loss = 25.568775
Epoch 260/1000: Loss = 24.904524
Epoch 270/1000: Loss = 24.343601
Epoch 280/1000: Loss = 23.825855
Epoch 290/1000: Loss = 23.169077
Epoch 300