In [1]:
import torch
import torch.nn as nn
import architectures as arch
from functools import partial
import argparse
from pathlib import Path
from torchvision.ops import MLP
from training import build_trainer
from data import load_navier_stokes_tensor, setup_dataloaders
from architectures import SingleConvNeuralNet


parser = argparse.ArgumentParser(description="Navier–Stokes training script (core logic excerpt).")
parser.add_argument("--data", type=Path, default=Path("../anie/ns_viscosity_data_large/ns_data_visc_8e-4.mat"), help="Path to the .mat dataset produced by the FNO codebase.")
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay (L2 penalty) for Adam optimizer.")
parser.add_argument("--n-timesteps", type=int, default=11, help="Number of temporal frames to sample from the raw data (consistent with notebook).")

parser.add_argument("--share", action="store_true", help="Share weights between modules.")
parser.add_argument("--no-share", dest="share", action="store_false", help="Don't share weights between modules.")
parser.set_defaults(share=True)

parser.add_argument("--refinement", action="store_true", help="Use refinement.")
parser.add_argument("--no-refinement", dest="refinement", action="store_false", help="Don't use refinement.")
parser.set_defaults(refinement=False)

parser.add_argument("--picard", action="store_true", help="Use Picard iterations.")
parser.add_argument("--no-picard", dest="picard", action="store_false", help="Don't use Picard iterations.")
parser.set_defaults(picard=True)

parser.add_argument("--d_model", type=int, default=64)
parser.add_argument("--nhead", type=int, default=4)
parser.add_argument("--dim_feedforward", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--n_layers", type=int, default=4)
parser.add_argument("--n_modules", type=int, default=1)
parser.add_argument("--q", type=int, default=1)
parser.add_argument("--r", type=float, default=0.5)

args = parser.parse_args("")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

init_conds, trajs = load_navier_stokes_tensor(args.data, n_timesteps=args.n_timesteps)
init_conds = init_conds.to(device)
trajs = trajs.to(device)

train_loader, val_loader = setup_dataloaders(init_conds, trajs, batch_size=args.batch_size)
P = 3
N, T, H, W, Q = trajs.shape

In [3]:
encoder = SingleConvNeuralNet(dim=Q,
                                hidden_dim=args.d_model-P,
                                out_dim=args.d_model-P,
                                hidden_ff=128,
                                K=[4,4],
                                S=[4,4])
encoder = encoder.to(device)

# Dummy forward pass to get shapes
with torch.no_grad():
    _, _, H_prime, W_prime, _ = encoder.forward(trajs[0, None, ...].to(device)).shape
block_size = H_prime * W_prime

if args.refinement:
    make_module = partial(arch.AcausalTransformer,
                    d_model=args.d_model,
                    nhead=args.nhead,
                    dim_feedforward=args.dim_feedforward,
                    dropout=args.dropout,
                    n_layers=args.n_layers)
else:
    make_module = partial(arch.BlockCausalTransformer,
                          block_size=block_size,
                          d_model=args.d_model,
                          nhead=args.nhead,
                          dim_feedforward=args.dim_feedforward,
                          dropout=args.dropout,
                          n_layers=args.n_layers)
if args.share:
    modules = arch.make_weight_shared_modules(make_module, n_modules=args.n_modules)
else:
    modules = arch.make_weight_unshared_modules(make_module, n_modules=args.n_modules)
if args.picard:
    model = arch.PicardIterations(modules, q=args.q, r=args.r)
else:
    model = arch.ArbitraryIterations(modules)
model = model.to(device)

decoder = MLP(
    in_channels=H_prime*W_prime*(args.d_model-P),
    hidden_channels=[64, 256, H*W],
    activation_layer=nn.ELU,
)
decoder = decoder.to(device)

In [None]:
optim = torch.optim.Adam(
    list(model.parameters()) + list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr, weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)

kind = 'causal_mane_steps'

trainer = build_trainer(kind,
                        encoder=encoder,
                        model=model,
                        decoder=decoder)

for epoch in range(1, args.epochs + 1):
    train_loss = trainer.train_epoch(train_loader, optim)
    val_loss   = trainer.eval_epoch(val_loader)
    scheduler.step()
    print(f"Epoch {epoch:3d} | train loss: {train_loss:.6f} | val loss: {val_loss:.6f}")

torch.save({"state_dict": model.state_dict()}, Path("tmpweights.pt"))