In [None]:
#Import packages
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
# from pyDOE2 import lhs
from torch.autograd import Variable
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import matplotlib.pyplot as plt
import copy
from scipy import stats
import time
print(device)

### Solve the system of shallow water equations to generate simulated data

$ h_t + hu_x = 0$

$ (hu)_t + \left(hu^2 + \frac{1}{2}gh^2\right)_x + g(z_b)_x=0 $

In [None]:
# Parameters
g = 9.81 #gravitational constant
nx = 200 #number of spatial points
L = 10.0 #Initial height if water
dx = L / nx #spatial grid size
dt = 0.001 #temporal grid size
nt = 1000 #number of timesteps

x = np.linspace(0, L, nx)

# Solution array: (2, nx, nt)
U = np.zeros((2, nx, nt))

# Initial conditions
h0 = np.ones(nx)
h0[int(nx/2):] = 0.5
u0 = np.zeros(nx)
hu0 = h0 * u0

U[0, :, 0] = h0
U[1, :, 0] = hu0

def flux(U_col):
    h = U_col[0]
    hu = U_col[1]
    u = hu / h if h > 1e-6 else 0.0
    return np.array([
        hu,
        hu * u + 0.5 * g * h**2
    ])

def rusanov_flux(U_left, U_right):
    f_left = flux(U_left)
    f_right = flux(U_right)

    h_L = U_left[0]
    h_R = U_right[0]
    u_L = U_left[1] / h_L if h_L > 1e-6 else 0.0
    u_R = U_right[1] / h_R if h_R > 1e-6 else 0.0
    c_L = np.sqrt(g * h_L)
    c_R = np.sqrt(g * h_R)

    s_max = max(abs(u_L) + c_L, abs(u_R) + c_R)

    return 0.5 * (f_left + f_right) - 0.5 * s_max * (U_right - U_left)

# Time stepping
for n in range(1, nt):
    U_new = U[:, :, n - 1].copy()

    for i in range(1, nx - 1):
        U_L = U[:, i - 1, n - 1]
        U_C = U[:, i, n - 1]
        U_R = U[:, i + 1, n - 1]

        F_plus = rusanov_flux(U_C, U_R)
        F_minus = rusanov_flux(U_L, U_C)

        U_new[:, i] -= dt / dx * (F_plus - F_minus)

    # Simple zero-gradient BCs
    U_new[:, 0] = U_new[:, 1]
    U_new[:, -1] = U_new[:, -2]

    U[:, :, n] = U_new

In [None]:
t_index = 0
tstart=1
h = U[0, :, tstart:]
hu = U[1, :, tstart:]
u = np.where(h > 1e-6, hu / h, 0.0)

min_h = np.min(U[0,:,tstart:],axis = 0)
max_h = np.max(U[0,:,tstart:],axis = 0)

min_hu = np.min(U[1,:,tstart:],axis = 0)
max_hu = np.max(U[1,:,tstart:],axis = 0)

min_u = np.min(u,axis = 0)
max_u = np.max(u,axis = 0)


N_x = nx-1
N_t = nt-tstart
std_noise = 0.00 #standard deviation of noise if applicable (if using noisy dataset)
var = 0
h_clean_norm = (U[0,:,tstart:] - min_h)/(max_h - min_h) 
hu_clean_norm = (U[1,:,tstart:] - min_hu)/(max_hu - min_hu)
h_noisy = U[0,:,tstart:] + np.random.normal(0,std_noise,((N_x+1),N_t))
hu_noisy = U[1,:,tstart:]  + np.random.normal(0,std_noise,((N_x+1),N_t))
h = (h_noisy - min_h)/(max_h - min_h) #+ np.random.normal(0,0.01,((N_x+1),N_t))
hu = (hu_noisy - min_hu)/(max_hu - min_hu) #+ np.random.normal(0,0.01,((N_x+1),N_t))


### Create the training and validation dataset
We first select $N_s$ random spatial points (note that $N_s<<nx$) and $steps$ random temporal points ($steps\leq nt$) for building the training+validation set. Further the training-validation split is done based on a 80-20 split. 
The test set consists of all the points in the entire spatio-temporal domain.

