In [3]:
import torch
from torch import nn
from core.neural_network import FeedForwardNN
from schrodinger_box.time_independent_1d import Schrodinger1DTimeIndependentPINN, LossTISE1D

In [22]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [23]:
L = 1.0 # Length of the infinite potential well

In [27]:
def train_tise_example(
    L: float = 1.0,
    n_epochs: int = 5000,
    N_samples: int = 256,
    hidden_layers: int = 3,
    width: int = 64,
    lr: float = 1e-3,
    device: str = device,
):
    device = torch.device(device)

    #Build the model
    model = FeedForwardNN(
        in_dim=1,
        out_dim=1,
        hidden_layers=hidden_layers,
        width=width,
        activation_func=nn.Tanh,
    ).to(device)
    pinn = Schrodinger1DTimeIndependentPINN(model, L=L, E_init=5.0).to(device)

    # Define loss
    total_loss_fn = LossTISE1D()

    # Define optimizer
    optimizer = torch.optim.Adam(pinn.parameters(), lr=lr)


    # Train the model
    for epoch in range(1, n_epochs + 1):
        # Sample collocation points x âˆˆ [0, L]
        x_batch = torch.rand(N_samples, 1, device=device) * L

        optimizer.zero_grad()
        loss = total_loss_fn(pinn, x_batch)
        loss.backward()
        optimizer.step()

        if epoch % 500 == 0 or epoch == 1:
            with torch.no_grad():
                E_learned = pinn.energy.item()
                print(
                    f"Epoch {epoch:5d} | Loss = {loss.item():.3e} | "
                    f"E_learned = {E_learned:.5f}"
                )

    # After training, return model and energy
    return pinn


In [28]:
pinn = train_tise_example(device=device)

Epoch     1 | Loss = 3.146e-02 | E_learned = 5.00100
Epoch   500 | Loss = 1.622e-07 | E_learned = 5.01752
Epoch  1000 | Loss = 1.834e-07 | E_learned = 5.01752
Epoch  1500 | Loss = 9.119e-08 | E_learned = 5.01752
Epoch  2000 | Loss = 7.638e-07 | E_learned = 5.01753
Epoch  2500 | Loss = 2.581e-08 | E_learned = 5.01835
Epoch  3000 | Loss = 1.912e-08 | E_learned = 5.02071
Epoch  3500 | Loss = 4.237e-07 | E_learned = 5.02213
Epoch  4000 | Loss = 7.040e-08 | E_learned = 5.02534
Epoch  4500 | Loss = 4.281e-08 | E_learned = 5.03838
Epoch  5000 | Loss = 1.190e-08 | E_learned = 5.03839
