In [None]:
import torch
from visualise import visualise_v_quiver
import matplotlib.pyplot as plt

from utils import make_grid

# Use bed topography data as "real data"

In [1]:
thickness = torch.load("/home/kim/ice_thickness/data/icethickness_bedmachine_m_h_x_y.pt", weights_only = False).float()
surface = torch.load("/home/kim/ice_thickness/data/surface_m_ellipsoid_true_s_x_y.pt", weights_only = False).float()

# first dim are values
bed = surface[0] - thickness[0]

# Select a region of the bed to normalise
bed_selection = bed[0:150, 100:250]

std = bed_selection.std()
mean = bed_selection.mean()

# Normalise the bed selection
bed_selection_norm = (bed_selection - mean) / std

plt.figure(figsize = (6, 6))
plt.pcolormesh(bed_selection_norm, cmap = 'viridis', shading = 'auto')
plt.gca().set_aspect('equal')
plt.colorbar(label = 'Normalised Bed Selection')
plt.title("Normalised Bed Selection")
plt.tight_layout()
plt.show()

NameError: name 'torch' is not defined

In [None]:
bed_smoothed = torch.nn.functional.avg_pool2d(bed_selection_norm.unsqueeze(0).unsqueeze(0), kernel_size = 5, stride = 1, padding = 1)[0, 0]

fig, ax = plt.subplots(figsize = (6, 6))
c = ax.pcolormesh(bed_smoothed, cmap = 'viridis')
ax.set_aspect('equal')
ax.set_title("Smoothed Bed")

# Add colorbar next to the axis
cbar = plt.colorbar(c, ax = ax, fraction = 0.046, pad = 0.04)  # tweak these as needed
cbar.set_label('normalised bed elevation')

plt.tight_layout()
plt.show()

In [2]:
def generate_directed_stream(x_grid, angle_degree):
    """
    Generate a linear stream function (scalar field) based on the x coordinate grid.
    The stream function is directed in a specified angle (in degrees).
    """
    # Base linear stream function giving uniform flow in ~250° direction
    angle_rad = torch.deg2rad(torch.tensor(angle_degree, dtype = torch.float32))

    a = torch.cos(angle_rad)  # x weight
    b = torch.sin(angle_rad)  # y weight

    # ψ(x, y) = a*x + b*y (gives uniform flow in direction orthogonal to gradient)
    directed_stream = a * x_test_grid[:, :, 0] + b * x_test_grid[:, :, 1]
    
    return directed_stream

In [None]:
baseline_stream = generate_directed_stream(x_grid = x_test_grid, angle_degree = 250)

In [None]:
def compose_stream_from_bed(x_grid, underlying_stream = baseline_stream, bed = bed_smoothed, bed_weight = 0.01):
    """Compose a stream function from the bed elevation.
    Arguments:
        x_grid: The grid of x and y coordinates.
        underlying_stream: The base stream function. This should be the same dim as the grid.
        bed (torch.size([whatever, whatever])): The bed elevation field.
    Returns

    """
    # Resample the bed elevation (arbitrary size selected) to match the grid
    bed_on_grid = torch.nn.functional.grid_sample(
        bed.unsqueeze(0).unsqueeze(0), # unsqueeze to add batch and channel dimensions torch.Size([1, 1, HW, HW]),
        x_grid.unsqueeze(0) * 2 - 1, # unsqueeze to add batch dimension torch.Size([1, HW, HW, 2]) and renorm to [-1, 1]
        align_corners = True, 
        mode = 'bilinear')[0, 0] # unsqueeze to remove batch and channel dimensions torch.Size([HW, HW])
    
    # Normalise the bed elevation to have mean 0 and std 1
    bed_on_grid_norm = (bed_on_grid - bed_on_grid.mean()) / bed_on_grid.std()

    # Combined
    combined_stream = underlying_stream + bed_on_grid_norm * bed_weight 

    return combined_stream

In [None]:
x_test_grid.requires_grad_(True)

baseline_stream = generate_directed_stream(
    x_grid = x_test_grid, 
    angle_degree = 340) # 170 degrees is a flow towards South with a smidge of West

combined_stream = compose_stream_from_bed(
    x_grid = x_test_grid,
    underlying_stream = baseline_stream,
    bed = bed_smoothed,
    bed_weight = 0.02)

fig, ax = plt.subplots(figsize = (6, 6))
c = ax.pcolormesh(combined_stream.detach(), cmap = 'viridis')
ax.set_aspect('equal')
ax.set_title("Combined Stream Function (flow along contours)")

# Add colorbar next to the axis
cbar = plt.colorbar(c, ax = ax, fraction = 0.046, pad = 0.04)  # tweak these as needed
cbar.set_label('stream function value')

plt.tight_layout()
plt.show()

In [None]:
# Turn stream function into divergence-free vector field via partial derivatives

