In [6]:
import torch
from torch import nn
import numpy as np
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from models import DeepONet, BelNet, StochasticFeatures
from utils import get_data, train_model, validate_model

'''
    prepare data for deeponet
'''
np.random.seed(42) # fixed seed
train_data, test_data = get_data(1000,200)
train_data= [torch.from_numpy(d).float().to(device).unsqueeze(-1) for d in train_data]
test_data = [torch.from_numpy(d).float().to(device).unsqueeze(-1) for d in test_data]
(x_train, a_train, y_train, u_train) = train_data
(x_test, a_test, y_test, u_test)     = test_data
print("train shape:", x_train.shape, a_train.shape, y_train.shape, u_train.shape)
print("test shape:", x_test.shape, a_test.shape, y_test.shape, u_test.shape)
    

np.random.seed(42)
torch.manual_seed(42)

for example_num in range(4):

    print(f"Using examples {example_num}")
    if example_num == 0:
        model = DeepONet().to(device)
    elif example_num == 1:
        model = BelNet(num_learned_basis = 50, num_fixed_basis = 0)
    elif example_num == 2:
        low = 1
        high = 30
        num = high - low +1
        def prior(w):
            with torch.no_grad():
                print(w.shape)
                w.copy_(torch.linspace(low*2*torch.pi, high*2*torch.pi, num).view(num, 1))
        basis = StochasticFeatures(kernel="trig symm", x_dim=1, num_features=num, prior = prior)
        model = BelNet(num_learned_basis = 50, num_fixed_basis = num*2, basis = basis)
    elif example_num == 3:
        low = 1
        high = 30
        num = high - low +1
        def prior(w):
            with torch.no_grad():
                print(w.shape)
                w.copy_(torch.linspace(low*2*torch.pi, high*2*torch.pi, num).view(num, 1))
        basis = StochasticFeatures(kernel="trig symm", x_dim=1, num_features=num, prior = prior)
        model = BelNet(num_learned_basis = 0, num_fixed_basis = num*2, basis = basis)
        
    model = model.to(device)

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {count_parameters(model)} parameters")
    criterion = nn.MSELoss()  
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Train the model
    train_losses = train_model(model, train_data, test_data, criterion, optimizer, num_epochs=500000, device=device)
    print(f"Example {example_num} finished training, final error is: ")
    validate_model(model, test_data, criterion, device=device)

    # save model
    torch.save(model.state_dict(), f'model_save_{example_num}.pth')


train shape: torch.Size([1000, 128, 1]) torch.Size([1000, 128, 1]) torch.Size([1000, 128, 1]) torch.Size([1000, 128, 1])
test shape: torch.Size([200, 128, 1]) torch.Size([200, 128, 1]) torch.Size([200, 128, 1]) torch.Size([200, 128, 1])
Using examples 0
Model has 99712 parameters
Epoch [1/500000], Loss: 0.3495
Validation Loss: 0.16244, mean Relative L2 Loss: 0.68947
Epoch [5001/500000], Loss: 0.0066
Validation Loss: 0.00720, mean Relative L2 Loss: 0.12422
Epoch [10001/500000], Loss: 0.0112
Validation Loss: 0.01247, mean Relative L2 Loss: 0.17381
Epoch [15001/500000], Loss: 0.0031
Validation Loss: 0.00358, mean Relative L2 Loss: 0.07756
Epoch [20001/500000], Loss: 0.0031
Validation Loss: 0.00366, mean Relative L2 Loss: 0.07786
Epoch [25001/500000], Loss: 0.0022
Validation Loss: 0.00260, mean Relative L2 Loss: 0.06122
Epoch [30001/500000], Loss: 0.0024
Validation Loss: 0.00288, mean Relative L2 Loss: 0.06685
Epoch [35001/500000], Loss: 0.0107
Validation Loss: 0.01138, mean Relative L2 Lo