In [1]:
import numpy as np
from scipy import io
import torch
import torch
import torch.nn as nn

# Utils

In [5]:
def get_data(ntrain, ntest):
    np.random.seed(42) # fixed seed
    N = ntrain + ntest
    data = io.loadmat("../data/dataGRFshorttspan.mat")

    def get_input_indices(N):
        indices_1 = np.arange(0, 2**13, 2**6)
        index_array = np.zeros((N, 2**7), dtype=int) 
        for n in range(N):
            index_array[n] = indices_1
        return index_array


    input_index_array = get_input_indices(N)
    input_func_data  = data["output"][:N, :].astype(np.float32)
    x_data = input_index_array * (1.0 / 2**13)
    a_data = input_func_data[np.arange(N)[:, None], input_index_array]

    from scipy.interpolate import interp1d
    data_grid = np.linspace(0, 1, 8193)
    y_data = np.random.uniform(0, 1, (N, 1000, 2))
    charateristic = (y_data[:,:,0] - y_data[:,:,1]) % 1

    u_data = np.zeros((N, 1000, 1))
    for idx in range(N):
        f_data = data["output"][idx,:]
        f_data = np.append(f_data, f_data[0]) 
        f = interp1d(data_grid, f_data, kind='linear', fill_value='periodic')
        u_data[idx,:,0] = f(charateristic[idx,:])

    train_data = (x_data[:ntrain,:,None], a_data[:ntrain,:,None], y_data[:ntrain], u_data[:ntrain]) 
    test_data  = (x_data[ntrain:,:,None], a_data[ntrain:,:,None], y_data[ntrain:], u_data[ntrain:])
    return train_data, test_data
    
def inverse_time_decay(epoch, initial_lr, decay_factor, decay_epochs):
    """Inverse time decay function."""
    return initial_lr / (1 + decay_factor * (epoch / decay_epochs))

# Training function
def train_model(model, train_data, test_data, criterion, optimizer, num_epochs, device):
    x_train, a_train, y_train, u_train = train_data
    model.train()  # Set the model to training mode
    train_losses = []
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for param_group in optimizer.param_groups:
            param_group['lr'] = inverse_time_decay(epoch, initial_lr=0.001, decay_factor=0.5, decay_epochs=100000)
        optimizer.zero_grad()  # Clear previous gradients
        outputs = model(x_train, a_train, y_train)
        loss = criterion(outputs, u_train)
        loss.backward()
        optimizer.step()
        
        epoch_loss = loss.item()
        train_losses.append(epoch_loss)
        if epoch % 5000 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
            validate_model(model, test_data, criterion, device=device)
    return train_losses

def validate_model(model, test_data, criterion, device):
    x_test, a_test, y_test, u_test = test_data
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    rel_l2_loss = 0.0
    with torch.no_grad():  # Disable gradient calculation during validation
        actual_batch_size = u_test.shape[0]
        outputs = model(x_test, a_test, y_test)
        loss = criterion(outputs, u_test)
        val_loss += loss.item()
        diff_norms = torch.norm(outputs.reshape(actual_batch_size,-1) - u_test.reshape(actual_batch_size,-1), dim=1)
        y_norms = torch.norm(u_test.reshape(actual_batch_size,-1), dim=1)
        rel_l2_loss += torch.mean(diff_norms/y_norms)   
    
    print(f"Validation Loss: {val_loss:.5f}, mean Relative L2 Loss: {rel_l2_loss:.5f}")

# Model Description


In [6]:
class StochasticFeatures(nn.Module):
    def __init__(self, kernel, x_dim, num_features, prior = torch.nn.init.normal_):
        super().__init__()
        self.x_dim = x_dim
        self.num_features = num_features
        self.kernel = kernel
        self.w = nn.Linear(x_dim, num_features, bias=False)
        prior(self.w.weight)
        self.w.weight.requires_grad = False

    def forward(self, x):
        if self.kernel == "trig":
            return  torch.sin(self.w(x))
        elif self.kernel == "trig symm":
            return  torch.cat([torch.sin(self.w(x)), torch.cos(self.w(x))], dim=-1)
        elif self.kernel == "relu":
            return  torch.nn.functional.relu(self.w(x))
        else:
            raise NotImplementedError("Fourier kernel is not implemented yet")