In [None]:
u_data = h.reshape((N_x+1),N_t) #+ np.random.normal(0,0.01,((N_x+1),N_t))
b_data = hu.reshape((N_x+1),N_t) #+ np.random.normal(0,0.01,((N_x+1),N_t))
u0_data = h[:,0].reshape((N_x+1),1)
b0_data = hu[:,0].reshape((N_x+1),1)

t_bounds = np.linspace(0,1,N_t)
x_bounds = np.linspace(0,L,N_x+1)/L #Normalizing L=10

t_data = np.tile(t_bounds,((N_x+1),1))
# print(t_data.shape)

x_data = x_bounds.reshape(-1,1)
x_data = np.tile(x_data, (1, N_t))

# print(x_data.shape)
N_s = 100
steps = 800 #int(0.8*(N_t-std_noise))
print('N_s', N_s, 'and Nt',steps)
idx_s = np.random.choice(x_data.shape[0], N_s, replace = False)
idx_t = np.random.choice(N_t,steps, replace = False)
# print('index chosen for time is',idx_t)

h_max = np.tile(max_h,(N_x+1)).squeeze().reshape((N_x+1),N_t)
hu_max = np.tile(max_hu,(N_x+1)).squeeze().reshape((N_x+1),N_t)
h_min = np.tile(min_h,(N_x+1)).squeeze().reshape((N_x+1),N_t)
hu_min = np.tile(min_hu,(N_x+1)).squeeze().reshape((N_x+1),N_t)
u_max = np.tile(max_u,(N_x+1)).squeeze().reshape((N_x+1),N_t)
u_min = np.tile(min_u,(N_x+1)).squeeze().reshape((N_x+1),N_t)


t_meas = t_data[idx_s, :]
t_meas = t_meas[:, idx_t].reshape((-1,1))
x_meas = x_data[idx_s, :]
x_meas = x_meas[:, idx_t].reshape((-1,1))
h_max_meas = h_max[idx_s,:][:,idx_t].reshape((-1,1))
h_min_meas = h_min[idx_s,:][:,idx_t].reshape((-1,1))
hu_max_meas = hu_max[idx_s,:][:,idx_t].reshape((-1,1))
hu_min_meas = hu_min[idx_s,:][:,idx_t].reshape((-1,1))
u_max_meas = u_max[idx_s,:][:,idx_t].reshape((-1,1))
u_min_meas = u_min[idx_s,:][:,idx_t].reshape((-1,1))

u_meas = u_data[idx_s, :]
u_meas = u_meas[:, idx_t].reshape((-1,1))
b_meas = b_data[idx_s, :]
b_meas = b_meas[:, idx_t].reshape((-1,1))


X_meas = np.hstack((x_meas, t_meas))

Split_TrainVal = 0.8
N_train = int(N_s*steps*Split_TrainVal)
idx_train = np.random.choice(X_meas.shape[0], N_train, replace=False)
# print(idx_train.shape)
X_train = X_meas[idx_train,:]
h_max_train = h_max_meas[idx_train,:]
h_min_train = h_min_meas[idx_train,:]
hu_max_train = hu_max_meas[idx_train,:]
hu_min_train = hu_min_meas[idx_train,:]
u_max_train = u_max_meas[idx_train,:]
u_min_train = u_min_meas[idx_train,:]

u_train = u_meas[idx_train,:]
b_train = b_meas[idx_train,:]

# Validation Measurements, which are the rest of measurements
idx_val = np.setdiff1d(np.arange(X_meas.shape[0]), idx_train, assume_unique=True)
X_val = X_meas[idx_val,:]
h_max_val = h_max_meas[idx_val,:]
h_min_val = h_min_meas[idx_val,:]
hu_max_val = hu_max_meas[idx_val,:]
hu_min_val = hu_min_meas[idx_val,:]
u_max_val = u_max_meas[idx_val,:]
u_min_val = u_min_meas[idx_val,:]

u_val = u_meas[idx_val,:]
b_val = b_meas[idx_val,:]

### Define all the relevant functions: $\ell_0$ function, the neural network, count_weight function, etc.

