In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Define the PINN model
class PINN_Y88(nn.Module):
    def __init__(self):
        super(PINN_Y88, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
        self.lambda_Y = nn.Parameter(torch.tensor(0.02, dtype=torch.float32))

    def forward(self, t):
        return self.net(t)

# Physics-based loss: dY/dt = -lambda * Y
def physics_loss(model, t):
    t.requires_grad = True
    Y = model(t)
    dY_dt = torch.autograd.grad(
        outputs=Y,
        inputs=t,
        grad_outputs=torch.ones_like(Y),
        create_graph=True
    )[0]
    lambda_Y = model.lambda_Y
    return torch.mean((dY_dt + lambda_Y * Y)**2)

# Training data: time values and known Y(t)
lambda_true = 0.0065
t_train = torch.linspace(0, 200, 200).view(-1, 1).float()
Y_true = torch.exp(-lambda_true * t_train)

# Initialize model and optimizer
model = PINN_Y88()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
epochs = 5000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()

    Y_pred = model(t_train)

    # Data loss: match known decay curve
    data_loss = torch.mean((Y_pred - Y_true)**2)

    # Physics loss: match the ODE
    p_loss = physics_loss(model, t_train)

    # Total loss
    loss = data_loss + p_loss

    loss.backward()
    optimizer.step()

    if epoch % 500 == 0:
        print(f"Epoch {epoch}, Total Loss: {loss.item():.6f}, Lambda_Y: {model.lambda_Y.item():.5f}")

# Evaluate and plot
model.eval()
t_test = torch.linspace(0, 200, 100).view(-1, 1).float()
with torch.no_grad():
    Y_test = model(t_test).cpu().numpy()
    Sr_test = 1 - Y_test

# Print predictions
print("Predictions (Y-88, Sr-88):")
for i, time in enumerate(t_test.numpy().flatten()):
    print(f"Time: {time:.2f} days, Y-88: {Y_test[i, 0]:.4f}, Sr-88: {Sr_test[i, 0]:.4f}")

# t_np = t_test.numpy().flatten()
# plt.plot(t_np, Y_test[:, 0], label="Y-88 (PINN)")
# plt.plot(t_np, Sr_test[:, 0], label="Sr-88 (PINN)")
# plt.xlabel("Time (days)")
# plt.ylabel("Amount (mol)")
# plt.title("Y-88 Decay and Sr-88 Accumulation (PINN)")
# plt.legend()
# plt.grid(True)
# plt.show()


Epoch 0, Total Loss: 0.767286, Lambda_Y: 0.01900
Epoch 500, Total Loss: 0.000010, Lambda_Y: 0.00606
Epoch 1000, Total Loss: 0.000001, Lambda_Y: 0.00637
Epoch 1500, Total Loss: 0.000000, Lambda_Y: 0.00641
Epoch 2000, Total Loss: 0.000000, Lambda_Y: 0.00642
Epoch 2500, Total Loss: 0.000000, Lambda_Y: 0.00643
Epoch 3000, Total Loss: 0.000000, Lambda_Y: 0.00644
Epoch 3500, Total Loss: 0.000001, Lambda_Y: 0.00644
Epoch 4000, Total Loss: 0.000003, Lambda_Y: 0.00645
Epoch 4500, Total Loss: 0.000000, Lambda_Y: 0.00645
Predictions (Y-88, Sr-88):
Time: 0.00 days, Y-88: 1.0023, Sr-88: -0.0023
Time: 2.02 days, Y-88: 0.9866, Sr-88: 0.0134
Time: 4.04 days, Y-88: 0.9738, Sr-88: 0.0262
Time: 6.06 days, Y-88: 0.9619, Sr-88: 0.0381
Time: 8.08 days, Y-88: 0.9499, Sr-88: 0.0501
Time: 10.10 days, Y-88: 0.9372, Sr-88: 0.0628
Time: 12.12 days, Y-88: 0.9243, Sr-88: 0.0757
Time: 14.14 days, Y-88: 0.9118, Sr-88: 0.0882
Time: 16.16 days, Y-88: 0.8998, Sr-88: 0.1002
Time: 18.18 days, Y-88: 0.8882, Sr-88: 0.1118
T