# Case C: Simple channel flow




## Phenomenon
Start-up and propagation of flow entering a straight channel. Relevant for canals, simplified river reaches, or laboratory flumes.

## Specific Setup
- Equations: No rainfall and friction active
- Terrain (z_terrain​): Horizontal bed
- Initial Conditions (IC): Starts almost dry
- Boundary Conditions (BCs):
    - Inlet (x=0): u=u(inflow)​, v=0
    - Outlet (xmax): Wall u = 0
    - Sides (ymin, ymax): Walls v=0

In [None]:
# ----- Imports -----
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import HTML, display, clear_output
from tqdm.notebook import tqdm # Use notebook version if in Jupyter/Colab

import time
import os

In [None]:
import os

from google.colab import drive
drive.mount('/content/drive')

In [None]:
os.chdir("drive/My Drive/Colab Data/")
!ls

In [None]:
# ----- Domain Definition & Simulation Parameters -----
# MODIFIED Domain for a 3x1 aspect ratio channel
domain = {'x_min': 0.0, 'x_max': 6.0, # Channel length = 3
          'y_min': 0.0, 'y_max': 1.0, # Channel width = 1
          't_min': 0.0, 't_max': 120.0} # MODIFIED: Simulation time in seconds (e.g., 2 minutes)

# Sampling points
N_collocation = 10000 # Increase slightly for larger domain maybe. Initial 7000 was Okish
N_bc = 2500          # Increase slightly. Initial 2500 was Okish
N_ic = 3000          # Increase slightly. Initial 3000 was Okish

# Normalisation scaling factors (Recalculated)
scale_x = 2.0 / (domain['x_max'] - domain['x_min']) # Will be smaller
scale_y = 2.0 / (domain['y_max'] - domain['y_min']) # Same as before
scale_t = 2.0 / (domain['t_max'] - domain['t_min']) # Updated for new t_max

# --- Device Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# --- Terrain Definition (Flat Channel Bed) ---
def create_terrain_torch(X, Y):
    """ Creates a flat channel bed at z=0. """
    if not isinstance(X, torch.Tensor): X = torch.tensor(X, dtype=torch.float32, device=device)
    if not isinstance(Y, torch.Tensor): Y = torch.tensor(Y, dtype=torch.float32, device=device)
    terrain = torch.zeros_like(X)
    return terrain

def create_terrain_np(x, y):
    """ Creates terrain using NumPy arrays by calling the torch version. """
    X_np, Y_np = np.meshgrid(x, y, indexing='xy')
    with torch.no_grad():
      terrain_torch = create_terrain_torch(torch.tensor(X_np, dtype=torch.float32, device=device),
                                           torch.tensor(Y_np, dtype=torch.float32, device=device))
    return terrain_torch.cpu().numpy()

# --- PINN Model Definition ---
class FloodNet(nn.Module):
    def __init__(self, in_dim=3, hid_dim=64, out_dim=3): # Adjust hid_dim if needed
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim), nn.Tanh(),
            nn.Linear(hid_dim, hid_dim), nn.Tanh(),
            nn.Linear(hid_dim, hid_dim), nn.Tanh(),
            nn.Linear(hid_dim, hid_dim), nn.Tanh(),
            nn.Linear(hid_dim, out_dim)
        )
        self.init_weights()
    def init_weights(self):
      for m in self.modules():
          if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias) if m.bias is not None else None
    def forward(self, x): return self.net(x)

# --- Coordinate Normalisation ---
def normalize(x, y, t):
    # Uses updated global scale factors
    x_n = scale_x * (x - domain['x_min']) - 1.0
    y_n = scale_y * (y - domain['y_min']) - 1.0
    t_n = scale_t * (t - domain['t_min']) - 1.0
    return x_n, y_n, t_n

## Shallow Water Equations (SWEs) used in this script

The provided Python function `shallow_water_residuals` calculates the residuals for the **2D non-linear shallow water equations**. These equations are used to model fluid flow where the horizontal scale is much larger than the depth.

