In [None]:
PATH_TO_BHPTNRSur = "/home/ubuntu/EG-UT/BHPTNRSurrogate"
import numpy as np
import matplotlib.pyplot as plt
import sys
# add the path to the script directory
sys.path.append(PATH_TO_BHPTNRSur)
from surrogates import BHPTNRSur1dq1e4 as bhptsur

In [None]:
tsur, hsur = bhptsur.generate_surrogate(q=2.5)
print(hsur[(2,2)][0])

#print(hsur.keys())

plt.figure(figsize=(20,4))
plt.plot(tsur, np.real(hsur[(2,2)]), '-', label='22')
plt.plot(tsur, np.real(hsur[(3,3)]), '-', label='33')
plt.xlabel('time [seconds]', fontsize=15)
plt.ylabel('rh/M', fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.legend(fontsize=12)
plt.show()

In [None]:
!python3 generate_bhpt_dataset.py --n-samples 5000  --q-min 2.5 --q-max 2.5 --n-timesteps 4096

cannot import LAL
**** Surrogate loaded: BHPTNRSur1dq1e4 ****
First 5 q values: [2.5 2.5 2.5 2.5 2.5]
Generating waveforms:  32%|█████▍           | 1600/5000 [02:08<04:34, 12.40it/s]

In [None]:
"python3 generate_bhpt_dataset.py --n-samples 256  --q-min 2.5 --q-max 10.0"

import torch
import torch.nn as nn
import architectures as arch
from functools import partial
import argparse
from pathlib import Path
from torchvision.ops import MLP
from data import setup_waveform_dataloaders, WaveformDataset, load_bhpt_tensors
from architectures import SingleConvNeuralNet, GalerkinTransformer
from bhpt_running import RefinementPipeline, Trainer

parser = argparse.ArgumentParser(description="BHPT training script (core logic excerpt).")
parser.add_argument("--data", type=Path, default=Path("bhpt_dataset.pt"), help="Path to the .pt dataset produced by generate_bhpt_dataset.py.")
parser.add_argument("--epochs", type=int, default=100000)
parser.add_argument("--batch-size", type=int, default=1000)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--weight-decay", type=float, default=0, help="Weight decay (L2 penalty) for Adam optimizer.")
parser.add_argument("--n-timesteps", type=int, default=None, help="Number of temporal frames to sample from the raw data (consistent with notebook).")

parser.add_argument("--share", action="store_true", help="Share weights between modules.")
parser.add_argument("--no-share", dest="share", action="store_false", help="Don't share weights between modules.")
parser.set_defaults(share=True)

parser.add_argument("--refinement", action="store_true", help="Use refinement.")
parser.add_argument("--no-refinement", dest="refinement", action="store_false", help="Don't use refinement.")
parser.set_defaults(refinement=True)

parser.add_argument("--picard", action="store_true", help="Use Picard iterations.")
parser.add_argument("--no-picard", dest="picard", action="store_false", help="Don't use Picard iterations.")
parser.set_defaults(picard=True)

parser.add_argument("--d_model", type=int, default=31)
parser.add_argument("--nhead", type=int, default=4)
parser.add_argument("--dim_feedforward", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--n_layers", type=int, default=2)
parser.add_argument("--n_modules", type=int, default=1)
parser.add_argument("--q", type=int, default=1)
parser.add_argument("--r", type=float, default=0.5)

args = parser.parse_args("")

In [None]:
import torch
import matplotlib.pyplot as plt
from pathlib import Path

# --- Configuration ---
# Adjust this path if your dataset is saved elsewhere
dataset_path = Path("bhpt_dataset.pt")
# How many random samples to plot from the dataset
n_to_plot = 3
# ---------------------

# 1. Load the dataset from the .pt file
if not dataset_path.exists():
    print(f"Error: Dataset file not found at '{dataset_path}'")
    print("Please make sure the path is correct and you have generated the dataset.")
else:
    # Plotting doesn't need a GPU, so we map to CPU
    device = torch.device("cpu")
    data = torch.load(dataset_path, map_location=device)
    waveforms = data["waveforms"]  # Shape: (N, T, 1, 1, 2)
    params = data["params"]        # Shape: (N, 1)

    n_samples, n_timesteps, _, _, _ = waveforms.shape
    print(f"Dataset loaded with {n_samples} samples and {n_timesteps} timesteps.")

    # 2. Select random samples to plot
    if n_samples < n_to_plot:
        print(f"Warning: Only {n_samples} samples available, plotting all of them.")
        indices_to_plot = range(n_samples)
    else:
        # Generate unique random indices without replacement
        indices_to_plot = torch.randperm(n_samples)[:n_to_plot]

    # 3. Create the plots
    # Create a figure with one column and n_to_plot rows
    fig, axes = plt.subplots(
        nrows=n_to_plot,
        ncols=1,
        figsize=(18, 4 * n_to_plot),
        squeeze=False  # Always return a 2D array for axes, even if nrows=1
    )

    for i, sample_idx in enumerate(indices_to_plot):
        ax = axes[i, 0]  # Get the current subplot axis
        
        # Extract the data for the chosen sample
        waveform_sample = waveforms[sample_idx]  # Shape: (T, 1, 1, 2)
        q_value = params[sample_idx].item()      # Get scalar value
        
        # The time axis is just the sample index
        time_axis = range(n_timesteps)
        
        # Squeeze out the singleton H and W dimensions (1, 1) to get (T, 2)
        # then extract the plus and cross polarisations.
        h_plus = waveform_sample.squeeze()[:, 0]
        h_cross = waveform_sample.squeeze()[:, 1]
        
        # Plot h_plus and h_cross on the same axes
        ax.plot(time_axis, h_plus, label=r'$h_+$ (plus)', lw=1.5)
        ax.plot(time_axis, h_cross, label=r'$h_\times$ (cross)', lw=1.5, linestyle='--')
        
        # --- Formatting ---
        ax.set_title(f"Sample #{sample_idx}: Mass Ratio q = {q_value:.3f}")
        ax.set_xlabel("Time Step Index")
        ax.set_ylabel("Strain (scaled)")
        ax.legend()
        ax.grid(True, linestyle=':', alpha=0.7)

    # Add a main title to the figure and adjust layout
    fig.suptitle("BHPT Dataset Waveform Visualization", fontsize=16, y=0.99)
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.show()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 1. Load data from the .pt file into two tensors
waveforms_tensor, params_tensor = load_bhpt_tensors(
    args.data, n_timesteps=args.n_timesteps
)

# 2. Create DataLoaders from the tensors
train_loader, val_loader = setup_waveform_dataloaders(
    waveforms=waveforms_tensor.to(device),
    params=params_tensor.to(device),
    batch_size=args.batch_size
)

# 3. Define dimensions for the model
P = 3  # Positional encoding dimensions (t, y, x)
N, T, H, W, Q = waveforms_tensor.shape
_, n_params = params_tensor.shape

print("DataLoaders are ready.")



In [None]:
# Cell 3: Model Definition
# For waveform data (H=1, W=1), the "convolution" is just a linear projection.
# Kernel size K and stride S are (1,1).

encoder_out_dim = args.d_model - P

encoder = SingleConvNeuralNet(dim=Q,
                                hidden_dim=args.d_model-P,
                                out_dim=args.d_model-P,
                                hidden_ff=128,
                                K=[1,1],
                                S=[1,1])

encoder = encoder.to(device)

# Dummy forward pass to get the shape for the decoder
with torch.no_grad():
    sample_waveforms = waveforms_tensor[0, None, ...].to(device)
    _, _, H_prime, W_prime, _ = encoder.forward(sample_waveforms).shape
# The decoder's input channels depend on the encoder's output shape
decoder_in_channels = H_prime * W_prime * encoder_out_dim

# The total dimension fed to the transformer will be d_model + n_params
d_model_conditioned = args.d_model + n_params

if args.refinement:
    make_module = partial(arch.GalerkinTransformer,
                          d_model=d_model_conditioned,
                          nhead=args.nhead,
                          dim_feedforward=args.dim_feedforward,
                          dropout=args.dropout,
                          n_layers=args.n_layers)
    process_trajectory = arch.broadcast_initial_conditions
else:
    # Autoregressive training is not implemented for the BHPT case yet
    raise NotImplementedError("Only refinement mode is set up for BHPT training.")

if args.share:
    modules = arch.make_weight_shared_modules(make_module, n_modules=args.n_modules)
else:
    modules = arch.make_weight_unshared_modules(make_module, n_modules=args.n_modules)

if args.picard:
    model = arch.PicardIterations(modules, q=args.q, r=args.r)
else:
    model = arch.ArbitraryIterations(modules)
model = model.to(device)


decoder = MLP(
    in_channels=decoder_in_channels,
    hidden_channels=[64, 256, H * W * Q], # Final output must match original H, W, Q
    activation_layer=nn.ELU,
)

decoder = decoder.to(device)

print(f"Model, Encoder, and Decoder are on device: {next(model.parameters()).device}")
print(f"Transformer d_model: {d_model_conditioned} (divisible by nhead={args.nhead})")

In [None]:
# Import necessary libraries for animation
import matplotlib.animation as animation
from IPython.display import HTML
from pathlib import Path
import numpy as np

# --- 1. Animation Function ---
# This function will be used to create a GIF of the training progress.
def create_animation_from_data(predictions_over_time, true_wave, losses, epoch_numbers, save_path=None):
    """
    Create an animation comparing model predictions to the true waveform over epochs.
    
    Args:
        predictions_over_time (list): A list of numpy arrays, where each array is a
                                      model prediction (T, 2) for a given epoch.
        true_wave (np.array): The ground truth waveform (T, 2) to compare against.
        losses (list): A list of tuples, each containing (train_loss, val_loss) for an epoch.
        save_path (str or Path, optional): Path to save the animation GIF.
    """
    fig_anim, ax_anim = plt.subplots(figsize=(15, 7))

    true_plus = true_wave[:, 0]
    true_cross = true_wave[:, 1]
    time_axis = range(len(true_plus))

    # Plot the true signals as dashed lines
    line_true_plus, = ax_anim.plot(time_axis, true_plus, 'b--', linewidth=1.5, label=r'True $h_+$', alpha=0.6)
    line_true_cross, = ax_anim.plot(time_axis, true_cross, 'c--', linewidth=1.5, label=r'True $h_\times$', alpha=0.6)
    
    # Plot the predicted signals as solid lines (data will be updated in animate)
    line_pred_plus, = ax_anim.plot([], [], 'r-', linewidth=1.5, label=r'Pred $h_+$', alpha=0.8)
    line_pred_cross, = ax_anim.plot([], [], 'g-', linewidth=1.5, label=r'Pred $h_\times$', alpha=0.8)
    
    # Determine plot limits from the true signal for stability
    y_min = true_wave.min() * 1.2
    y_max = true_wave.max() * 1.2
    ax_anim.set_xlim(0, len(true_plus))
    ax_anim.set_ylim(y_min, y_max)
    
    # Formatting
    ax_anim.set_xlabel('Time Steps')
    ax_anim.set_ylabel('Strain (scaled)')
    ax_anim.set_title('Training Progress: Prediction vs. Ground Truth')
    ax_anim.legend(loc='upper right')
    ax_anim.grid(True, linestyle=':', alpha=0.5)
    
    # Text box for epoch and loss information
    text_info_anim = ax_anim.text(0.02, 0.95, '', transform=ax_anim.transAxes, fontsize=12, 
                                 bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.7),
                                 verticalalignment='top')
    
    # Animation function: this is called sequentially for each frame
    def animate(frame):
        if frame < len(predictions_over_time):
            prediction = predictions_over_time[frame]
            train_loss, val_loss = losses[frame]
            
            # Update prediction lines
            line_pred_plus.set_data(time_axis, prediction[:, 0])
            line_pred_cross.set_data(time_axis, prediction[:, 1])
            
            # Update text box
            # Update text box
            epoch_num = epoch_numbers[frame]
            text_info_anim.set_text(f'Epoch: {epoch_num}\nTrain Loss: {train_loss:.7f}\nVal Loss: {val_loss:.7f}')
        
        return line_pred_plus, line_pred_cross, text_info_anim

    # Create the animation
    anim = animation.FuncAnimation(fig_anim, animate, frames=len(predictions_over_time), 
                                 interval=200, blit=True)
    
    # Save the animation if a path is provided
    if save_path:
        print(f"Saving animation to {save_path}...")
        anim.save(str(save_path), writer='pillow', fps=5)
        print("Save complete.")
    
    plt.close(fig_anim) # Avoid displaying the static plot in the notebook
    return anim

# --- 2. Training Setup ---
# The optimizer needs to know about all trainable parameters
all_params = list(model.parameters()) + list(encoder.parameters()) + list(decoder.parameters())
optim = torch.optim.Adam(all_params, lr=args.lr, weight_decay=args.weight_decay)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)

