In [1]:
import torch
import torch.nn as nn
import architectures as arch
from functools import partial
import argparse
from pathlib import Path
from torchvision.ops import MLP
from running import make_positional_encoding, train_one_epoch, evaluate
from data import load_navier_stokes_tensor, setup_dataloaders
from architectures import SingleConvNeuralNet


parser = argparse.ArgumentParser(description="Navier–Stokes training script (core logic excerpt).")
parser.add_argument("--data", type=Path, default=Path("../anie/ns_viscosity_data_large/ns_data_visc_8e-4.mat"), help="Path to the .mat dataset produced by the FNO codebase.")
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay (L2 penalty) for Adam optimizer.")
parser.add_argument("--time-points", type=int, default=11, help="Number of temporal frames to sample from the raw data (consistent with notebook).")

parser.add_argument("--share", action="store_true", help="Share weights between modules.")
parser.add_argument("--no-share", dest="share", action="store_false", help="Don't share weights between modules.")
parser.set_defaults(share=True)

parser.add_argument("--refinement", action="store_true", help="Use refinement.")
parser.add_argument("--no-refinement", dest="refinement", action="store_false", help="Don't use refinement.")
parser.set_defaults(refinement=True)

parser.add_argument("--picard", action="store_true", help="Use Picard iterations.")
parser.add_argument("--no-picard", dest="picard", action="store_false", help="Don't use Picard iterations.")
parser.set_defaults(picard=True)

parser.add_argument("--d_model", type=int, default=64)
parser.add_argument("--nhead", type=int, default=4)
parser.add_argument("--dim_feedforward", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--n_layers", type=int, default=4)
parser.add_argument("--n_modules", type=int, default=1)
parser.add_argument("--q", type=int, default=1)
parser.add_argument("--r", type=float, default=0.5)

args = parser.parse_args("")

if args.refinement:
    make_module = partial(arch.AcausalTransformer,
                    d_model=args.d_model,
                    nhead=args.nhead,
                    dim_feedforward=args.dim_feedforward,
                    dropout=args.dropout,
                    n_layers=args.n_layers)
    process_trajectory = arch.broadcast_initial_conditions
else:
    make_module = partial(arch.CausalTransformer,
                    d_model=args.d_model,
                    nhead=args.nhead,
                    dim_feedforward=args.dim_feedforward,
                    dropout=args.dropout,
                    n_layers=args.n_layers)
    process_trajectory = arch.Identity
if args.share:
    modules = arch.make_weight_shared_modules(make_module, n_modules=args.n_modules)
else:
    modules = arch.make_weight_unshared_modules(make_module, n_modules=args.n_modules)
if args.picard:
    model = arch.PicardIterations(modules, q=args.q, r=args.r)
else:
    model = arch.ArbitraryIterations(modules)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

tensor = load_navier_stokes_tensor(args.data, time_points=args.time_points)



Using device: cuda


In [2]:
tensor = tensor.to(device)
train_loader, val_loader = setup_dataloaders(tensor, batch_size=args.batch_size)

P = 3
N, T, H, W, Q = tensor.shape
encoder = SingleConvNeuralNet(dim=Q,
                                hidden_dim=args.d_model-P,
                                out_dim=args.d_model-P,
                                hidden_ff=128,
                                K=[4,4],
                                S=[4,4])
encoder = encoder.to(device)

# Dummy forward pass to get shapes
with torch.no_grad():
    _, _, H_prime, W_prime, _ = encoder.forward(tensor[0, None, ...].to(device)).shape

decoder = MLP(
    in_channels=H_prime*W_prime*(args.d_model-P),
    hidden_channels=[64, 256, H*W],
    activation_layer=nn.ELU,
)

positional_encodings = make_positional_encoding(T, H_prime, W_prime, device=device)
positional_encodings = positional_encodings[None,...].repeat(args.batch_size,1,1,1,1)

decoder = decoder.to(device)

optim = torch.optim.Adam(
    list(model.parameters()) + list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr, weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)

for epoch in range(1, args.epochs + 1):
    train_loss = train_one_epoch(model,
                                 encoder,
                                 decoder,
                                 process_trajectory,
                                 train_loader,
                                 optim,
                                 positional_encodings)
    val_loss = evaluate(model,
                        encoder,
                        decoder,
                        process_trajectory,
                        val_loader,
                        positional_encodings)
    scheduler.step()
    print(f"Epoch {epoch:3d} | train loss: {train_loss:.6f} | val loss: {val_loss:.6f}")

torch.save({"state_dict": model.state_dict()}, Path("tmpweights.pt"))

Epoch   1 | train loss: 0.323091 | val loss: 0.137176
Epoch   2 | train loss: 0.120290 | val loss: 0.108359
Epoch   3 | train loss: 0.098632 | val loss: 0.091244
Epoch   4 | train loss: 0.086226 | val loss: 0.083542
Epoch   5 | train loss: 0.078092 | val loss: 0.079151
Epoch   6 | train loss: 0.072348 | val loss: 0.074000
Epoch   7 | train loss: 0.067465 | val loss: 0.071946
Epoch   8 | train loss: 0.063301 | val loss: 0.066403
Epoch   9 | train loss: 0.060048 | val loss: 0.064498
Epoch  10 | train loss: 0.057972 | val loss: 0.063673


In [3]:
from running import format_input_for_model, reformat_output_from_model, reformat_output_from_decoder
import torch.nn.functional as F

In [4]:
for trajectories in train_loader:
    encoder_inputs = process_trajectory(trajectories)  # (B, T, H, W, Q)
    encoder_outputs = encoder(encoder_inputs)  # (B, T, H', W', Q')
    model_inputs = format_input_for_model(encoder_outputs, positional_encodings)  # (B, T*H'*W', P+Q')
    model_outputs = model(model_inputs)  # (B, T*H'*W', P+Q')
    decoder_inputs = reformat_output_from_model(model_outputs, encoder_outputs)  # (B, T, H'*W'*Q')
    decoder_outputs = decoder(decoder_inputs)  # (B, T, H, W, Q)
    outputs = reformat_output_from_decoder(decoder_outputs, trajectories)

In [5]:
outputs.shape, trajectories.shape

(torch.Size([8, 11, 64, 64, 1]), torch.Size([8, 11, 64, 64, 1]))

In [6]:
train_loader.dataset.dataset.data[0, 0] - train_loader.dataset.dataset.data[0, 1]

tensor([[[-0.4313],
         [-0.4721],
         [-0.5008],
         ...,
         [-0.2531],
         [-0.3033],
         [-0.3696]],

        [[-0.4638],
         [-0.5088],
         [-0.5317],
         ...,
         [-0.3009],
         [-0.3443],
         [-0.4117]],

        [[-0.4784],
         [-0.5163],
         [-0.5381],
         ...,
         [-0.3640],
         [-0.3922],
         [-0.4323]],

        ...,

        [[-0.2985],
         [-0.3426],
         [-0.3830],
         ...,
         [-0.1068],
         [-0.1771],
         [-0.2381]],

        [[-0.3529],
         [-0.3886],
         [-0.4325],
         ...,
         [-0.1600],
         [-0.2278],
         [-0.2926]],

        [[-0.3873],
         [-0.4280],
         [-0.4671],
         ...,
         [-0.2007],
         [-0.2607],
         [-0.3224]]], device='cuda:0')