### Variables:
* **$\zeta$ (`zeta`)**: Water surface elevation.
* **$z_b$ (`terrain_c`)**: Bed elevation.
* **$h$ (`h`)**: Water depth, $h = \zeta - z_b$.
* **$u$ (`u`)**: Depth-averaged velocity in the x-direction.
* **$v$ (`v`)**: Depth-averaged velocity in the y-direction.
* **$g$ (`g`)**: Acceleration due to gravity.
* **$f_k$ (`friction_factor`)**: Bed friction coefficient.
* **$\frac{\partial \phi}{\partial t}$ (`phi_t`)**: Partial derivative of $\phi$ with respect to time.
* **$\frac{\partial \phi}{\partial x}$ (`phi_x`)**: Partial derivative of $\phi$ with respect to x.
* **$\frac{\partial \phi}{\partial y}$ (`phi_y`)**: Partial derivative of $\phi$ with respect to y.
* **$\frac{\partial z_b}{\partial x}$ (`z_x`)**, **$\frac{\partial z_b}{\partial y}$ (`z_y`)**: Bed slopes (explicitly set to zero in this specific code snippet, meaning a locally flat bed is assumed for the evaluation).

---
### 1. Continuity Equation (Mass Conservation)
Represented by `r1 = h_t + u*h_x + h*u_x + v*h_y + h*v_y`

Standard form:
$$ \frac{\partial h}{\partial t} + \frac{\partial (hu)}{\partial x} + \frac{\partial (hv)}{\partial y} = 0 $$

* This equation ensures that the change in water depth over time is balanced by the net flux of water.

---
### 2. Momentum Equation in x-direction
Represented by `r2 = u_t + u*u_x + v*u_y + g*(h_x + z_x) + friction_factor*u`

Standard form:
$$ \frac{\partial u}{\partial t} + u\frac{\partial u}{\partial x} + v\frac{\partial u}{\partial y} + g\left(\frac{\partial h}{\partial x} + \frac{\partial z_b}{\partial x}\right) + f_k u = 0 $$

* **Terms:**
    * $\frac{\partial u}{\partial t}$: Local acceleration.
    * $u\frac{\partial u}{\partial x} + v\frac{\partial u}{\partial y}$: Advective acceleration (non-linear).
    * $g\left(\frac{\partial h}{\partial x} + \frac{\partial z_b}{\partial x}\right)$: Gravitational force due to water surface and bed slope (pressure gradient). (Note: $\frac{\partial z_b}{\partial x}$ is treated as 0 here).
    * $f_k u$: Linear bed friction.

---
### 3. Momentum Equation in y-direction
Represented by `r3 = v_t + u*v_x + v*v_y + g*(h_y + z_y) + friction_factor*v`

Standard form:
$$ \frac{\partial v}{\partial t} + u\frac{\partial v}{\partial x} + v\frac{\partial v}{\partial y} + g\left(\frac{\partial h}{\partial y} + \frac{\partial z_b}{\partial y}\right) + f_k v = 0 $$

* **Terms:** Similar to the x-momentum equation, but for the y-direction. (Note: $\frac{\partial z_b}{\partial y}$ is treated as 0 here).

---
### Key Characteristics:
* **2D Depth-Averaged:** Models flow in two horizontal dimensions.
* **Non-linear:** Contains advective terms (e.g., $u \frac{\partial u}{\partial x}$).
* **Includes Gravity:** Accounts for gravitational forces due to surface slopes.
* **Includes Linear Bed Friction:** Uses a simplified friction model ($f_k u, f_k v$).
* **Handles Variable Bed Topography (in general form):** The equations are structured for bed slopes, but this snippet specifically evaluates them for a **locally flat bed** by setting $z_x=0$ and $z_y=0$.
* **Excludes:** Coriolis force, wind stress, and explicit source/sink terms (like rainfall) within these residual definitions.

The code uses `torch.autograd.grad` for derivative calculations, applied to normalised coordinates and then scaled to physical coordinates, which is a common practice in PINNs.