In [None]:
class L0Gate(nn.Module):
    def __init__(self, shape, droprate_init=0.5, temperature=2./3.):
        super().__init__()
        self.qz_loga = nn.Parameter(torch.Tensor(shape))
        self.temperature = temperature
        # init log-alpha
        self.qz_loga.data.normal_(mean=np.log(droprate_init) - np.log(1 - droprate_init), std=1e-2)

    def _hard_concrete_sample(self):
        u = torch.rand_like(self.qz_loga)
        s = torch.sigmoid((torch.log(u) - torch.log(1 - u) + self.qz_loga) / self.temperature)
        z = s* (1.1 - 0.1) + 0.1  # Stretch to (0.1, 1.1)
        return torch.clamp(z, 0, 1)

    def forward(self):
        return self._hard_concrete_sample()

    def l0_loss(self):
        # Expected gate value → expected L0 norm
        s = torch.sigmoid(self.qz_loga)
        return torch.sum(s)

def count_weights_and_nnz(net):
    total_weights = 0
    total_nnz = 0

    layers = [
        (net.hidden_layer1, net.g1),
        (net.hidden_layer2, net.g2),
        (net.hidden_layer3, net.g3),
        (net.hidden_layer4, net.g4),
        #   (net.hidden_layer5, net.g5),
        #   (net.hidden_layer6, net.g6),
    ]

    for layer, gate in layers:
        w = layer.weight
        b = layer.bias

        in_features = w.shape[1]
        out_features = w.shape[0]

        # Count total weights in this layer
        total_layer = w.numel() + b.numel()
        total_weights += total_layer

        # Expected active neurons from gate (soft L0)
        s = torch.sigmoid(gate.qz_loga).detach()
        active = (s > 0.5).float()  # hard threshold
        n_active = int(active.sum().item())

        # Each active neuron has all its incoming weights + 1 bias
        nnz_layer = n_active * (in_features + 1)
        total_nnz += nnz_layer

    # Output layer (fully dense)
    w = net.output_layer.weight
    b = net.output_layer.bias
    total_weights += w.numel() + b.numel()
    total_nnz += w.numel() + b.numel()

    return total_weights, total_nnz

#Activation function
def m(x):
    return torch.sin(x)
#   return torch.relu(x)

def cart_inputs(x,t):
    a = np.array([[x0,t0] for x0 in x for t0 in t])
    return a[:,0].reshape(-1,1), a[:,1].reshape(-1,1)

#Neural network; uncomment/comment based on how many layers you want in the network
class Net(nn.Module):
    def __init__(self, H):
        super(Net, self).__init__()
        
        self.hidden_layer1 = nn.Linear(2, H)
        self.hidden_layer2 = nn.Linear(H, H)
        self.hidden_layer3 = nn.Linear(H, H)
        self.hidden_layer4 = nn.Linear(H, H)
        # self.hidden_layer5 = nn.Linear(H, H)
        # self.hidden_layer6 = nn.Linear(H, H)

        # Add gates (one per neuron)
        self.g1 = L0Gate((H,))
        self.g2 = L0Gate((H,))
        self.g3 = L0Gate((H,))
        self.g4 = L0Gate((H,))
        # self.g5 = L0Gate((H,))
        # self.g6 = L0Gate((H,))

        self.output_layer = nn.Linear(H, 2)

    def forward(self, x,t,hmax,hmin,humax,humin,umax,umin):
        inputs = torch.cat([x,t],axis=1)

        z1 = self.g1()
        z2 = self.g2()
        z3 = self.g3()
        z4 = self.g4()
        # z5 = self.g5()
        # z6 = self.g6()

        h1 = m(self.hidden_layer1(inputs)) * z1
        h2 = m(self.hidden_layer2(h1)) * z2
        h3 = m(self.hidden_layer3(h2)) * z3
        h4 = m(self.hidden_layer4(h3)) * z4
        # h5 = m(self.hidden_layer5(h4)) * z5
        # h6 = m(self.hidden_layer6(h5)) * z6
        output = self.output_layer(h4)
        h = output[:,0].reshape(-1,1)
        u = output[:,1].reshape(-1,1)
        hu = h*u

        h_x = torch.autograd.grad(h.sum(), x, create_graph=True,allow_unused=True)[0]
        h_xx = torch.autograd.grad(h_x.sum(), x , create_graph=True,allow_unused=True)[0]
        h_t = torch.autograd.grad(h.sum(), t, create_graph=True,allow_unused=True)[0]
        hu_t = torch.autograd.grad(hu.sum(), t, create_graph=True,allow_unused=True)[0]
        hu_x = torch.autograd.grad(hu.sum(), x, create_graph=True,allow_unused=True)[0]
        u_x = torch.autograd.grad(u.sum(), x, create_graph=True,allow_unused=True)[0]
        c_h = hmax-hmin
        c_hu = humax-humin
        c_u = umax - umin

        F_out = 0
        pde_u = ((hmax - hmin)*h_t + (1/L)*(humax-humin)*hu_x).reshape(-1,1)
        hat_u_x = ((c_h*h+hmin)*(c_hu*hu_x)*(1/L)-(c_hu*hu+humin)*(c_h*h_x)*(1/L))/((c_h*h+hmin)**2)
        hat_u = (c_hu*hu + humin)/(c_h*h+hmin)
        pde_b = (c_hu*hu_t + (c_hu*hu+humin)*hat_u_x + (c_hu*hu_x)*(1/L)*(hat_u) + 9.8*(c_h*h + hmin)*(c_h*h_x)*(1/L)).reshape(-1,1)

        return h,hu,pde_u,pde_b


