In [None]:
PATH_TO_BHPTNRSur = "/home/ubuntu/EG-UT/BHPTNRSurrogate"
import numpy as np
import matplotlib.pyplot as plt
import sys
# add the path to the script directory
sys.path.append(PATH_TO_BHPTNRSur)
from surrogates import BHPTNRSur1dq1e4 as bhptsur
tsur, hsur = bhptsur.generate_surrogate(q=2.5)
plt.figure(figsize=(20,4))
plt.plot(tsur, np.imag(hsur[(2,2)]), '-', label='22')
#plt.plot(tsur, np.real(hsur[(3,3)]), '-', label='33')
plt.xlabel('time [seconds]', fontsize=15)
plt.ylabel('rh/M', fontsize=15)
plt.show()

In [None]:
!python3 generate_bhpt_dataset.py --n-samples 5000  --q-min 2.5 --q-max 2.5 --n-timesteps 4096

In [None]:
import time, math, gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.ops import MLP
from trim_transformer.transformer_layers import TrimTransformerEncoderLayer, TrimTransformerEncoder
import argparse
from pathlib import Path


seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

parser = argparse.ArgumentParser(description="BHPT training script.")
parser.add_argument("--data", type=Path, default=Path("bhpt_dataset.pt"), help="Path to the .pt dataset produced by generate_bhpt_dataset.py.")
parser.add_argument("--epochs", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight-decay", type=float, default=0, help="Weight decay (L2 penalty) for Adam optimizer.")
parser.add_argument("--n-timesteps", type=int, default=None, help="Number of temporal frames to sample from the raw data.")

parser.add_argument("--d_model", type=int, default=31)
parser.add_argument("--nhead", type=int, default=4)
parser.add_argument("--dim_feedforward", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--n_layers", type=int, default=2)

args = parser.parse_args("")

In [None]:
class WaveformDataset(Dataset):
    def __init__(self, waveforms: torch.Tensor, params: torch.Tensor):
        super().__init__()
        assert waveforms.shape[0] == params.shape[0], "Number of waveforms and parameters must match"
        self.waveforms = waveforms
        self.params = params
        if args.n_timesteps is not None and args.n_timesteps < waveforms.shape[1]:
            original_timesteps = waveforms.shape[1]
            idx = torch.linspace(0, original_timesteps - 1, steps=args.n_timesteps, dtype=torch.long)
            waveforms = waveforms[:, idx]
            print(f"  Subsampled time from {original_timesteps} to {waveforms.shape[1]} steps.")

    def __len__(self) -> int:
        return self.waveforms.shape[0]

    def __getitem__(self, idx: int):
        return self.waveforms[idx], self.params[idx]

In [None]:
# 1. Load data from the .pt file into two tensors
print(f"Loading data from {args.data}...")
data = torch.load(args.data)
waveforms_tensor = data["waveforms"]
params_tensor = data["params"]
print(f"  Loaded 'waveforms' with shape: {waveforms_tensor.shape}")
print(f"  Loaded 'params' with shape: {params_tensor.shape}")

# 2. Create DataLoaders from the tensors
dataset = WaveformDataset(waveforms_tensor.to(device), params_tensor.to(device))
n_train = int(len(dataset) * 0.8)
n_val = len(dataset) - n_train
train_ds, val_ds = random_split(dataset, [n_train, n_val])

train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, drop_last=True)

# 3. Define dimensions for the model
N, T, H, W, Q = waveforms_tensor.shape
_, n_params = params_tensor.shape

print("DataLoaders are ready.")

In [None]:
# In this example, we use the positional encoding from https://arxiv.org/abs/1706.03762. This is
# an unnatural choice for this dataset, but it is the classic choice for language models.
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        num_freqs = (d_model + 1) // 2
        div_term = torch.exp(torch.arange(0, num_freqs) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)
    def forward(self, x):
        return x + self.pe[:, : x.size(1)]

# Compose the encoder, positional encoding, transformer, and decoder into a single module.
class Pipeline(nn.Module):
    def __init__(self, encoder, decoder, pos_enc, transformer, mask):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pos_enc = pos_enc
        self.transformer = transformer
        self.mask = mask

    def forward(self, x):
        B, T, _, _, _ = x.shape
        z = self.encoder(x)
        y = z.flatten(1, 3)
        y = self.pos_enc(y)
        y = self.transformer(y, mask=self.mask)
        y = y.reshape_as(z)
        y = self.decoder(y)
        return y.reshape_as(x)

In [None]:
encoder = MLP(
    in_channels=Q,
    hidden_channels=[args.d_model, args.dim_feedforward, args.d_model + 1],
    activation_layer=nn.ReLU,
)
layer = TrimTransformerEncoderLayer(d_model=args.d_model + 1,
                                         nhead=args.nhead,
                                         dim_feedforward=args.dim_feedforward,
                                         batch_first=True)
model = TrimTransformerEncoder(layer, num_layers=args.n_layers)

decoder = MLP(
    in_channels=args.d_model + 1,
    hidden_channels=[args.d_model, args.dim_feedforward, H * W * Q], # Final output must match original H, W, Q
    activation_layer=nn.ReLU,
)

pos_enc = PositionalEncoding(args.d_model + 1, max_len=T)

trim_pipeline = Pipeline(encoder, decoder, pos_enc, model, None).to(device)

print("encoder parameter count:", sum(p.numel() for p in encoder.parameters()))
print("decoder parameter count:", sum(p.numel() for p in decoder.parameters()))
print("model parameter count:", sum(p.numel() for p in model.parameters()))

In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    start = time.time()
    running = 0.0
    for waveforms, _ in loader:
        optimizer.zero_grad()
        pred = model(waveforms[:, :-1])
        loss = criterion(pred, waveforms[:, 1:])
        loss.backward()
        optimizer.step()
        running += loss.item() * waveforms.size(0)
    elapsed = time.time() - start
    return running / len(loader.dataset), elapsed

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    running = 0.0
    for waveforms, _ in loader:
        pred = model(waveforms[:, :-1])
        loss = criterion(pred, waveforms[:, 1:])
        running += loss.item() * waveforms.size(0)
    return running / len(loader.dataset)

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(trim_pipeline.parameters(), lr=args.lr, weight_decay=args.weight_decay)
for ep in range(1, args.epochs+1):
    train_loss, t = train_epoch(trim_pipeline, train_loader, optimizer, criterion)
    val_loss = evaluate(trim_pipeline, val_loader, criterion)
    print(f"epoch {ep}/{args.epochs} | train {train_loss:.3e} | val {val_loss:.3e} | {t:.2f}s")
if device.type == "cuda":
    torch.cuda.empty_cache(); gc.collect()