In [None]:
grad_psi = torch.autograd.grad(
    outputs = combined_stream,
    inputs = x_test_grid,
    grad_outputs = torch.ones_like(combined_stream) # non scalar setting: equal contributions from each output element
    )[0]  # shape: (n_side, n_side, 2)

# Now form the velocity field via curl:
u = grad_psi[:, :, 1]
v = - grad_psi[:, :, 0]

velocity = torch.stack([u, v], dim = -1)  # shape: (n_side, n_side, 2)

In [None]:
def finite_diff_grad(field, dx=1.0, dy=1.0):
    # field: shape (H, W)
    
    # Compute central differences (interior)
    dψ_dx = (field[:, 2:] - field[:, :-2]) / (2 * dx)
    dψ_dy = (field[2:, :] - field[:-2, :]) / (2 * dy)

    # Pad edges by repeating the nearest value
    left = dψ_dx[:, :1]
    right = dψ_dx[:, -1:]
    dψ_dx = torch.cat([left, dψ_dx, right], dim=1)

    top = dψ_dy[:1, :]
    bottom = dψ_dy[-1:, :]
    dψ_dy = torch.cat([top, dψ_dy, bottom], dim=0)

    # Combine to shape (H, W, 2)
    return torch.stack([dψ_dx, dψ_dy], dim=-1)


grad_psi = finite_diff_grad(combined_stream, dx=1.0, dy=1.0)

u = grad_psi[:, :, 1]   # ∂ψ/∂y
v = -grad_psi[:, :, 0]  # -∂ψ/∂x

velocity = torch.stack([u, v], dim=-1)  # (H, W, 2)

In [None]:
import numpy as np
H, W = velocity.shape[:2]
U = velocity[..., 0].detach().cpu().numpy()  # x-component (u)
V = velocity[..., 1].detach().cpu().numpy()  # y-component (v)

# Create meshgrid for quiver arrows
X, Y = np.meshgrid(np.arange(W), np.arange(H))

# Downsample for clarity
step = 5
plt.figure(figsize = (6, 6))
plt.quiver(X[::step, ::step], Y[::step, ::step],
           U[::step, ::step], V[::step, ::step],
           color='blue', headwidth=3)

plt.gca().invert_yaxis()  # Flip y-axis to match matrix layout
plt.axis('equal')
plt.title('Velocity Field from Stream Function')
plt.tight_layout()
plt.show()

In [None]:
x_test_grid = x_test_grid.clone().detach().requires_grad_(True)

# combined_stream is a scalar field of shape (128, 128)
combined_stream = compose_stream_from_bed(x_test_grid)

# Compute gradient of the summed stream w.r.t. inputs
grad_stream = torch.autograd.grad(
    outputs=combined_stream,                      # shape: [128, 128]
    inputs=x_test_grid,                           # shape: [128, 128, 2]
    grad_outputs=torch.ones_like(combined_stream),# tells autograd to sum over outputs
    create_graph=True
)[0]  # shape: [128, 128, 2]

In [None]:
import numpy as np
# Assume grad_stream: (H, W, 2)
H, W = grad_stream.shape[:2]

# Make meshgrid for plotting (x = columns, y = rows)
X, Y = np.meshgrid(np.arange(W), np.arange(H))

# Get x and y components of gradient
U = grad_stream[..., 0].detach().cpu().numpy()  # ∂ψ/∂x
V = grad_stream[..., 1].detach().cpu().numpy()  # ∂ψ/∂y

# Optional: Downsample for cleaner quiver plot
step = 5
plt.figure(figsize=(6, 6))
plt.quiver(X[::step, ::step], Y[::step, ::step],
           U[::step, ::step], V[::step, ::step],
           color='blue', headwidth=3)

plt.gca().invert_yaxis()  # Optional: flip y to match matrix layout
plt.axis('equal')
plt.title('Gradient of Stream Function')
plt.tight_layout()
plt.show()

# Jacobian

In [None]:
full_jac = torch.autograd.functional.jacobian(compose_stream_from_bed, x_test_grid)

In [None]:
full_jac.shape # (128, 128, 128, 128, 2) where (128, 128) is the output and (128, 128, 2) is the input
diag_jac = full_jac.diagonal(dim1 = 0, dim2 = 2).diagonal(dim1 = 0, dim2 = 1).permute(1, 2, 0)
diag_jac.shape # (128, 128, 2) where (128, 128) is the output and (2) is the input

In [None]:
U = diag_jac[..., 0].detach().cpu().numpy()  # ∂ψ/∂x
V = diag_jac[..., 1].detach().cpu().numpy()  # ∂ψ/∂y

# Optional: Downsample for cleaner quiver plot
step = 5
plt.figure(figsize=(6, 6))
plt.quiver(X[::step, ::step], Y[::step, ::step],
           U[::step, ::step], V[::step, ::step],
           color='blue', headwidth=3)

plt.gca().invert_yaxis()  # Optional: flip y to match matrix layout
plt.axis('equal')
plt.title('Gradient of Stream Function')
plt.tight_layout()
plt.show()