In [None]:
# --- Shallow Water Equations ---
def shallow_water_residuals(zeta, u, v, terrain_c,
                            Xc, Yc, Tc, Xn, Yn, Tn,
                            g=9.81, friction_factor=0.01):
    h = zeta - terrain_c
    h_safe = h + 1e-6

    # Derivatives w.r.t NORMALISED coordinates
    h_grads = torch.autograd.grad(h, (Xn, Yn, Tn), grad_outputs=torch.ones_like(h), create_graph=False, retain_graph=True)
    h_xhat, h_yhat, h_that = h_grads # No [0]
    u_grads = torch.autograd.grad(u, (Xn, Yn, Tn), grad_outputs=torch.ones_like(u), create_graph=False, retain_graph=True)
    u_xhat, u_yhat, u_that = u_grads # No [0]
    v_grads = torch.autograd.grad(v, (Xn, Yn, Tn), grad_outputs=torch.ones_like(v), create_graph=False, retain_graph=True)
    v_xhat, v_yhat, v_that = v_grads # No [0]

    # Convert derivatives to PHYSICAL coordinates
    h_x, h_y, h_t = h_xhat*scale_x, h_yhat*scale_y, h_that*scale_t
    u_x, u_y, u_t = u_xhat*scale_x, u_yhat*scale_y, u_that*scale_t
    v_x, v_y, v_t = v_xhat*scale_x, v_yhat*scale_y, v_that*scale_t

    # Terrain slope derivatives - Explicitly zero for flat terrain
    z_x = torch.zeros_like(terrain_c); z_y = torch.zeros_like(terrain_c)

    # Residuals
    r1 = h_t + u*h_x + h*u_x + v*h_y + h*v_y
    r2 = u_t + u*u_x + v*u_y + g*(h_x + z_x) + friction_factor*u
    r3 = v_t + u*v_x + v*v_y + g*(h_y + z_y) + friction_factor*v

    return r1, r2, r3