class ProjectionNets(nn.Module):
    def __init__(self, x_dim, num_learned_basis, num_fixed_basis, learning_hidden_dim, activation = torch.nn.ReLU(), basis=None):
        """
        x_dim: dimension of the input (10)
        num_learned_basis: number of learned basis
        learning_hidden_dim: width of hidden layers for the FNN
        activation: activation function
        num_fixed_basis: number of fixed basis
        basis: a (B, input_dim) tensor of Fourier basis vectors. If None, we initialize randomly.
        """
        super().__init__()
        self.x_dim = x_dim
        self.num_learned_basis = num_learned_basis
        self.learning_hidden_dim = learning_hidden_dim
        self.activation = activation
        self.num_fixed_basis = num_fixed_basis
        self.basis = basis

        # --- 1) define the FNN for the first A outputs ---
        layers = [x_dim] + learning_hidden_dim + [num_learned_basis]
        modules = []
        for i in range(len(layers) - 1):
            modules.append(nn.Linear(layers[i], layers[i+1]))
            # add activation after every hidden layer, but not after the last Linear
            if i < len(layers) - 2:
                modules.append(activation)
        self.fnn = nn.Sequential(*modules)

        # --- 2) prepare the fixed basis ---
        if basis is None:
            self.basis = StochasticFeatures(kernel="trig", x_dim=x_dim, num_features=num_fixed_basis)
        else:
            self.basis = basis

    def forward(self, x):
        """
        x: tensor of shape (..., x_dim)
        returns: tensor of shape (..., num_learned_basis + num_fixed_basis)
        """
        learned_out = self.fnn(x)              
        fixed_out = self.basis(x)
        return torch.cat([learned_out, fixed_out], dim=-1)

def periodic(x):
    y = 2 * torch.pi * x
    return torch.cat(
        [torch.cos(y), torch.sin(y), torch.cos(2 * y), torch.sin(2 * y)], -1
    )