# Instantiate the pipeline and trainer from your notebook's logic
pipeline = RefinementPipeline(
    model=model,
    encoder=encoder,
    decoder=decoder,
    process_trajectory=process_trajectory
).to(device)

trainer = Trainer(pipeline)

# Get a fixed sample from the validation set for consistent visualization
try:
    vis_wave_batch, vis_param_batch = next(iter(val_loader))
    true_wave_for_vis = vis_wave_batch[0].squeeze().cpu().numpy()
except StopIteration:
    print("Validation loader is empty. Using a sample from the training loader for visualization.")
    vis_wave_batch, vis_param_batch = next(iter(train_loader))
    true_wave_for_vis = vis_wave_batch[0].squeeze().cpu().numpy()

# --- 3. Training Loop ---
animation_save_frequency = 50

predictions_history = []
losses_history = []
epochs_for_animation = []

best_val_loss = float('inf')
save_path = Path("best_bhpt_weights.pt")

print("Starting training...")
for epoch in range(1, args.epochs + 1):
    # Set models to training mode
    pipeline.train()
    train_loss = trainer.train_epoch(train_loader, optim)
    
    # Set models to evaluation mode
    pipeline.eval()
    val_loss = trainer.eval_epoch(val_loader)
    
    # Generate a prediction for the visualization sample
    with torch.no_grad():
        pred_wave = pipeline(vis_wave_batch, vis_param_batch)
    
    # Store history for animation
    if epoch % animation_save_frequency == 0 or epoch == 1 or epoch == args.epochs:
        p = pred_wave[0].squeeze().cpu().numpy().copy()
        predictions_history.append(p)
        losses_history.append((train_loss, val_loss))
        epochs_for_animation.append(epoch)

    scheduler.step()
    print(f"Epoch {epoch:3d} | Train Loss: {train_loss:.7f} | Val Loss: {val_loss:.7f} | LR: {scheduler.get_last_lr()[0]:.1e}")
    if epoch % 100 == 0: 
        torch.save({
            'epoch': epoch, 
            'model_state_dict': model.state_dict(), 
            'encoder_state_dict': encoder.state_dict(), 
            'decoder_state_dict': decoder.state_dict(), 
            'optimizer_state_dict': optim.state_dict(), 'loss': val_loss, 
            'args': args}, f"bhpt_weights_epoch_{epoch}.pt")