In [None]:
# --- Training Loop ---
def train_pinn(model, opt, scheduler, x_grid, y_grid, n_epochs=8000,
               w_pde=1.0, w_ic=150.0, w_bc=150.0, w_phys=15.0,
               friction_factor=0.01, u_inflow=0.5):
    model.train(); loss_hist = []; start_time = time.time()

    # IC Setup
    print("Setting up initial conditions..."); X0_np, Y0_np = np.meshgrid(x_grid,y_grid,indexing='xy'); T0_np = np.zeros_like(X0_np)
    terrain0_np = create_terrain_np(x_grid,y_grid); h0_np = np.full_like(X0_np,1e-6); zeta0_np = h0_np+terrain0_np
    X0_t = torch.tensor(X0_np,dtype=torch.float32,device=device); Y0_t = torch.tensor(Y0_np,dtype=torch.float32,device=device)
    T0_t = torch.tensor(T0_np,dtype=torch.float32,device=device); zeta0_t = torch.tensor(zeta0_np,dtype=torch.float32,device=device)
    Xn0,Yn0,Tn0 = normalize(X0_t.unsqueeze(-1),Y0_t.unsqueeze(-1),T0_t.unsqueeze(-1)); inp0 = torch.cat([Xn0,Yn0,Tn0],dim=2).view(-1,3)
    # --- End IC Setup --

    N_bc_each = N_bc//4

    print(f"Starting training for {n_epochs} epochs...");
    for ep in tqdm(range(n_epochs), desc=f'Training Channel Flow (u_in={u_inflow:.2f}, f={friction_factor:.3f})'):
        model.train(); opt.zero_grad()

        # Collocation points - sampling respects new domain size automatically
        x_c = torch.rand(N_collocation,1,device=device)*(domain['x_max']-domain['x_min'])+domain['x_min']
        y_c = torch.rand(N_collocation,1,device=device)*(domain['y_max']-domain['y_min'])+domain['y_min']
        t_c = torch.rand(N_collocation,1,device=device)*(domain['t_max']-domain['t_min'])+domain['t_min']
        Xc=x_c.clone().detach().requires_grad_(True); Yc=y_c.clone().detach().requires_grad_(True); Tc=t_c.clone().detach()
        terrain_c=create_terrain_torch(Xc,Yc); Xn,Yn,Tn_no_grad=normalize(Xc,Yc,Tc); Tn=Tn_no_grad.clone().detach().requires_grad_(True)
        inp_c=torch.cat([Xn,Yn,Tn],dim=1); zeta_c,u_c,v_c=model(inp_c).split(1,dim=1)

        # PDE Loss
        r1,r2,r3=shallow_water_residuals(zeta_c,u_c,v_c,terrain_c,Xc,Yc,Tc,Xn,Yn,Tn,friction_factor=friction_factor)
        PDE_loss=torch.mean(r1**2)+torch.mean(r2**2)+torch.mean(r3**2)

        # Physical Loss
        h_c=zeta_c-terrain_c; PHYS_loss=torch.mean(F.relu(-h_c)**2)

        # IC Loss
        zeta0_pred,u0_pred,v0_pred=model(inp0).split(1,dim=1); zeta0_pred_grid=zeta0_pred.view_as(X0_t)
        IC_loss=torch.mean((zeta0_pred_grid-zeta0_t)**2)+torch.mean(u0_pred**2)+torch.mean(v0_pred**2)

        # BC Loss - Logic uses domain dictionary values, adapts automatically
        BC_loss=torch.tensor(0.0,device=device); t_bc=torch.rand(N_bc_each*2,1,device=device)*domain['t_max']
        x_bc_wall=torch.rand(N_bc_each,1,device=device)*domain['x_max']; y_bc_wall=torch.rand(N_bc_each,1,device=device)*domain['y_max']
        x0_bc=torch.full_like(y_bc_wall,domain['x_min']); xn_in,yn_in,tn_in=normalize(x0_bc,y_bc_wall,t_bc[:N_bc_each]); inp_in=torch.cat([xn_in,yn_in,tn_in],dim=1); _,u_in,_=model(inp_in).split(1,dim=1); BC_loss+=torch.mean((u_in-u_inflow)**2)
        x1_bc_wall=torch.full_like(y_bc_wall,domain['x_max']); xn_w1,yn_w1,tn_w1=normalize(x1_bc_wall,y_bc_wall,t_bc[N_bc_each:N_bc_each*2]); inp_w1=torch.cat([xn_w1,yn_w1,tn_w1],dim=1); _,u_w1,_=model(inp_w1).split(1,dim=1); BC_loss+=torch.mean(u_w1**2)
        y0_bc_wall=torch.full_like(x_bc_wall,domain['y_min']); xn_wy0,yn_wy0,tn_wy0=normalize(x_bc_wall,y0_bc_wall,t_bc[:N_bc_each]); inp_wy0=torch.cat([xn_wy0,yn_wy0,tn_wy0],dim=1); _,_,v_wy0=model(inp_wy0).split(1,dim=1); BC_loss+=torch.mean(v_wy0**2)
        y1_bc_wall=torch.full_like(x_bc_wall,domain['y_max']); xn_wy1,yn_wy1,tn_wy1=normalize(x_bc_wall,y1_bc_wall,t_bc[N_bc_each:N_bc_each*2]); inp_wy1=torch.cat([xn_wy1,yn_wy1,tn_wy1],dim=1); _,_,v_wy1=model(inp_wy1).split(1,dim=1); BC_loss+=torch.mean(v_wy1**2)
        BC_loss/=4.0

        # Total Loss & Backprop
        loss=w_pde*PDE_loss+w_ic*IC_loss+w_bc*BC_loss+w_phys*PHYS_loss; loss.backward()
        opt.step(); scheduler.step()

        # Logging
        if ep%500==0 or ep==n_epochs-1: loss_hist.append(loss.item()); elapsed=time.time()-start_time; print(f"Epoch {ep}/{n_epochs} | Loss: {loss.item():.4e} | PDE: {PDE_loss.item():.3e} | IC: {IC_loss.item():.3e} | BC: {BC_loss.item():.3e} | Phys: {PHYS_loss.item():.3e} | LR: {scheduler.get_last_lr()[0]:.2e} | Time: {elapsed:.1f}s")
    print(f"Training finished. Total time: {time.time()-start_time:.1f}s"); return loss_hist

