# Rotating Bar Example

A simple work-through example to model the effect of a rotating bar in a simple Milky Way system


In [10]:
import logging

import galax.potential as gp
import numpy as np
import optax
import unxt as u
from flax import nnx

from galactoPINNs.data import (
    flatten_time_dict_by_time,
    generate_time_dep_data,
    scale_data_time,
)
from galactoPINNs.evaluate import evaluate_performance_node
from galactoPINNs.models.node_model import NODEModel
from galactoPINNs.train import train_model_node

## Set up logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)

In [None]:
####
### Generate synthetic data for training and testing
####

MW_potential = gp.MilkyWayPotential(units="galactic")
true_rotation_rate = 2 # degree per Myr

def alpha_of_t(t: u.Quantity["time"]) -> u.Quantity["angle"]:
    t_myr = t.to_value("myr")
    return u.Quantity(t_myr * true_rotation_rate, "degree")

rot_bar = gp.LongMuraliBarPotential(
            m_tot=u.Quantity(1e10, "Msun"), a=u.Quantity(4.0, "kpc"), b=u.Quantity(1.0, "kpc"), c=u.Quantity(1.5, "kpc"),
            alpha = alpha_of_t, units="galactic")

mw_lmc_bar = MW_potential + rot_bar

true_analytic_function = mw_lmc_bar
analytic_baseline_potenial = gp.NFWPotential(m=6.0e11, r_s=5.0, units="galactic")

In [12]:
####
### Generate the training and testing sets
####

N_samples_train = 1024
N_samples_test = 1024
r_max_train = 150
r_max_test = 200

ts_train = np.linspace(0, 100, 6)
ts_test = np.linspace(0, 180, 8)
times_train = ts_train
times_test = ts_test

raw_datadict = generate_time_dep_data(
    galax_potential= true_analytic_function,
    times_train = times_train,
    times_test = times_test,
    n_samples_train= N_samples_train,
    n_samples_test= N_samples_test,
    r_max_train =100,    # kpc
    r_max_test  =150,     # kpc
)


In [None]:
####
### Nondimensionalize the data, and set up the model configuration
####

distribution = "superposition"
include_analytic = True
scale = "nfw"
r_s = 15.62

initial_config = {
    "r_s": r_s,
    "include_analytic": include_analytic,
    "ab_potential": analytic_baseline_potenial}

# nondimensionalize the data
data, transformers = scale_data_time(
    raw_datadict, initial_config)

# configure the desired model features
config = {
    "x_transformer": transformers["x"],
    "a_transformer": transformers["a"],
    "u_transformer": transformers["u"],
    "t_transformer": transformers["t"],
    "r_s": r_s,
    "scale": scale,
    "enforce_bc": False,
    "include_analytic": include_analytic,
    "delta_phi_depth": 3,
    "delta_phi_width": 64,
    "initial_correction_depth": 4,
    "integration_mode": "gl3"
    }


In [5]:
###
## Initialize and train a model with a non-trainable baseline potential
###

## set up the model and hyperparameters
l_rel = 0.5  # weight for the relative loss term
lr0 = 1e-3  # initial learning rate
tx = optax.adam(lr0)
net = NODEModel(config,rngs=nnx.Rngs(0))

## set up the optimizer
opt = nnx.Optimizer(net, optax.adam(lr0), wrt=nnx.Param)

x_train, a_train = flatten_time_dict_by_time(data, split="train")

In [6]:
###
## Train the model
## Should converge in < 1 minute. For optimal results, train for much longer.
###

out = train_model_node(
    model=net,
    optimizer=opt,
    x_train = x_train,
    a_train = a_train,
    num_epochs = 1000,
    log_every = 100,
    analytic_potential = analytic_baseline_potenial
)

2026-01-28 17:51:20,838 | INFO | galactoPINNs.train | Epoch 0, Loss: 1.450405
2026-01-28 17:51:24,490 | INFO | galactoPINNs.train | Epoch 100, Loss: 0.028884
2026-01-28 17:51:29,647 | INFO | galactoPINNs.train | Epoch 200, Loss: 0.021541
2026-01-28 17:51:34,785 | INFO | galactoPINNs.train | Epoch 300, Loss: 0.019746
2026-01-28 17:51:39,910 | INFO | galactoPINNs.train | Epoch 400, Loss: 0.017753
2026-01-28 17:51:45,009 | INFO | galactoPINNs.train | Epoch 500, Loss: 0.016794
2026-01-28 17:51:50,117 | INFO | galactoPINNs.train | Epoch 600, Loss: 0.017211
2026-01-28 17:51:55,229 | INFO | galactoPINNs.train | Epoch 700, Loss: 0.014537
2026-01-28 17:52:00,386 | INFO | galactoPINNs.train | Epoch 800, Loss: 0.014594
2026-01-28 17:52:05,734 | INFO | galactoPINNs.train | Epoch 900, Loss: 0.013982


In [9]:
###
## Evaluate performance
###

perf = evaluate_performance_node(
    model=out["model"],
    t_eval = times_test[-1],
    raw_datadict = raw_datadict,
    num_test = 800,
    analytic_baseline=analytic_baseline_potenial
)