### Create all the relevant input-output data.

In [None]:
x_ic,t_ic = cart_inputs(x_bounds,t_bounds[0]*np.ones((1)))
u_ic = torch.tensor(u0_data).reshape(-1).reshape(-1,1).detach().numpy()
b_ic = torch.tensor(b0_data).reshape(-1).reshape(-1,1).detach().numpy()

x_collocation = X_train[:,0].reshape(-1,1) #np.random.uniform(low=x_min, high=x_max, size=(N_x+1,1))
t_collocation = X_train[:,1].reshape(-1,1)

pt_x_ic = Variable(torch.from_numpy(x_ic).float(), requires_grad=True).to(device)
pt_t_ic = Variable(torch.from_numpy(t_ic).float(), requires_grad=True).to(device)

all_zeros = np.zeros((X_train.shape[0],1))
pt_all_zeros = Variable(torch.from_numpy(all_zeros).float(), requires_grad=False).to(device)


pt_x_collocation = Variable(torch.from_numpy(x_collocation).float(), requires_grad=True).to(device)
pt_t_collocation = Variable(torch.from_numpy(t_collocation).float(), requires_grad=True).to(device)

x_val = X_val[:,0].reshape(-1,1) #np.random.uniform(low=x_min, high=x_max, size=(N_x+1,1))
t_val = X_val[:,1].reshape(-1,1)

pt_x_val = Variable(torch.from_numpy(x_val).float(), requires_grad=True).to(device)
pt_t_val = Variable(torch.from_numpy(t_val).float(), requires_grad=True).to(device)
h_max_train_t = torch.tensor(h_max_train).float().to(device)
h_min_train_t = torch.tensor(h_min_train).float().to(device)
hu_max_train_t = torch.tensor(hu_max_train).float().to(device)
hu_min_train_t = torch.tensor(hu_min_train).float().to(device)
u_max_train_t = torch.tensor(u_max_train).float().to(device)
u_min_train_t = torch.tensor(u_min_train).float().to(device)

h_max_val_t = torch.from_numpy(h_max_val).float().to(device)
h_min_val_t = torch.from_numpy(h_min_val).float().to(device)
hu_max_val_t = torch.from_numpy(hu_max_val).float().to(device)
hu_min_val_t = torch.from_numpy(hu_min_val).float().to(device)
u_max_val_t = torch.from_numpy(u_max_val).float().to(device)
u_min_val_t = torch.from_numpy(u_min_val).float().to(device)

### Main training loop and validation errors

In [None]:
# Define hyperparameter sets
learning_rates = [0.005] #learning rate
hidden_dim = 50 #number of neurons per layer
lam_0 = [1e-3,1e-6] #regularization parameter
num_repeats = 3  # repeat training with different seeds
patience = 5 
epochs_no_improve = 0
max_epochs = 10000
validate_every = 2000


best_global_val_loss = float('inf')
best_model_state = None
best_hparams = {}

results = []
results_h = []
results_hu = []

