In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt


SEED = 42 # Use same seed for replicable data generation
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Import data loader and loss function
from FNO1D_utils import (
    data_loader,
    relative_min_max,
    normal_RMSE,
)
# Set path to data set
path = '/mimer/NOBACKUP/groups/ml_flame_storage/Max_1D/Data/1D200.h5'
from FNO1D import FNO1d

train_loader, test_loader, inp_max, inp_min, out_max, out_min, x_max, x_min, y_max, y_min = data_loader(path, batch_size=20)

In [None]:
# Define physical constants
C_p = 1000
gamma = 1.4
C_v = C_p/gamma
R = C_p - C_v
# --------------------------------------------------------------------------------
# 5) Training 
# --------------------------------------------------------------------------------
if __name__ == "__main__":
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ------------------------------------------------
    # Instantiate the FNO and training parameters
    # ------------------------------------------------
    model = FNO1d(modes=16, width=32, hidden_mlp=128, N_x=128, N_fourier_layers=4).to(device)
    phys_w_max      = 10     # the weight you want to end up with
    warmup_epochs   = 2000
    epochs = 2500
    learning_rate = 1e-2
    step_size = 100
    gamma = 0.7
    weight_decay = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    out_max_torch = torch.from_numpy(out_max).to(device)
    out_min_torch = torch.from_numpy(out_min).to(device)
    
    # Define loss and error evaluation function
    criterion_loss = relative_min_max
    criterion_eval = relative_min_max
   
    
    # ------------------------------------------------
    # Preallocate losses and relative errors
    # ------------------------------------------------
    train_losses = []
    phys_losses = []
    test_losses = []
    test_rel_list = []
    test_rel_pressure_list = []
    test_rel_velocity_list = []
    test_rel_temperature_list = []
    rel_final_batch = []
    rel_err_sum = 0
    
    # ------------------------------------------------
    # Training loop
    # ------------------------------------------------
    for epoch in range(epochs):
        phys_w = phys_w_max * min(1.0, epoch / warmup_epochs)
        model.train()
        total_train_loss = 0.0
        total_phys_loss = 0.0
        train_rel_pressure = 0.0
        train_rel_velocity = 0.0
        train_rel_temperature = 0.0
        for bc, target_field, x, y in train_loader:
            bc, target_field, x, y = bc.to(device), target_field.to(device), x.to(device), y.to(device)
            optimizer.zero_grad()
            pred_field = model(bc, x, y)
            pred_field_real= pred_field*(out_max_torch-out_min_torch) + out_min_torch
            target_field_real = target_field*(out_max_torch-out_min_torch) + out_min_torch

            # Renormalize physical units
            real_x = x*(x_max-x_min) + x_min
            real_y = y*(y_max-y_min) + y_min
            real_temp = pred_field_real[:,:,0]
            real_press = pred_field_real[:,:,1]
            real_vel = pred_field_real[:,:,2]

            rho = real_press/(R*real_temp)
            mass = rho * real_vel*y
            energy = C_p*real_temp + (real_vel**2)/2

            mass_residual = mass[:,-1] - mass[:,0]
            mass_mean = torch.mean(mass)
            energy_residual = (energy[:, -1] - energy[:, 0])
            energy_mean = torch.mean(energy)

            mass_residual_norm = mass_residual/mass_mean
            energy_residual_norm = energy_residual/energy_mean

            loss_phys = phys_w*(torch.sqrt(torch.mean(mass_residual_norm**2)) + torch.sqrt(torch.mean(energy_residual_norm**2)))

            # Training loss
            loss_vec = criterion_loss(pred_field, target_field)
            train_loss = (loss_vec[0] + loss_vec[1] + loss_vec[2])
            loss = train_loss + loss_phys
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Works good for small datasets
            optimizer.step()
            total_train_loss += train_loss.item()
            total_phys_loss += loss_phys.item()

            # Batch wise relative errors
            train_eval_error = criterion_eval(pred_field, target_field)*100
            batch_train_rel_temperature = train_eval_error[0]
            batch_train_rel_pressure = train_eval_error[1]
            batch_train_rel_velocity = train_eval_error[2]
            train_rel_pressure += batch_train_rel_pressure.item()
            train_rel_velocity += batch_train_rel_velocity.item()
            train_rel_temperature += batch_train_rel_temperature.item()
            
        # Average over batchwise losses and errors    
        avg_train_loss = total_train_loss / len(train_loader)
        avg_phys_loss = total_phys_loss/len(train_loader)
        avg_train_rel_pressure = train_rel_pressure / len(train_loader)
        avg_train_rel_velocity = train_rel_velocity / len(train_loader)
        avg_train_rel_temperature = train_rel_temperature / len(train_loader)
        train_losses.append(avg_train_loss)
        phys_losses.append(avg_phys_loss)
        
        # ------------------------------------------------
        # Evaluate model on test data
        # ------------------------------------------------
        model.eval()
        total_test_loss = 0.0
        test_rel_error = 0.0
        test_rel_pressure = 0.0
        test_rel_velocity = 0.0
        test_rel_temperature = 0.0
        with torch.no_grad():
            for bc, target_field, x, y in test_loader:
                bc, target_field, x, y = bc.to(device), target_field.to(device), x.to(device), y.to(device)
                pred_field = model(bc, x, y)

                # Test Loss
                loss_vec = criterion_eval(pred_field, target_field)
                loss = (loss_vec[0] + loss_vec[1] + loss_vec[2])
                total_test_loss += loss.item()

                batch_rel_error = torch.mean(loss_vec)*100
                test_rel_error += batch_rel_error.item()

                # Batch wise relative errors
                test_eval_error = criterion_eval(pred_field, target_field)*100
                batch_rel_temperature = test_eval_error[0]
                batch_rel_pressure = test_eval_error[1]
                batch_rel_velocity = test_eval_error[2]

                test_rel_temperature += batch_rel_temperature.item()
                test_rel_pressure += batch_rel_pressure.item()
                test_rel_velocity += batch_rel_velocity.item()
                
                # Individual relative error of each sample
                if epoch == epochs-1:
                    num = (pred_field-target_field)**2
                    denom = (torch.amax(target_field, dim=[1], keepdim=True) - torch.amin(target_field, dim=[1], keepdim=True))**2
                    rel_err = torch.mean(num/denom, dim=[1])
                    rel_err = torch.sqrt(rel_err)*100
                    rel_final_batch.extend(rel_err.cpu().tolist())


        # Average batchwise errors
        avg_test_loss = total_test_loss / len(test_loader)
        avg_test_rel = test_rel_error / len(test_loader)
        avg_rel_pressure = test_rel_pressure / len(test_loader)
        avg_rel_velocity = test_rel_velocity / len(test_loader)
        avg_rel_temperature = test_rel_temperature / len(test_loader)
        test_losses.append(avg_test_loss)

        # Store test relative errors for plotting
        test_rel_list.append(avg_test_rel)
        test_rel_temperature_list.append(avg_rel_temperature)
        test_rel_pressure_list.append(avg_rel_pressure)
        test_rel_velocity_list.append(avg_rel_velocity)
        
        if epoch % 50 == 0 or epoch == epochs - 1:
            print("mass loss", torch.mean(mass_residual))
            print("energy loss", torch.mean(energy_residual))
            print("phys_loss", loss_phys)
            print(f"Epoch [{epoch+1}/{epochs}] - Train Loss: {avg_train_loss:.6f} - Test Loss: {avg_test_loss:.6f} - Test Relative Error: {avg_test_rel:.6f}%")
            print(
                f"  Train:  Temperature Rel: {avg_train_rel_temperature:.6f}% | "
                f"Pressure Rel: {avg_train_rel_pressure:.6f}% | "
                f"Velocity Rel: {avg_train_rel_velocity:.6f}% | "
                
            )
            print(
                f"  Test:   Temperature Rel: {avg_rel_temperature:.6f}% | "
                f"Pressure Rel: {avg_rel_pressure:.6f}% | "
                f"Velocity Rel: {avg_rel_velocity:.6f}% | "
                
            )
            print()
        scheduler.step()
        
    print("Training complete!")
    # ------------------------------------------------
    # Plot losses vs Epoch
    # ------------------------------------------------
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, epochs+1), train_losses, label="Train Loss")
    plt.plot(range(1, epochs+1), test_losses, label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Log-Loss")
    plt.yscale('log')
    plt.title("Loss vs Epoch")
    plt.legend()
    plt.show()

In [None]:
#torch.save(model.state_dict(), '1DPFNO.pth')