In [23]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.animation import FuncAnimation
import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython.display import HTML

In [24]:
def generate_time_varying_flow(size=224, timesteps=60, freq=0.5, min_magnitude=0.1, max_magnitude=0.9, center=None):
    if center is None:
        center = (size//2, size//2)

    def generate_flow(t):
        y, x = np.mgrid[:size, :size]
        x = x - center[0]
        y = y - center[1]
    
        # Calculate distance from center
        r = np.sqrt(x**2 + y**2) + 1e-6  # Add small epsilon to avoid division by zero
    
        # Normalize coordinates
        x_norm = x / (size/2)  # ranges roughly from -1 to 1
        y_norm = y / (size/2)  # ranges roughly from -1 to 1
        r_norm = r / size      # ranges roughly from 0 to 0.7
    
        # Create the basic flow pattern
        u = y_norm * np.sin(r_norm * 5 + t) + x_norm * np.cos(freq * t)
        v = x_norm * np.sin(r_norm * 5 + t) - y_norm * np.cos(freq * t)
    
        # Calculate current magnitudes
        magnitude = np.sqrt(u**2 + v**2)
    
        # Normalize to [0,1] range
        magnitude_max = np.max(magnitude)
        magnitude_min = np.min(magnitude)
    
        if magnitude_max > magnitude_min:  # Avoid division by zero
            # First normalize to [0,1]
            normalized_magnitude = (magnitude - magnitude_min) / (magnitude_max - magnitude_min)
        
            # Then scale to [min_magnitude, max_magnitude]
            target_magnitude = min_magnitude + normalized_magnitude * (max_magnitude - min_magnitude)
        
            # Scale the vectors to have the target magnitude while preserving direction
            scale_factor = target_magnitude / (magnitude + 1e-10)
            u = u * scale_factor
            v = v * scale_factor
    
        return np.stack([u, v], axis=-1)
    
    flows = np.zeros((timesteps, size, size, 2))
    for i in range(timesteps):
        t = i * 2*np.pi / timesteps
        flows[i] = generate_flow(t=t)

    return flows

In [25]:
def animate_flow_field(flows, fps=30, density=20):
    # Get dimensions from flows array
    n_frames, size, _, _ = flows.shape
    
    # Set up the figure
    fig, ax = plt.subplots(figsize=(10, 10))
    
    # Create a custom colormap
    colors = [(0, 0, 0.5), (0, 0.5, 1), (1, 1, 1), (1, 0.5, 0), (0.5, 0, 0)]
    cmap = LinearSegmentedColormap.from_list('flow_cmap', colors, N=100)
    
    # Initial flow field
    u, v = flows[0, ..., 0], flows[0, ..., 1]
    magnitude = np.sqrt(u**2 + v**2)
    
    # Plot initial background
    img = ax.imshow(magnitude, cmap=cmap, origin='lower', animated=True)
    fig.colorbar(img, label='Velocity Magnitude')
    
    # Subsample for quiver plot
    step = size // density
    y, x = np.mgrid[:size:step, :size:step]
    u_sub = u[::step, ::step]
    v_sub = v[::step, ::step]
    
    # Create quiver plot
    quiver = ax.quiver(x, y, u_sub, v_sub, 
                      angles='xy', scale_units='xy', scale=0.1,
                      color='black', width=0.003, 
                      headwidth=4, headlength=5, headaxislength=3)
    
    # Add title with frame indicator
    title = ax.set_title(f"Flow Field (frame: 0/{n_frames-1})")
    
    # Animation update function
    def update(frame):
        # Get the flow field for this frame
        u, v = flows[frame, ..., 0], flows[frame, ..., 1]
        magnitude = np.sqrt(u**2 + v**2)
        
        # Update background image
        img.set_array(magnitude)
        
        # Update quiver
        u_sub = u[::step, ::step]
        v_sub = v[::step, ::step]
        quiver.set_UVC(u_sub, v_sub)
        
        # Update title
        title.set_text(f"Flow Field (frame: {frame}/{n_frames-1})")
        
        return img, quiver, title
    
    # Create animation
    anim = FuncAnimation(fig, update, frames=n_frames, 
                         interval=1000/fps, blit=True)
    
    plt.close(fig)
    
    return anim

In [26]:
SIZE = 256
TIMESTEPS = 10
flows = generate_time_varying_flow(size=SIZE, timesteps=TIMESTEPS)

anim = animate_flow_field(flows, fps=1, density=SIZE//16)
HTML(anim.to_jshtml())

In [32]:
class ConcatSquashLinear(nn.Module):
	def __init__(self, dim_in, dim_out, dim_c):
		super(ConcatSquashLinear, self).__init__()
		self._layer = nn.Linear(dim_in, dim_out)
		self._hyper_bias = nn.Linear(dim_c, dim_out, bias=False)
		self._hyper_gate = nn.Linear(dim_c, dim_out)

	def forward(self, context, x):
		gate = torch.sigmoid(self._hyper_gate(context))
		bias = self._hyper_bias(context)
		ret = self._layer(x) * gate + bias
		return ret

class FourierFeatureODE(nn.Module):
	def __init__(self, input_dim, hidden_dims, fourier_scale=10.0, num_fourier_features=10):
		super(FourierFeatureODE, self).__init__()
		
		self.input_dim = input_dim
		self.num_fourier_features = num_fourier_features
		
		self.register_buffer('B', torch.randn(input_dim, num_fourier_features) * fourier_scale)
		
		fourier_expanded_dim = input_dim + (2 * num_fourier_features)
		dim_list = [fourier_expanded_dim] + list(hidden_dims) + [input_dim]
		layers = []
		for i in range(len(dim_list) - 1):
			layers.append(ConcatSquashLinear(dim_list[i], dim_list[i + 1], 1))
		self.layers = nn.ModuleList(layers)
	
	def compute_positional_fourier_features(self, x):
		encodings = [x]
		for i in range(self.num_fourier_features // 2):
			freq = 2.0 ** i
			sin_features = torch.sin(freq * x)
			cos_features = torch.cos(freq * x)
			encodings.append(sin_features)
			encodings.append(cos_features)
		return torch.cat(encodings, dim=-1)

	def forward(self, coordinates, times):
		coordinates_fourier = self.compute_positional_fourier_features(coordinates)
		h = coordinates_fourier
		for l, layer in enumerate(self.layers):
			h = layer(times, h)
			if l < len(self.layers) - 1:
				h = F.tanh(h)
		return h

In [33]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [34]:
y_coords = torch.arange(SIZE).view(1, 1, SIZE, 1).expand(TIMESTEPS, SIZE, SIZE, 1)
x_coords = torch.arange(SIZE).view(1, SIZE, 1, 1).expand(TIMESTEPS, SIZE, SIZE, 1)
coordinate_tensor = torch.cat([x_coords, y_coords], dim=3)
coordinate_tensor = coordinate_tensor.float()
coordinate_tensor = coordinate_tensor.to(device)

time_tensor = torch.FloatTensor([t * 2*np.pi / TIMESTEPS for t in range(TIMESTEPS)])
time_tensor = time_tensor.view(TIMESTEPS, 1, 1, 1)
time_tensor = time_tensor.expand(TIMESTEPS, SIZE, SIZE, 1)
time_tensor = time_tensor.to(device)

In [35]:
def fit_flow_field(ode, flow_field, coordinates, times):
    optim = torch.optim.Adam(ode.parameters(), lr=1e-3, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.999)

    for epoch in range(10000):
        estimated_flow_field = ode(coordinates, times)

        loss = F.mse_loss(flow_field, estimated_flow_field)
        if epoch == 0 or (epoch < 1000 and (epoch + 1) % 100 == 0) or (epoch + 1) % 1000 == 0:
            print(f'Epoch {epoch+1}, Loss: {loss}')

        optim.zero_grad()
        loss.backward()
        optim.step()
        scheduler.step()

    return estimated_flow_field

In [36]:
flows_tensor = torch.FloatTensor(flows)
flows_tensor = flows_tensor.to(device)
ode = FourierFeatureODE(2, (256 for _ in range(3)), num_fourier_features=126).to(device)
estimated_flows_tensor = fit_flow_field(ode, flows_tensor, coordinate_tensor, time_tensor)

Epoch 1, Loss: 4.226852893829346
Epoch 100, Loss: 0.037353772670030594
Epoch 200, Loss: 0.012350354343652725
Epoch 300, Loss: 0.004505823832005262
Epoch 400, Loss: 0.002251699334010482
Epoch 500, Loss: 0.001512794871814549
Epoch 600, Loss: 0.0011485472787171602
Epoch 700, Loss: 0.0009340856340713799
Epoch 800, Loss: 0.0007892322028055787
Epoch 900, Loss: 0.0006869665230624378
Epoch 1000, Loss: 0.0006110970280133188
Epoch 2000, Loss: 0.0003342716081533581
Epoch 3000, Loss: 0.0002661958278622478
Epoch 4000, Loss: 0.00023915377096273005
Epoch 5000, Loss: 0.00022653581982012838
Epoch 6000, Loss: 0.00022023313795216382
Epoch 7000, Loss: 0.00021699443459510803
Epoch 8000, Loss: 0.00021531821403186768
Epoch 9000, Loss: 0.00021445580932777375
Epoch 10000, Loss: 0.00021402715356089175


In [37]:
estimated_flows = estimated_flows_tensor.detach().cpu().numpy()
anim = animate_flow_field(estimated_flows, fps=1, density=SIZE//16)
HTML(anim.to_jshtml())