print(f"  -> Saving model weights to {save_path}")
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'optimizer_state_dict': optim.state_dict(),
    'loss': val_loss,
    'args': args
}, save_path)
        
print("\n--- Training Complete ---")

# --- 4. Create and Display Animation ---
# Generate the animation from the stored history
# Generate the animation from the stored history
animation_save_path = Path("training_progress.gif")
final_anim = create_animation_from_data(
    predictions_history, 
    true_wave_for_vis, 
    losses_history, 
    epochs_for_animation, # <-- Pass the new list here
    animation_save_path
)

# Display the animation directly in the notebook
HTML(final_anim.to_jshtml())

In [None]:
# The last dimension of the tensor holds the [real, imaginary] parts.
# Let's calculate the complex amplitude (magnitude) for every point.
amplitudes = torch.sqrt(waveforms_tensor[..., 0]**2 + waveforms_tensor[..., 1]**2)

# Now, for each individual waveform in the dataset, find its maximum amplitude over time.
# The time dimension is the second one (dim=1).
peak_amplitudes_per_waveform = torch.max(amplitudes, dim=1).values

# Finally, let's find the min and max of those peak amplitudes.
# This tells us the range of the strongest signals across the whole dataset.
smallest_peak = peak_amplitudes_per_waveform.min().item()
largest_peak = peak_amplitudes_per_waveform.max().item()
mean_peak = peak_amplitudes_per_waveform.mean().item()