class DeepONet(nn.Module):
    def __init__(self):
        super().__init__()
        x_dim = 1
        y_dim = 2 * 4

        self.branch_nets = nn.Sequential(
            nn.Linear(128, 128), # coressponds to xi_i^k
            nn.Tanh(),
            nn.Linear(128, 128), # coressponds to xi_i^k
            nn.Tanh(),
            nn.Linear(128, 128), # coressponds to xi_i^k
            nn.Tanh(),
            nn.Linear(128, 128), # coressponds to c_i^k 
        )

        self.trunk_net = nn.Sequential(
            nn.Linear(y_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
        )
        
    def forward(self, x, u, y):
        """
        Inputs:
          x: Tensor of shape [B, L, 1], where B = batch size, L = number of integration points per sample.
          u: Tensor of shape [B, L] (function values corresponding to x).
        """
        u_flattened = torch.flatten(u, start_dim=-2)
        branch_results = self.branch_nets(u_flattened)
        z = periodic(y)
        trunk_results = self.trunk_net(z)
        results = (branch_results.unsqueeze(1) * trunk_results).sum(dim=-1, keepdim=True)
        return results
    
    
class BelNet(nn.Module):
    def __init__(self, num_learned_basis = 50, num_fixed_basis = 50, basis = None):
        super().__init__()
        x_dim = 1
        y_dim = 2 * 4
        self.projection_nets = ProjectionNets(x_dim, num_learned_basis, num_fixed_basis, learning_hidden_dim=[64,64,64,64], activation=torch.nn.ReLU(), basis=basis)
        self.branch_nets = nn.Sequential(
            nn.Linear(num_learned_basis + num_fixed_basis, 100), # coressponds to xi_i^k
            nn.Tanh(),
            nn.Linear(100, 100), # coressponds to xi_i^k
            nn.Tanh(),
            nn.Linear(100, 100), # coressponds to xi_i^k
            nn.Tanh(),
            nn.Linear(100, 128), # coressponds to c_i^k 
        )

        self.trunk_net = nn.Sequential(
            nn.Linear(y_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
        )
        
    def forward(self, x, u, y):
        """
        Inputs:
          x: Tensor of shape [B, L, 1], where B = batch size, L = number of integration points per sample.
          u: Tensor of shape [B, L] (function values corresponding to x).
        """
        B, L, _ = x.shape  # B: batch size, L: number of points
        psi_outputs = self.projection_nets(x) # shape: [B, L, R]
        psi_transpose = psi_outputs.transpose(1, 2)  # shape: [B, R, L]
        proj_integration = (torch.bmm(psi_transpose, u) / float(L)).squeeze(-1) # shape: [B, R]
        branch_results = self.branch_nets(proj_integration) # shape: [B, N]
        z = periodic(y)
        trunk_results = self.trunk_net(z)
        results = (branch_results.unsqueeze(1) * trunk_results).sum(dim=-1, keepdim=True)
        return results

In [7]:
import torch
from torch import nn
import numpy as np
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

'''
    prepare data for deeponet
'''
np.random.seed(42) # fixed seed
train_data, test_data = get_data(1000,200)
train_data= [torch.from_numpy(d).float().to(device) for d in train_data]
test_data = [torch.from_numpy(d).float().to(device) for d in test_data]
(x_train, a_train, y_train, u_train) = train_data
(x_test, a_test, y_test, u_test)     = test_data
print("train shape:", x_train.shape, a_train.shape, y_train.shape, u_train.shape)
print("test shape:", x_test.shape, a_test.shape, y_test.shape, u_test.shape)
    

np.random.seed(42)
torch.manual_seed(42)

for example_num in range(4):

    print(f"Using examples {example_num}")
    if example_num == 0:
        model = DeepONet().to(device)
    elif example_num == 1:
        model = BelNet(num_learned_basis = 50, num_fixed_basis = 0)
    elif example_num == 2:
        low = 1
        high = 30
        num = high - low +1
        def prior(w):
            with torch.no_grad():
                print(w.shape)
                w.copy_(torch.linspace(low*2*torch.pi, high*2*torch.pi, num).view(num, 1))
        basis = StochasticFeatures(kernel="trig symm", x_dim=1, num_features=num, prior = prior)
        model = BelNet(num_learned_basis = 50, num_fixed_basis = num*2, basis = basis)
    elif example_num == 3:
        low = 1
        high = 30
        num = high - low +1
        def prior(w):
            with torch.no_grad():
                print(w.shape)
                w.copy_(torch.linspace(low*2*torch.pi, high*2*torch.pi, num).view(num, 1))
        basis = StochasticFeatures(kernel="trig symm", x_dim=1, num_features=num, prior = prior)
        model = BelNet(num_learned_basis = 0, num_fixed_basis = num*2, basis = basis)
        
    model = model.to(device)

    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {count_parameters(model)} parameters")
    criterion = nn.MSELoss()  
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Train the model
    train_losses = train_model(model, train_data, test_data, criterion, optimizer, num_epochs=500000, device=device)
    print(f"Example {example_num} finished training, final error is: ")
    validate_model(model, test_data, criterion, device=device)

    # save model
    torch.save(model.state_dict(), f'model_save_{example_num}.pth')


train shape: torch.Size([1000, 128, 1]) torch.Size([1000, 128, 1]) torch.Size([1000, 1000, 2]) torch.Size([1000, 1000, 1])
test shape: torch.Size([200, 128, 1]) torch.Size([200, 128, 1]) torch.Size([200, 1000, 2]) torch.Size([200, 1000, 1])
Using examples 0
Model has 100224 parameters
Epoch [1/500000], Loss: 0.3985
Validation Loss: 0.35323, mean Relative L2 Loss: 1.01816
Epoch [5001/500000], Loss: 0.0194
Validation Loss: 0.02008, mean Relative L2 Loss: 0.23530
Epoch [10001/500000], Loss: 0.0214
Validation Loss: 0.02327, mean Relative L2 Loss: 0.25198
Epoch [15001/500000], Loss: 0.0131
Validation Loss: 0.01380, mean Relative L2 Loss: 0.18625
Epoch [20001/500000], Loss: 0.0117
Validation Loss: 0.01268, mean Relative L2 Loss: 0.17623
Epoch [25001/500000], Loss: 0.0126
Validation Loss: 0.01353, mean Relative L2 Loss: 0.18349
Epoch [30001/500000], Loss: 0.0084
Validation Loss: 0.00914, mean Relative L2 Loss: 0.14280
Epoch [35001/500000], Loss: 0.0084
Validation Loss: 0.00923, mean Relative 



Epoch [5001/500000], Loss: 0.0353
Validation Loss: 0.03562, mean Relative L2 Loss: 0.32972
Epoch [10001/500000], Loss: 0.0352
Validation Loss: 0.03556, mean Relative L2 Loss: 0.32938
Epoch [15001/500000], Loss: 0.0274
Validation Loss: 0.02890, mean Relative L2 Loss: 0.29040
Epoch [20001/500000], Loss: 0.0299
Validation Loss: 0.03059, mean Relative L2 Loss: 0.30012
Epoch [25001/500000], Loss: 0.0201
Validation Loss: 0.02072, mean Relative L2 Loss: 0.24001
Epoch [30001/500000], Loss: 0.0367
Validation Loss: 0.03739, mean Relative L2 Loss: 0.33671
Epoch [35001/500000], Loss: 0.0165
Validation Loss: 0.01698, mean Relative L2 Loss: 0.21231
Epoch [40001/500000], Loss: 0.0168
Validation Loss: 0.01723, mean Relative L2 Loss: 0.21316
Epoch [45001/500000], Loss: 0.0363
Validation Loss: 0.03786, mean Relative L2 Loss: 0.33648
Epoch [50001/500000], Loss: 0.0169
Validation Loss: 0.01747, mean Relative L2 Loss: 0.21746
Epoch [55001/500000], Loss: 0.0341
Validation Loss: 0.03499, mean Relative L2 Los