In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import tqdm
from torch.optim import Adam
import matplotlib.pyplot as plt
from model import FNO1d
from utils import relative_l2_error, PDEDatasetWithTime

In [2]:
torch.manual_seed(0)
np.random.seed(0)

In [3]:
N_TRAIN = 64 # number of training samples
BATCH_SIZE = 10

In [4]:
training_data = PDEDatasetWithTime("../../data/FNO - Wave Equation/train_sol.npy")
# choose N_TRAIN samples randomly
train_data, val_data = torch.utils.data.random_split(training_data, [N_TRAIN, len(training_data) - N_TRAIN])

1280 Samples available


In [5]:
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
learning_rate = 0.001
epochs = 500
step_size = 100
gamma = 0.5

In [7]:
modes = 16
width = 64
fno = FNO1d(modes, width, time_conditioning=True)  # model

In [8]:
optimizer = Adam(fno.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

# loss_f = torch.nn.MSELoss()
loss_f = relative_l2_error

In [9]:
fno.train()

progress_bar = tqdm.tqdm(range(epochs))
for epoch in progress_bar:
    train_loss = 0.0
    for time_delta, input, target in train_data_loader:

        optimizer.zero_grad()
        prediction = fno(input, time_delta).squeeze(-1)

        loss = loss_f(prediction, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
    train_loss /= len(train_data_loader)

    scheduler.step()

    progress_bar.set_postfix({"train_loss": train_loss})


0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
ERROR:tornado.general:SEND Error: Host unreachable
 11%|█         | 53/500 [00:12<01:43,  4.33it/s, train_loss=0.227] 


KeyboardInterrupt: 

In [21]:
# validate model

fno.eval()
progress_bar = tqdm.tqdm(val_data_loader)

with torch.no_grad():
    test_relative_l2 = 0.0
    for time_delta, input, target in progress_bar:

        optimizer.zero_grad()
        prediction = fno(input, time_delta).squeeze(-1)

        loss = relative_l2_error(prediction, target)
        test_relative_l2 += loss.item()
    test_relative_l2 /= len(val_data_loader)


print("#" * 20)
print(f"Test relative L2 error: {test_relative_l2}")

100%|██████████| 122/122 [00:00<00:00, 334.40it/s]

####################
Test relative L2 error: 0.5511187400485649





In [22]:
# save model to disk
torch.save(fno.state_dict(), "models/tfno_model.pth")