Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Burger's dataset and PINO #256

Merged
merged 9 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions config/burgers_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
default: &DEFAULT

#General
# For computing compression
n_params_baseline: None
verbose: True
arch: 'tfno2d'

#Distributed computing
distributed:
use_distributed: False
wireup_info: 'mpi'
wireup_store: 'tcp'
model_parallel_size: 2
seed: 666

# FNO related
tfno2d:
data_channels: 3
n_modes_height: 15
n_modes_width: 15
hidden_channels: 24
lifting_channels: 24
projection_channels: 24
n_layers: 5
domain_padding: None
domain_padding_mode: 'one-sided'
fft_norm: 'forward'
norm: 'group_norm'
skip: 'linear'
implementation: 'factorized'
separable: 0
preactivation: 0
half_prec_fourier: False
half_prec_inverse: False
stabilizer: None

use_mlp: 1
mlp:
expansion: 0.5
dropout: 0

factorization: None
rank: 0.05
fixed_rank_modes: None
dropout: 0.0
tensor_lasso_penalty: 0.0
joint_factorization: False

# Optimizer
opt:
n_epochs: 10000
learning_rate: 0.0001
training_loss: 'l2'
weight_decay: 1e-4
amp_autocast: False

scheduler_T_max: 500 # For cosine only, typically take n_epochs
scheduler_patience: 100 # For ReduceLROnPlateau only
scheduler: 'ReduceLROnPlateau' # Or 'CosineAnnealingLR' OR 'ReduceLROnPlateau'
step_size: 60
gamma: 0.5
precision_schedule: None

# Dataset related
data:
folder: '/home/ubuntu/data/burgers/burgers.npz'
batch_size: 16
n_train: 800
test_batch_sizes: [16]
n_tests: [400]
spatial_length: 128
temporal_length: 101

positional_encoding: True
encode_input: False
encode_output: False
include_endpoint: [True, False]

# Patching
patching:
levels: 0
padding: 0
stitching: False

# Weights and biases
wandb:
log: False
name: None # If None, config will be used but you can override it here
group: ""
project: ""
entity: "" # put your username here
sweep: False
log_output: True
log_test_interval: 1
100 changes: 100 additions & 0 deletions config/burgers_pino_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
default: &DEFAULT

#General
# For computing compression
n_params_baseline: None
verbose: True
arch: 'tfno2d'

#Distributed computing
distributed:
use_distributed: False
wireup_info: 'mpi'
wireup_store: 'tcp'
model_parallel_size: 2
seed: 666

# FNO related
tfno2d:
data_channels: 3
n_modes_height: 15
n_modes_width: 15
hidden_channels: 24
lifting_channels: 24
projection_channels: 24
n_layers: 5
domain_padding: None
domain_padding_mode: 'one-sided'
fft_norm: 'forward'
norm: 'group_norm'
skip: 'linear'
implementation: 'factorized'
separable: 0
preactivation: 0
half_prec_fourier: False
half_prec_inverse: False
stabilizer: None

use_mlp: 1
mlp:
expansion: 0.5
dropout: 0

factorization: None
rank: 0.05
fixed_rank_modes: None
dropout: 0.0
tensor_lasso_penalty: 0.0
joint_factorization: False

# Optimizer
opt:
n_epochs: 10000
learning_rate: 0.0001
training_loss: ['equation', 'ic']
pino_method: 'fdm'
loss_weights:
'l2': 0.0
'equation': .2
'ic': .8
weight_decay: 1e-4
amp_autocast: False

scheduler_T_max: 500 # For cosine only, typically take n_epochs
scheduler_patience: 100 # For ReduceLROnPlateau only
scheduler: 'ReduceLROnPlateau' # Or 'CosineAnnealingLR' OR 'ReduceLROnPlateau'
step_size: 60
gamma: 0.5
precision_schedule: None

# Dataset related
data:
folder: '/home/ubuntu/data/burgers/burgers.npz'
batch_size: 16
n_train: 800
test_batch_sizes: [16]
n_tests: [400]
spatial_length: 128
temporal_length: 101

positional_encoding: True
encode_input: False
encode_output: False
include_endpoint: [True, False]

# Patching
patching:
levels: 0
padding: 0
stitching: False

# Weights and biases
wandb:
log: False
name: None # If None, config will be used but you can override it here
group: ""
project: ""
entity: "" # put your username here
sweep: False
log_output: True
log_test_interval: 1
4 changes: 2 additions & 2 deletions neuralop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
from .models import get_model
from . import datasets
from . import mpu
from .training import Trainer
from .training import LpLoss, H1Loss
from .training import Trainer, CheckpointCallback
from .losses import LpLoss, H1Loss, BurgersEqnLoss, ICLoss, WeightedSumLoss
1 change: 1 addition & 0 deletions neuralop/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .spherical_swe import load_spherical_swe
from .navier_stokes import load_navier_stokes_pt
from .pt_dataset import load_pt_traintestsplit
from .burgers import load_burgers_1dtime
79 changes: 78 additions & 1 deletion neuralop/datasets/burgers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from pathlib import Path
import torch
import numpy as np
from .tensor_dataset import TensorDataset