In [None]:
# --- Visualisation Function (3D Surface Plot) ---
def visualize_results(model, x, y, t_vis, terrain_np, u_inflow, friction_factor, save_path="channel_3d.mp4"):
    # (This function adapts automatically to x, y grid shape)
    print(f"Generating 3D surface animation (will save to {save_path})...")
    model.eval(); Xp_np,Yp_np = np.meshgrid(x,y,indexing='xy'); Xp_t = torch.tensor(Xp_np,dtype=torch.float32,device=device).unsqueeze(-1); Yp_t = torch.tensor(Yp_np,dtype=torch.float32,device=device).unsqueeze(-1)
    min_terrain = np.min(terrain_np); max_terrain = np.max(terrain_np)
    fig=plt.figure(figsize=(12,6)); ax=fig.add_subplot(111,projection='3d') # Wider figure maybe
    z_min_plot=min_terrain-0.05; z_max_plot_init=max_terrain+0.3
    ax.set_zlim(z_min_plot,z_max_plot_init); print(f"Initial 3D Z limits: ({z_min_plot:.2f}, {z_max_plot_init:.2f})")
    water_surf=None; max_zeta_observed=-np.inf; min_depth_display=1e-4
    def update(frame):
        nonlocal water_surf,max_zeta_observed,z_max_plot_init
        ti=t_vis[frame]; ax.clear()
        with torch.no_grad():
            T_t=torch.full_like(Xp_t,fill_value=ti,dtype=torch.float32,device=device); Xn,Yn,Tn=normalize(Xp_t,Yp_t,T_t)
            inp=torch.cat([Xn,Yn,Tn],dim=2).view(-1,3); zeta_pred,_,_=model(inp).split(1,dim=1)
            zeta_np=zeta_pred.view(Xp_np.shape).cpu().numpy()
        max_zeta_observed=max(max_zeta_observed,np.max(zeta_np)); h_np=zeta_np-terrain_np
        zeta_display=np.where(h_np>min_depth_display,np.maximum(zeta_np,terrain_np+1e-5),np.nan)
        ax.plot_surface(Xp_np,Yp_np,terrain_np,color='dimgray',alpha=0.3,rstride=5,cstride=5)
        if not np.all(np.isnan(zeta_display)): water_surf=ax.plot_surface(Xp_np,Yp_np,zeta_display,color='skyblue',alpha=0.7,rstride=5,cstride=5)
        else: water_surf=None
        ax.set_title(f"Channel Flow 3D (u_in={u_inflow:.2f}, f={friction_factor:.3f}): Time = {ti:.1f}s")
        ax.set_xlabel('X (Flow Direction)'); ax.set_ylabel('Y (Width)'); ax.set_zlabel('Water Surface Elev.')
        current_z_max=max(z_max_plot_init,max_zeta_observed+0.05); ax.set_zlim(z_min_plot,current_z_max)
        # Set view angle
        ax.view_init(elev=25, azim=-110) # MODIFY to 70 ???  (TO-DO)
    ani=animation.FuncAnimation(fig,update,frames=len(t_vis),interval=100,blit=False) # MODIFIED: 1000ms / 20fps = 50ms (it was 100)
    try:
        print(f"Saving 3D animation to {save_path}..."); writer=animation.FFMpegWriter(fps=20,metadata=dict(artist='PINN Flood Sim'),bitrate=1800)
        ani.save(save_path,writer=writer); print(f"3D animation successfully saved to {save_path}")
    except Exception as e: print(f"\n--- ERROR: Failed to save 3D animation ---\n{e}"); save_path=None
    html_video=ani.to_jshtml(); plt.close(fig); print(f"Max zeta observed (3D): {max_zeta_observed:.3f}"); print("3D surface animation generation complete.")
    return HTML(html_video),save_path

