In [3]:
import torch
from typing import List
from colorama import Fore, Style
from prettytable import PrettyTable, ALL


# ------------------------------
# DEVICE (inlined from device.py)
# ------------------------------
def get_device():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    return device


# ------------------------------
# MODEL and HELPER FUNCTIONS (inlined from models.py)
# ------------------------------
class PINN(torch.nn.Module):
    def __init__(self, out_vars=5):
        super(PINN, self).__init__()
        self.fc1 = torch.nn.Linear(1, 100)
        self.fc2 = torch.nn.Linear(100, out_vars)
        self.activation = torch.nn.Tanh()
        self.out_vars = out_vars

    def forward(self, t):
        """
        Forward pass for a scalar (or batch) time input.
        We assume t is a tensor of shape (N, 1) or a scalar tensor.
        The network output is modulated as:
            ŷ(t) = (1 - exp(-t)) * NN(t)
        to enforce ŷ(0) = 0.
        """
        if t.dim() == 0:
            t = t.unsqueeze(0)
        x = self.activation(self.fc1(t))
        out = self.fc2(x)
        return (1 - torch.exp(-t)) * out


def load_model(model, device, filename):
    model.load_state_dict(torch.load(filename, map_location=device))
    return model


def test_model(model: PINN, test_times: List[int], device=torch.device("cpu")):
    table = PrettyTable()
    table.field_names = ["t", "ŷ(t)"]
    table.hrules = ALL  # horizontal line between rows
    table.align["t"] = "l"
    table.align["ŷ(t)"] = "l"
    for t in test_times:
        t_tensor = torch.tensor(
            t, dtype=torch.float32, device=device, requires_grad=True
        )
        y_pred = model(t_tensor)
        # Flatten the tensor and convert to string for display.
        y_pred_str = str(y_pred.detach().cpu().numpy().flatten())
        table.add_row([f"{t:6.2f}", y_pred_str])
    print(Fore.MAGENTA + table.get_string() + Style.RESET_ALL)


# ------------------------------
# MAIN: Reload and test the model
# ------------------------------
dev = get_device()

# Create an instance of PINN with 8 output variables and move it to the device.
model = PINN(out_vars=8).to(dev)

# Set the filename for the saved model. Adjust as needed.
filename = "../pinn_1000_2025_02_20.pt"

# Reload the saved model.
model_reloaded = load_model(model, dev, filename)

# Define test times and evaluate the model.
test_times = [0.0, 2.5, 5.0, 7.5, 10.0, 10.2]
test_model(model_reloaded, test_times, device=dev)

Using device: cuda
[35m+--------+-------------------------------------------------------------------------+
| t      | ŷ(t)                                                                    |
+--------+-------------------------------------------------------------------------+
|   0.00 | [ 0.  0.  0.  0. -0.  0.  0. -0.]                                       |
+--------+-------------------------------------------------------------------------+
|   2.50 | [-0.51622915  1.6665292   1.7863696   0.6442201   0.75406796  1.0715411 |
|        |  -0.00254623 -0.00289825]                                               |
+--------+-------------------------------------------------------------------------+
|   5.00 | [-0.3519519   1.5629302   1.6229733   0.9386521   1.0295663   1.7156954 |
|        |  -0.00424191  0.00363769]                                               |
+--------+-------------------------------------------------------------------------+
|   7.50 | [-0.21428967  1.4687133   1.5