print("Analysis of Peak Waveform Amplitudes:")
print(f"  Smallest peak amplitude in the dataset: {smallest_peak:.2e}")
print(f"  Largest peak amplitude in the dataset:  {largest_peak:.2e}")
print(f"  Mean peak amplitude in the dataset:     {mean_peak:.2e}")

In [None]:
# The optimizer needs to know about all trainable parameters
all_params = list(model.parameters()) + list(encoder.parameters()) + list(decoder.parameters())
optim = torch.optim.Adam(all_params, lr=args.lr, weight_decay=args.weight_decay)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)

# Instantiate the pipeline and trainer
pipeline = RefinementPipeline(
    model=model,
    encoder=encoder,
    decoder=decoder,
    process_trajectory=process_trajectory
).to(device)

trainer = Trainer(pipeline)

best_val_loss = float('inf')
best_epoch = 0
save_path = Path("best_bhpt_weights.pt")

print("Starting training...")
for epoch in range(1, args.epochs + 1):
    train_loss = trainer.train_epoch(train_loader, optim)
    val_loss = trainer.eval_epoch(val_loader)
    
    scheduler.step()
    print(f"Epoch {epoch:3d} | train loss: {train_loss:.6f} | val loss: {val_loss:.10f}")

print("\nTraining complete. Saved model weights to best_bhpt_weights.pt")
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'encoder_state_dict': encoder.state_dict(),
            'decoder_state_dict': decoder.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'loss': val_loss,
        }, save_path)