# --- Visualisation Function (2D Velocity Plot) - MODIFIED Aspect Ratio ---
def visualize_velocity_2d(model, x, y, t_vis, terrain_np, u_inflow, friction_factor, save_path="channel_2d.mp4"):
    print(f"Generating 2D velocity animation (will save to {save_path})...")
    model.eval(); Xp_np,Yp_np = np.meshgrid(x,y,indexing='xy'); Xp_t = torch.tensor(Xp_np,dtype=torch.float32,device=device).unsqueeze(-1); Yp_t = torch.tensor(Yp_np,dtype=torch.float32,device=device).unsqueeze(-1)
    # Adjust figure size for better aspect ratio visualization
    aspect_ratio = (domain['x_max']-domain['x_min']) / (domain['y_max']-domain['y_min'])
    fig_width = 10
    fig_height = fig_width / aspect_ratio + 1.5 # Add some height for title etc.
    fig,ax=plt.subplots(figsize=(fig_width, fig_height));
    max_h_plot=0.2; max_v_scale=u_inflow*1.1; skip=3 # Adjust skip based on grid_res
    max_h_observed_anim=0.0; min_depth_display=1e-4

    def update_2d(frame):
        nonlocal max_h_plot,max_v_scale,max_h_observed_anim; ti=t_vis[frame]; ax.clear()
        with torch.no_grad():
            T_t=torch.full_like(Xp_t,fill_value=ti,dtype=torch.float32,device=device); Xn,Yn,Tn=normalize(Xp_t,Yp_t,T_t)
            inp=torch.cat([Xn,Yn,Tn],dim=2).view(-1,3); zeta_pred,u_pred,v_pred=model(inp).split(1,dim=1)
            zeta_np=zeta_pred.view(Xp_np.shape).cpu().numpy(); u_np=u_pred.view(Xp_np.shape).cpu().numpy(); v_np=v_pred.view(Xp_np.shape).cpu().numpy()
        h_np=zeta_np-terrain_np; h_plot=np.maximum(h_np,0); max_h_observed_anim=max(max_h_observed_anim,np.max(h_plot))
        if max_h_observed_anim > max_h_plot: max_h_plot=max_h_observed_anim
        depth_levels=np.linspace(min_depth_display,max_h_plot+1e-6,15)
        if len(depth_levels)>1: contour=ax.contourf(Xp_np,Yp_np,h_plot,levels=depth_levels,cmap='Blues',extend='max',alpha=0.8)
        else: contour=None
        quiver_mask=h_plot[::skip,::skip]>min_depth_display
        X_q=Xp_np[::skip,::skip][quiver_mask]; Y_q=Yp_np[::skip,::skip][quiver_mask]; U_q=u_np[::skip,::skip][quiver_mask]; V_q=v_np[::skip,::skip][quiver_mask]
        if X_q.size>0:
            current_max_v = np.max(np.sqrt(U_q**2+V_q**2)) if U_q.size>0 else 0.01
            if current_max_v > max_v_scale: max_v_scale = current_max_v
            dynamic_scale = max(max_v_scale,0.01)*20 # Adjust scale multiplier
            quiver = ax.quiver(X_q, Y_q, U_q, V_q, color='red', scale=dynamic_scale, scale_units='xy', angles='xy', width=0.004, headwidth=3, headlength=5)
        else: quiver=None
        ax.set_title(f"Channel Flow 2D (u_in={u_inflow:.2f}, f={friction_factor:.3f}): Time = {ti:.1f}s")
        ax.set_xlabel("X (Flow Direction)"); ax.set_ylabel("Y (Width)")
        # Use 'auto' aspect or remove for non-equal axes
        ax.set_aspect('auto', adjustable='box')
        ax.set_xlim(domain['x_min'],domain['x_max']); ax.set_ylim(domain['y_min'],domain['y_max'])
    ani_2d=animation.FuncAnimation(fig,update_2d,frames=len(t_vis),interval=100,blit=False)  # MODIFIED: 1000ms / 20fps = 50ms (it was 100)
    try:
        print(f"Saving 2D animation to {save_path}..."); writer=animation.FFMpegWriter(fps=20,metadata=dict(artist='PINN Flood Sim'),bitrate=1800)
        ani_2d.save(save_path,writer=writer); print(f"2D animation successfully saved to {save_path}")
    except Exception as e: print(f"\n--- ERROR: Failed to save 2D animation ---\n{e}"); save_path=None
    html_video_2d=ani_2d.to_jshtml(); plt.close(fig); print(f"Max depth observed (2D): {max_h_observed_anim:.3f}"); print("2D velocity animation generation complete.")
    return HTML(html_video_2d),save_path