mse_cost_function1 = torch.nn.MSELoss() # Mean squared error

for lr in learning_rates:
# for hidden_dim in hidden_dims:
    for lambda0 in lam_0:
        run_errors = []
        run_error_h = []
        run_error_hu = []

        print(f"\n=== Training with lr={lr}, L0={lambda0} ===hidden dim={hidden_dim}====")

        for run in range(num_repeats):
            print(f"Run {run+1}/{num_repeats}")
            torch.manual_seed(run)   # different init each repeat
            t_start_train = time.time()

            # Reinitialize model each time
            net = Net(hidden_dim).to(device)
            optimizer = torch.optim.Adam(net.parameters(), lr=lr)

            best_val_loss = float('inf')
            best_model_wts = copy.deepcopy(net.state_dict())
            # ---- Training loop ----
            for epoch in range(max_epochs):
                net.train()
                optimizer.zero_grad()

                out, bout, f_out, g_out = net(pt_x_collocation, pt_t_collocation,
                                            h_max_train_t, h_min_train_t,
                                            hu_max_train_t, hu_min_train_t,
                                            u_max_train_t, u_min_train_t)

                net_ic, net_bic, _, _ = net(pt_x_ic, pt_t_ic,
                                            torch.tensor(h_max[:,0]).float().to(device),
                                            torch.tensor(h_min[:,0]).float().to(device),
                                            torch.tensor(hu_max[:,0]).float().to(device),
                                            torch.tensor(hu_min[:,0]).float().to(device),
                                            torch.tensor(u_max[:,0]).float().to(device),
                                            torch.tensor(u_min[:,0]).float().to(device))

                h_out = out.reshape(-1, 1)
                hu_out = bout.reshape(-1, 1)

                # Loss
                mse_h = mse_cost_function1(f_out, pt_all_zeros)
                mse_hu = mse_cost_function1(g_out, pt_all_zeros)
                mse_huic = mse_cost_function1(net_bic.reshape(-1,1), torch.from_numpy(b_ic.reshape(-1,1)).float().to(device))

                mse_hic = mse_cost_function1(net_ic.reshape(-1,1), torch.from_numpy(u_ic.reshape(-1,1)).float().to(device))
                mse_hdata = mse_cost_function1(h_out, torch.tensor(u_train).float().to(device))
                mse_hudata = mse_cost_function1(hu_out, torch.tensor(b_train).float().to(device))

                # loss = mse_hudata + mse_h + mse_hic
                L0_term = (net.g1.l0_loss() + net.g2.l0_loss()   + net.g3.l0_loss() + net.g4.l0_loss())
                        #  + net.g5.l0_loss() + net.g6.l0_loss())


                # loss = mse_hudata + mse_h + mse_hic + lambda0 * L0_term
                loss = mse_hdata + mse_huic + mse_hu + lambda0 * L0_term
                loss.backward()
                optimizer.step()

                # ---- Validation ----
                if (epoch) % validate_every == 0:
                    net.eval()
                    # print(mse_hdata, mse_hu)
                    # with torch.no_grad():
                    u_val_learnt, b_val_learnt, _, _ = net(pt_x_val, pt_t_val,
                                                            h_max_val_t, h_min_val_t,
                                                            hu_max_val_t, hu_min_val_t,
                                                            u_max_val_t, u_min_val_t)

                    rel_hudata_val = torch.norm(b_val_learnt - torch.tensor(b_val).float().to(device)) / torch.norm(torch.tensor(b_val).float().to(device))
                    rel_hdata_val = torch.norm(u_val_learnt - torch.tensor(u_val).float().to(device)) / torch.norm(torch.tensor(u_val).float().to(device))

                    val_loss = 0.5*(rel_hudata_val + rel_hdata_val).item()
                    print(f"Epoch {epoch}: Val Loss={val_loss:.4f}")

                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        best_model_wts = copy.deepcopy(net.state_dict())
                        epochs_no_improve = 0
                    else:
                        epochs_no_improve += 1

                    if epochs_no_improve == patience:
                        print(f"Early stopping at epoch {epoch}")
                        break

            # Restore best model weights
            net.load_state_dict(best_model_wts)
            t_end_train = time.time()

            # ---- Compute error metrics on this run ----
            t_start_eval = time.time()
            x1, t1 = cart_inputs(x_bounds, t_data[0,:])
            pt_x = torch.from_numpy(x1).float().to(device).requires_grad_(True)
            pt_t = torch.from_numpy(t1).float().to(device).requires_grad_(True)

            pt_u, pt_b, _, _ = net(pt_x, pt_t,
                                torch.tensor(h_max).reshape(-1,1).float().to(device),
                                torch.tensor(h_min).reshape(-1,1).float().to(device),
                                torch.tensor(hu_max).reshape(-1,1).float().to(device),
                                torch.tensor(hu_min).reshape(-1,1).float().to(device),
                                torch.tensor(u_max).reshape(-1,1).float().to(device),
                                torch.tensor(u_min).reshape(-1,1).float().to(device))
            t_end_eval = time.time()

            ms_u = pt_u.reshape(x_bounds.shape[0], t_data[0,:].shape[0]) #.detach().cpu().numpy()
            ms_b = pt_b.reshape(x_bounds.shape[0], t_data[0,:].shape[0]) #.detach().cpu().numpy()
            full_field_true = torch.vstack((torch.tensor(h_clean_norm),torch.tensor(hu_clean_norm))).to(device)
            full_field_net = torch.vstack((ms_u,ms_b)).detach() #.detach()

            error_uv = torch.zeros(N_t)
            error_hu = torch.zeros(N_t)
            error_h = torch.zeros(N_t)
            for i in range(N_t):
                error_uv[i] = torch.norm(full_field_true[:,i] - full_field_net[:,i])/torch.norm(full_field_true[:,i])
                error_hu[i] = torch.norm(torch.tensor(hu_clean_norm[:,i]).to(device) - ms_b[:,i])/torch.norm(torch.tensor(hu_clean_norm[:,i]).to(device))
                error_h[i] = torch.norm(torch.tensor(h_clean_norm[:,i]).to(device) - ms_u[:,i])/torch.norm(torch.tensor(h_clean_norm[:,i]).to(device))

            print('\nError uv',torch.mean(error_uv),' Error u',torch.mean(error_h),'Error hu:', torch.mean(error_hu),'\n')


            error_h_fin = error_h[:]
            error_hu_fin = error_hu[:]
            error_uv_fin = error_uv[:]

            print('\nError uv',torch.mean(error_uv_fin),' Error u',torch.mean(error_h_fin),'Error hu:', torch.mean(error_hu_fin),'\n')

            error_uv = error_uv_fin
            error_h = error_h_fin
            error_hu = error_hu_fin

            final_error = torch.mean(error_uv)
            run_errors.append(final_error.item())
            run_error_h.append(error_h)
            run_error_hu.append(error_hu)
            # print()
            print(f"Run {run+1}/{num_repeats} -> error={final_error.item():.4f}")
            total_w, nnz_w = count_weights_and_nnz(net)
            perc = 100 * nnz_w / total_w

            print(f"\nTotal weights: {total_w}")
            print(f"\nNon-zero (active) weights: {nnz_w}")
            print(f"\nPercentage active: {perc:.2f}%")
            print(f"\nTotal training time is",t_end_train-t_start_train)
            print(f"\nTotal evaluation time is",t_end_eval-t_start_eval)
            print("\n---------------------------------------------------------\n")


        # ---- Compute mean ± CI ----
        mean_err = np.mean(run_errors)
        sem = stats.sem(run_errors)
        # print('run error',run_errors,sem)
        ci95 = sem * stats.t.ppf((1+0.95)/2., len(run_errors)-1)

        results.append((lr, hidden_dim, mean_err, ci95))
        print(f"===> lr={lr}, hidden_dim={hidden_dim}, L0={lambda0} and {N_t-std_noise} steps -> full error={mean_err:.4f} ± {ci95:.4f}")

        # Track global best
        if mean_err < best_global_val_loss:
            best_global_val_loss = mean_err
            best_model_state = copy.deepcopy(net.state_dict())
            best_hparams = {'lr': lr, 'hidden_dim': hidden_dim}

print(f"\nBest hyperparameters: {best_hparams}, Error={best_global_val_loss:.4f}")