In [None]:
import sys
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from torchvision.ops import MLP

import architectures as arch
from architectures import SingleConvNeuralNet, GalerkinTransformer
from bhpt_running import RefinementPipeline

# ----------------------------------------------------------------------
# 0.  Surrogate model (make sure its repo is on the PYTHONPATH)
# ----------------------------------------------------------------------
PATH_TO_BHPTNRSur = "/home/ubuntu/EG-UT/BHPTNRSurrogate"
if PATH_TO_BHPTNRSur not in sys.path:
    sys.path.append(PATH_TO_BHPTNRSur)
from surrogates import BHPTNRSur1dq1e4 as bhptsur

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
P               = 3     # (t, y, x) positional encodings
d_model         = 127
nhead           = 4
dim_ff          = 2048
dropout         = 0.1
n_layers        = 8
n_modules       = 1
q_picard        = 1
r_picard        = 0.5
share_weights   = True

H = W = 1           # spatial size
Q = 2               # (h₊, h×)

encoder_out_dim = d_model - P
encoder = SingleConvNeuralNet(
    dim=Q,
    hidden_dim=encoder_out_dim,
    out_dim=encoder_out_dim,
    hidden_ff=128,
    K=[1, 1],
    S=[1, 1],
).to(device)

make_module = partial(
    arch.GalerkinTransformer,
    d_model=d_model + 1,      # +1 for the scalar mass-ratio parameter
    nhead=nhead,
    dim_feedforward=dim_ff,
    dropout=dropout,
    n_layers=n_layers,
)
modules = (
    arch.make_weight_shared_modules(make_module, n_modules)
    if share_weights else
    arch.make_weight_unshared_modules(make_module, n_modules)
)
model = arch.PicardIterations(modules, q=q_picard, r=r_picard).to(device)

decoder_in = encoder_out_dim * H * W
decoder = MLP(
    in_channels=decoder_in,
    hidden_channels=[64, 256, H * W * Q],
    activation_layer=torch.nn.ELU,
).to(device)

In [None]:
ckpt = torch.load("bhpt_weights_epoch_96600.pt", map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
encoder.load_state_dict(ckpt["encoder_state_dict"])
decoder.load_state_dict(ckpt["decoder_state_dict"])

model.eval(); encoder.eval(); decoder.eval()

# ----------------------------------------------------------------------
# 3.  Build a test input from the surrogate (use ANY q in [2.5, 1e4])
# ----------------------------------------------------------------------
modes_used = (
    (2, 2),
)

q_test = 2.5
tsur, hdict = bhptsur.generate_surrogate(q=q_test)

# Sum the requested modes, then split into plus / cross
h_sum = sum(hdict[m] for m in modes_used)          # complex array
h_plus, h_cross = np.real(h_sum) * 20 , np.imag(h_sum) * 20  # (T,)

# Slice to first 500 time steps
n_timesteps = len(tsur)
tsur = tsur[:n_timesteps]
h_plus = h_plus[:n_timesteps]
h_cross = h_cross[:n_timesteps]

wave_np = np.stack([h_plus, h_cross], axis=-1).astype(np.float32)  # (T, 2)
wave_np = wave_np[:, None, None, :]                                # (T, 1, 1, 2)
wave_t  = torch.from_numpy(wave_np)[None].to(device)               # (1, T, 1, 1, 2)
param_t = torch.tensor([[q_test]], dtype=torch.float32, device=device)

In [None]:
# Instantiate the pipeline for inference
inference_pipeline = RefinementPipeline(
    model=model,
    encoder=encoder,
    decoder=decoder,
    process_trajectory=arch.broadcast_initial_conditions,
).to(device)
inference_pipeline.eval()

with torch.no_grad():
    pred = inference_pipeline(wave_t, param_t)

pred_np   = pred.squeeze(0).cpu().numpy()[:, 0, 0, :]   # (T, 2)
pred_plus = pred_np[:, 0]

In [None]:
plt.figure(figsize=(20,4))
plt.plot(tsur, h_plus,  label="Surrogate ∑ modes",      lw=1.0)
plt.plot(tsur, pred_plus, label="Transformer output",   lw=1.0)
plt.xlabel("time [seconds]", fontsize=14)
plt.ylabel("rh/M",          fontsize=14)
plt.title(f"Mass-ratio q = {q_test}", fontsize=15)
plt.legend(fontsize=12)
plt.show()