In [None]:
# --- Main Execution Block ---
if __name__ == '__main__':
    # --- Setup Grid and Time for Rectangular Domain ---
    grid_res_x = 90 # More points along the length
    grid_res_y = 30 # Fewer points across the width
    xg = np.linspace(domain['x_min'], domain['x_max'], grid_res_x)
    yg = np.linspace(domain['y_min'], domain['y_max'], grid_res_y)
    t_vis_frames = 80 # More frames for longer sim/domain. MODIFIED for 200 frames / 20 fps = 10 seconds animation (it was 80)
    tg_vis = np.linspace(domain['t_min'], domain['t_max'], t_vis_frames)
    terrain_np_vis = create_terrain_np(xg, yg) # Will be flat
    print(f"Flat channel terrain created (Shape: {terrain_np_vis.shape}).")

    # --- Model, Optimiser, Scheduler ---
    model = FloodNet(in_dim=3, hid_dim=64, out_dim=3).to(device)
    learning_rate = 5e-4
    opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = ExponentialLR(opt, gamma=0.997)

    # --- Simulation Parameters ---
    n_epochs = 10000       # May need more/less for longer domain. Increased from 10000, adjust based on results
    friction_factor = 0.015 # Adjust friction
    u_inflow = 0.5         # Inflow velocity
    # Loss Weights (tune these) ###############################################################################
    weights = {'w_pde': 50.0, 'w_ic': 150.0, 'w_bc': 150.0, 'w_phys': 15.0}  # Initial guess :: weights = {'w_pde': 1.0, 'w_ic': 150.0, 'w_bc': 150.0, 'w_phys': 15.0}
    # -------------------------------------------------------------------------

    # --- Train the Model ---
    print("-" * 60); print(f"Starting training: Rectangular Channel Flow Scenario"); print(f"Domain: x=[{domain['x_min']},{domain['x_max']}], y=[{domain['y_min']},{domain['y_max']}], t=[{domain['t_min']},{domain['t_max']}]")
    print(f"Parameters: epochs={n_epochs}, lr={learning_rate}, friction={friction_factor}, u_in={u_inflow}"); print(f"Weights: PDE={weights['w_pde']}, IC={weights['w_ic']}, BC={weights['w_bc']}, Phys={weights['w_phys']}"); print("-" * 60)
    loss_history = train_pinn(model, opt, scheduler, xg, yg, n_epochs=n_epochs,
                              friction_factor=friction_factor, u_inflow=u_inflow, **weights)

    # --- Plot Loss History ---
    clear_output(wait=True)
    print("Training complete. Plotting loss history...")
    plt.figure(figsize=(8, 4)); plt.plot(np.arange(len(loss_history))*500, loss_history); plt.yscale('log'); plt.title('Training Loss History (Channel Flow)'); plt.xlabel('Epoch'); plt.ylabel('Log Loss'); plt.grid(True, which='both', ls='--'); plt.show()

    # --- Generate, Save, and Display Visualisations ---
    ff_str = f"{friction_factor:.3f}".replace('.', 'p')
    save_name_3d = f"channel_flow_rect_f{ff_str}_3d.mp4"
    save_name_2d = f"channel_flow_rect_f{ff_str}_2d.mp4"

    html_3d, saved_3d_path = visualize_results(model, xg, yg, tg_vis, terrain_np_vis, u_inflow, friction_factor, save_path=save_name_3d)
    if saved_3d_path: print(f"3D Animation saved to: {os.path.abspath(saved_3d_path)}")
    display(html_3d)

    html_2d, saved_2d_path = visualize_velocity_2d(model, xg, yg, tg_vis, terrain_np_vis, u_inflow, friction_factor, save_path=save_name_2d)
    if saved_2d_path: print(f"2D Animation saved to: {os.path.abspath(saved_2d_path)}")
    display(html_2d)

    print("-" * 60); print("Script finished."); print("-" * 60)