def load_burgers(
def load_burgers_1d(
data_path, n_train, n_test, batch_train=32, batch_test=100, time=1, grid=[0, 1]
):

Expand Down Expand Up @@ -38,3 +40,78 @@ def load_burgers(
)

return train_loader, test_loader

def load_burgers_1dtime(
data_path, n_train, n_test, batch_size=32, batch_size_test=100,
temporal_length=101, spatial_length=128, temporal_subsample=1,
spatial_subsample=1, pad=0):
"""
Load burgers.mat data. Given the initial condition (t=0),
predict timesteps 1 to temporal_length.
"""
with np.load(data_path) as data:
x_data = data['input']
y_data = data['output']
visc = data['visc']

x_data = torch.from_numpy(x_data.astype(np.float32))
x_data = x_data[:, :spatial_length:spatial_subsample]
y_data = torch.from_numpy(y_data.astype(np.float32))
y_data = y_data[:, :temporal_length:temporal_subsample, :spatial_length:spatial_subsample]
visc = torch.from_numpy(visc.astype(np.float32)).item()

x_train = x_data[:n_train]
y_train = y_data[:n_train]
x_test = x_data[n_train:n_train+n_test]
y_test = y_data[n_train:n_train+n_test]

domain_lengths = [spatial_length / 128, (temporal_length - 1) / 100]
domain_starts = [0., 0.]

spatial_length = spatial_length // spatial_subsample
temporal_length = temporal_length // temporal_subsample

if pad:
x_train = torch.nn.ReplicationPad1d(pad)(x_train)
x_test = torch.nn.ReplicationPad1d(pad)(x_test)
spatial_length += 2 * pad
temporal_length += 2 * pad
incrs = [spatial_subsample / 128, temporal_subsample / 100]
domain_lengths = [d + incr * pad for d, incr in zip(domain_lengths, incrs)]
domain_starts = [-incr * pad for incr in incrs]

# TODO: use include_endpoint arg here
grid_x = torch.tensor(np.linspace(domain_starts[0], domain_lengths[0], spatial_length + 1)[:-1], dtype=torch.float)
grid_t = torch.tensor(np.linspace(domain_starts[1], domain_lengths[1], temporal_length), dtype=torch.float)

grid_x = grid_x.reshape(1, 1, spatial_length)
grid_t = grid_t.reshape(1, temporal_length, 1)

x_train = x_train.reshape(n_train, 1, spatial_length).repeat([1, temporal_length, 1])
x_test = x_test.reshape(n_test, 1, spatial_length).repeat([1, temporal_length, 1])

# TODO: add option to not have positional encoding
x_train = torch.stack([x_train,
grid_t.repeat([n_train, 1, spatial_length]),
grid_x.repeat([n_train, temporal_length, 1])
], dim=3)
x_test = torch.stack([x_test,
grid_t.repeat([n_test, 1, spatial_length]),
grid_x.repeat([n_test, temporal_length, 1])
], dim=3)

x_train = x_train.permute(0, 3, 1, 2)
x_test = x_test.permute(0, 3, 1, 2)
y_train = y_train.unsqueeze(1)
y_test = y_test.unsqueeze(1)

train_db = TensorDataset(x_train, y_train)
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=False)

test_db = TensorDataset(x_test, y_test)
test_loader = torch.utils.data.DataLoader(test_db, batch_size=batch_size_test, shuffle=False)

output_encoder = None
test_loaders = {'test':test_loader}

return train_loader, test_loaders, output_encoder
Binary file added neuralop/datasets/data/burgers_lowres.pt
Binary file not shown.
3 changes: 3 additions & 0 deletions neuralop/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .data_losses import LpLoss, H1Loss
from .equation_losses import BurgersEqnLoss, ICLoss
from .meta_losses import WeightedSumLoss
File renamed without changes.
70 changes: 70 additions & 0 deletions neuralop/losses/equation_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
import torch.nn.functional as F

from .data_losses import central_diff_2d


class BurgersEqnLoss(object):
"""
Computes loss for Burgers' equation.
"""

def __init__(self, visc=0.01, method="fdm", loss=F.mse_loss, domain_length=1.0):
super().__init__()
self.visc = visc
self.method = method
self.loss = loss
self.domain_length = domain_length
if not isinstance(self.domain_length, (tuple, list)):
self.domain_length = [self.domain_length] * 2

def fdm(self, u):
# remove extra channel dimensions
u = u.squeeze(1)

# shapes
_, nt, nx = u.shape

# we assume that the input is given on a regular grid
dt = self.domain_length[0] / (nt - 1)
dx = self.domain_length[1] / nx

# du/dt and du/dx
dudt, dudx = central_diff_2d(u, [dt, dx], fix_x_bnd=True, fix_y_bnd=True)

# d^2u/dxx
dudxx = (
torch.roll(u, -1, dims=-1) - 2 * u + torch.roll(u, 1, dims=-1)
) / dx**2
# fix boundary
dudxx[..., 0] = (u[..., 2] - 2 * u[..., 1] + u[..., 0]) / dx**2
dudxx[..., -1] = (u[..., -1] - 2 * u[..., -2] + u[..., -3]) / dx**2

# right hand side
right_hand_side = -dudx * u + self.visc * dudxx

# compute the loss of the left and right hand sides of Burgers' equation
return self.loss(dudt, right_hand_side)

def __call__(self, y_pred, **kwargs):
if self.method == "fdm":
return self.fdm(u=y_pred)
raise NotImplementedError()


class ICLoss(object):
"""
Computes loss for initial value problems.
"""

def __init__(self, loss=F.mse_loss):
super().__init__()
self.loss = loss

def initial_condition_loss(self, y_pred, x):
boundary_true = x[:, 0, 0, :]
boundary_pred = y_pred[:, 0, 0, :]
return self.loss(boundary_pred, boundary_true)

def __call__(self, y_pred, x, **kwargs):
return self.initial_condition_loss(y_pred, x)
Loading