In [81]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
from torch.nn import MSELoss
from torch.nn.functional import mse_loss
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
import glob
from compute_distance import compute_distance
from scipy.spatial import cKDTree
from shapely.geometry import Point, Polygon, LineString
from shapely.ops import unary_union

In [None]:
weights_path = 'models/'
model = ""
data_path = 'data/unscaled_p/'
data_pattern = 'TrP4_*_10.csv'

num_epochs = 500
batch_size = 2000

loss_fn = mse_loss

h_n = 64
input_n = 4
n_layers = 8

ds = 0.4
dp = 0.1

lr = [1e-3, 1e-3, 2e-3]

In [83]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [84]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

# General PINN architecture for u, v, or p
class PINN(nn.Module):
    def __init__(self, input_dim=4, output_dim=1, hidden_dim=64, num_layers=8):
        super(PINN, self).__init__()
        layers = []

        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(Swish())

        # Hidden layers
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(Swish())

        # Output layer
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [85]:
def initialize_models():
    net_u = PINN(input_dim=input_n, output_dim=1, hidden_dim=h_n, num_layers=n_layers).to(device)
    net_v = PINN(input_dim=input_n, output_dim=1, hidden_dim=h_n, num_layers=n_layers).to(device)
    net_p = PINN(input_dim=input_n, output_dim=1, hidden_dim=h_n, num_layers=n_layers).to(device)

    def init_xavier(model):
        for m in model.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)  # or small constant like 1e-3

    net_u.apply(init_xavier)
    net_v.apply(init_xavier)
    net_p.apply(init_xavier)

    optimizer_u = optim.Adam(net_u.parameters(), lr=lr[0])
    optimizer_v = optim.Adam(net_v.parameters(), lr=lr[1])
    optimizer_p = optim.Adam(net_p.parameters(), lr=lr[2])

    scheduler_u = optim.lr_scheduler.ReduceLROnPlateau(optimizer_u, mode='min', factor=0.5, patience=50, min_lr=1e-6)
    scheduler_v = optim.lr_scheduler.ReduceLROnPlateau(optimizer_v, mode='min', factor=0.5, patience=50, min_lr=1e-6)
    scheduler_p = optim.lr_scheduler.ReduceLROnPlateau(optimizer_p, mode='min', factor=0.5, patience=50, min_lr=1e-6)

    if os.path.exists(f"{weights_path}/{model}_u.pth"):
        net_u.load_state_dict(torch.load(f"{weights_path}/{model}_u.pth"))
    if os.path.exists(f"{weights_path}/{model}_v.pth"):
        net_v.load_state_dict(torch.load(f"{weights_path}/{model}_v.pth"))
    if os.path.exists(f"{weights_path}/{model}_p.pth"):
        net_p.load_state_dict(torch.load(f"{weights_path}/{model}_p.pth"))

    return net_u, net_v, net_p, optimizer_u, optimizer_v, optimizer_p, scheduler_u, scheduler_v, scheduler_p

In [86]:
class DLD_Dataset(Dataset):
    def __init__(self, data_path, data_pattern):
        self.files = sorted(glob.glob(os.path.join(data_path, data_pattern)))
        self.data = []

        for file in self.files:
            df = pd.read_csv(file)

            inputs = df[['x', 'y', 'd', 'N']].values
            targets = df[['u', 'v', 'p']].values
            self.data.append((inputs, targets))

        self.inputs = torch.tensor(np.vstack([d[0] for d in self.data]), dtype=torch.float32)
        self.targets = torch.tensor(np.vstack([d[1] for d in self.data]), dtype=torch.float32)

    def __len__(self):
        return self.inputs.shape[0]
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

In [87]:
def criterion(inputs, targets, net_u, net_v, net_p, loss_fn, epoch, rho=1.0, nu=0.01, ds=0.4, dp = 0.1):
    x, y, d, n = inputs[:, 0], inputs[:, 1], inputs[:, 2], inputs[:, 3]

    x.requires_grad_()
    y.requires_grad_()
    d.requires_grad_()
    n.requires_grad_()

    input_tensor = torch.stack((x, y, d, n), dim=1)

    u_pred = net_u(input_tensor).squeeze(1)
    v_pred = net_v(input_tensor).squeeze(1)
    p_pred = net_p(input_tensor).squeeze(1)

    distances = compute_distance(x, y, d, n, ds)

    ####
    # u_hard = u_pred
    # v_hard = v_pred
    # p_hard = p_pred

    u_hard = u_pred * distances
    v_hard = v_pred * distances

    xStart = 0
    xEnd = ds

    p_par = ((xEnd - x) / (xEnd - xStart) * dp) + ((x - xStart) * (xEnd - xStart) * 0.0)
    Dp = (x - xStart) * (xEnd - x)
    p_hard = p_par + Dp * p_pred

    u_x = torch.autograd.grad(u_hard, x, grad_outputs=torch.ones_like(u_hard), create_graph=True)[0]
    u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True)[0]
    u_y = torch.autograd.grad(u_hard, y, grad_outputs=torch.ones_like(u_hard), create_graph=True)[0]
    u_yy = torch.autograd.grad(u_y, y, grad_outputs=torch.ones_like(u_y), create_graph=True)[0]

    v_x = torch.autograd.grad(v_hard, x, grad_outputs=torch.ones_like(v_hard), create_graph=True)[0]
    v_xx = torch.autograd.grad(v_x, x, grad_outputs=torch.ones_like(v_x), create_graph=True)[0]
    v_y = torch.autograd.grad(v_hard, y, grad_outputs=torch.ones_like(v_hard), create_graph=True)[0]
    v_yy = torch.autograd.grad(v_y, y, grad_outputs=torch.ones_like(v_y), create_graph=True)[0]

    p_x = torch.autograd.grad(p_hard, x, grad_outputs=torch.ones_like(p_hard), create_graph=True)[0]
    p_y = torch.autograd.grad(p_hard, y, grad_outputs=torch.ones_like(p_hard), create_graph=True)[0]

    pde_loss_x = u_hard * u_x + v_hard * u_y - nu * (u_xx + u_yy) + (1 / rho) * p_x
    pde_loss_y = u_hard * v_x + v_hard * v_y - nu * (v_xx + v_yy) + (1 / rho) * p_y
    pde_loss_continuity = u_x + v_y 

    pde_loss_x = loss_fn(pde_loss_x, torch.zeros_like(pde_loss_x))
    pde_loss_y = loss_fn(pde_loss_y, torch.zeros_like(pde_loss_y))
    pde_loss_continuity = loss_fn(pde_loss_continuity, torch.zeros_like(pde_loss_continuity))
    ####

    inlet_condition = (torch.abs(x) < 1e-6)
    outlet_condition = (torch.abs(x - ds) < 1e-6)
    wall_condition = (distances < 1e-6)

    u_avg = nu / d
    u_max = (3 / 2) * u_avg
    u_inlet = u_max * (1 - (4 * (((ds / 2) - y) ** 2)) / ((ds - d) ** 2))

    inlet_loss_u = mse_loss(u_hard[inlet_condition], u_inlet[inlet_condition]) if u_inlet[inlet_condition].numel() > 0 else torch.tensor(0.0, device=device)
    inlet_loss_v = mse_loss(v_hard[inlet_condition], torch.zeros_like(v_hard[inlet_condition])) if v_hard[inlet_condition].numel() > 0 else torch.tensor(0.0, device=device)
    inlet_loss_p = mse_loss(p_hard[inlet_condition], torch.full_like(p_hard[inlet_condition], dp)) if p_hard[inlet_condition].numel() > 0 else torch.tensor(0.0, device=device)

    inlet_loss = inlet_loss_u + inlet_loss_v + inlet_loss_p

    outlet_loss_p = mse_loss(p_hard[outlet_condition], torch.zeros_like(p_hard[outlet_condition])) if p_hard[outlet_condition].numel() > 0 else torch.tensor(0.0, device=device)

    outlet_loss = outlet_loss_p

    wall_loss_u = mse_loss(u_hard[wall_condition], torch.zeros_like(u_hard[wall_condition])) if u_hard[wall_condition].numel() > 0 else torch.tensor(0.0, device=device)
    wall_loss_v = mse_loss(v_hard[wall_condition], torch.zeros_like(v_hard[wall_condition])) if v_hard[wall_condition].numel() > 0 else torch.tensor(0.0, device=device)

    wall_loss = wall_loss_u + wall_loss_v
    
    loss_u = loss_fn(u_hard, targets[:, 0])
    loss_v = loss_fn(v_hard, targets[:, 1])
    loss_p = loss_fn(p_hard, targets[:, 2])

    boundaries_weight = 1
    pde_weight = 1
    data_weight = 0
    
    total_loss = data_weight * (loss_u + loss_v + loss_p) + boundaries_weight * (inlet_loss + outlet_loss + wall_loss) + pde_weight * (pde_loss_x + pde_loss_y + pde_loss_continuity)

    return total_loss, loss_u, loss_v, loss_p,\
        inlet_loss, outlet_loss, wall_loss, \
        pde_loss_x, pde_loss_y, pde_loss_continuity

In [88]:
def periodic_pairs(unique_dn, ds, N):
    device = unique_dn.device
    K = unique_dn.shape[0]

    idx_dn = torch.randint(0, K, (N,), dtype=torch.long, device=device)
    dn = unique_dn[idx_dn]
    d = dn[:, 0]
    n = dn[:, 1]

    t = torch.rand(N, device=device)

    tilt = ds / n

    x_bot = ds * t
    y_bot = tilt * t

    x_top = ds * t
    y_top = ds + tilt * t

    bottom_inp = torch.stack([x_bot, y_bot, d, n], dim=1)
    top_inp    = torch.stack([x_top, y_top, d, n], dim=1)

    return bottom_inp, top_inp

In [None]:
def main():
    dataset = DLD_Dataset(data_path, data_pattern)
    print(f"Dataset size: {len(dataset)}")

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"Number of batches: {len(dataloader)}")

    net_u, net_v, net_p, optimizer_u, optimizer_v, optimizer_p, scheduler_u, scheduler_v, scheduler_p = initialize_models()

    min_loss = float('inf')

    loss_history = {
        'epoch': [],
        'total_loss': [],
        'loss_u': [],
        'loss_v': [],
        'loss_p': [],
        'inlet_loss': [],
        'outlet_loss': [],
        'wall_loss': [],
        'periodic_loss': [],
        'pde_loss_x': [],
        'pde_loss_y': [],
        'pde_loss_continuity': [],
    }

    start_time = time.time()

    for epoch in range(num_epochs+1):
        total_loss, total_loss_u, total_loss_v, total_loss_p, \
            total_inlet_loss, total_outlet_loss, total_wall_loss, \
            total_pde_loss_x, total_pde_loss_y, total_pde_loss_continuity = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
        
        total_periodic_loss = 0.0

        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer_u.zero_grad()
            optimizer_v.zero_grad()
            optimizer_p.zero_grad()

            loss, loss_u, loss_v, loss_p, \
                inlet_loss, outlet_loss, wall_loss, \
                pde_loss_x, pde_loss_y, pde_loss_continuity = criterion(inputs, targets, net_u, net_v, net_p, loss_fn, epoch, ds=ds, dp=dp)
            
            bottom_inputs, top_inputs = periodic_pairs(torch.unique(inputs[:, 2:4], dim=0).to(device), ds, 100)
            bottom_inputs = bottom_inputs.to(device)
            top_inputs = top_inputs.to(device)

            periodic_loss = 0.0

            with torch.no_grad():
                u_bottom = net_u(bottom_inputs).squeeze(1)
                v_bottom = net_v(bottom_inputs).squeeze(1)
                p_bottom = net_p(bottom_inputs).squeeze(1)

                u_top = net_u(top_inputs).squeeze(1)
                v_top = net_v(top_inputs).squeeze(1)
                p_top = net_p(top_inputs).squeeze(1)

                # Periodic boundary condition loss
                periodic_loss += mse_loss(u_bottom, u_top)
                periodic_loss += mse_loss(v_bottom, v_top)
                periodic_loss += mse_loss(p_bottom, p_top)

            loss += periodic_loss

            loss.backward()

            optimizer_u.step()
            optimizer_v.step()
            optimizer_p.step()

            total_loss += loss.item()
            total_loss_u += loss_u.item()
            total_loss_v += loss_v.item()
            total_loss_p += loss_p.item()

            total_inlet_loss += inlet_loss.item()
            total_outlet_loss += outlet_loss.item()
            total_wall_loss += wall_loss.item()
            total_periodic_loss += periodic_loss.item()

            total_pde_loss_x += pde_loss_x.item()
            total_pde_loss_y += pde_loss_y.item()
            total_pde_loss_continuity += pde_loss_continuity.item()

        total_loss_u /= len(dataloader)
        total_loss_v /= len(dataloader)
        total_loss_p /= len(dataloader)
        
        total_inlet_loss /= len(dataloader)
        total_outlet_loss /= len(dataloader)
        total_wall_loss /= len(dataloader)
        total_periodic_loss /= len(dataloader)

        total_pde_loss_x /= len(dataloader)
        total_pde_loss_y /= len(dataloader)
        total_pde_loss_continuity /= len(dataloader)
        total_loss /= len(dataloader)

        scheduler_u.step(total_loss_u)
        scheduler_v.step(total_loss_v)
        scheduler_p.step(total_loss_p)

        loss_history['epoch'].append(epoch)
        loss_history['total_loss'].append(total_loss)
        loss_history['loss_u'].append(total_loss_u)
        loss_history['loss_v'].append(total_loss_v)
        loss_history['loss_p'].append(total_loss_p)

        loss_history['inlet_loss'].append(total_inlet_loss)
        loss_history['outlet_loss'].append(total_outlet_loss)
        loss_history['wall_loss'].append(total_wall_loss)
        loss_history['periodic_loss'].append(periodic_loss.item())

        loss_history['pde_loss_x'].append(total_pde_loss_x)
        loss_history['pde_loss_y'].append(total_pde_loss_y)
        loss_history['pde_loss_continuity'].append(total_pde_loss_continuity)

        print(f"Epoch {epoch}/{num_epochs}, Total Loss: {total_loss:.4f} || Loss U: {total_loss_u:.4f}, Loss V: {total_loss_v:.4f}, Loss P: {total_loss_p:.4f} "
            f"|| Inlet Loss: {total_inlet_loss:.4f}, Outlet Loss: {total_outlet_loss:.4f}, Wall Loss: {total_wall_loss:.4f}, Periodic Loss: {periodic_loss.item():.4f} "
            f"|| PDE Losses: x: {total_pde_loss_x:.4f}, y: {total_pde_loss_y:.4f}, continuity: {total_pde_loss_continuity:.4f} || Time: {time.time() - start_time:.2f}s")

        start_time = time.time()

        if total_loss < min_loss:
            min_loss = total_loss
            torch.save(net_u.state_dict(), f"{weights_path}/best_u.pth")
            torch.save(net_v.state_dict(), f"{weights_path}/best_v.pth")
            torch.save(net_p.state_dict(), f"{weights_path}/best_p.pth")
            print(f"New best model saved with loss: {min_loss:.4f} at epoch {epoch}")

        if epoch % 100 == 0:
            torch.save(net_u.state_dict(), f"{weights_path}/epoch_{epoch}_u.pth")
            torch.save(net_v.state_dict(), f"{weights_path}/epoch_{epoch}_v.pth")
            torch.save(net_p.state_dict(), f"{weights_path}/epoch_{epoch}_p.pth")
            print(f"Checkpoint saved at epoch {epoch}")
    
    loss_df = pd.DataFrame(loss_history)
    loss_df.to_csv(f"results/loss_history.csv", index=False)

    plt.figure(figsize=(10, 6))
    plt.plot(loss_history['epoch'], loss_history['loss_u'], label='Loss U', color='blue')
    plt.plot(loss_history['epoch'], loss_history['loss_v'], label='Loss V', color='orange')
    plt.plot(loss_history['epoch'], loss_history['loss_p'], label='Loss P', color='green')
    plt.plot(loss_history['epoch'], loss_history['total_loss'], label='Total Loss', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss History')
    plt.legend()
    plt.grid()
    plt.savefig(f"results/loss_history.png")
    plt.show()

    print("Training complete.")

In [None]:
if __name__ == "__main__":
    main()

Dataset size: 99603
Number of batches: 50
Epoch 0/500, Total Loss: 0.0708 || Loss U: 0.0206, Loss V: 0.0005, Loss P: 0.0012 || Inlet Loss: 0.0041, Outlet Loss: 0.0000, Wall Loss: 0.0000, Periodic Loss: 0.0046 || PDE Losses: x: 0.0624, y: 0.0000, continuity: 0.0001 || Time: 4.49s
New best model saved with loss: 0.0708 at epoch 0
Checkpoint saved at epoch 0
Epoch 1/500, Total Loss: 0.0700 || Loss U: 0.0206, Loss V: 0.0005, Loss P: 0.0012 || Inlet Loss: 0.0041, Outlet Loss: 0.0000, Wall Loss: 0.0000, Periodic Loss: 0.0141 || PDE Losses: x: 0.0624, y: 0.0000, continuity: 0.0000 || Time: 4.73s
New best model saved with loss: 0.0700 at epoch 1
Epoch 2/500, Total Loss: 0.0844 || Loss U: 0.0206, Loss V: 0.0005, Loss P: 0.0012 || Inlet Loss: 0.0042, Outlet Loss: 0.0000, Wall Loss: 0.0000, Periodic Loss: 0.0017 || PDE Losses: x: 0.0622, y: 0.0001, continuity: 0.0000 || Time: 4.69s
Epoch 3/500, Total Loss: 0.0858 || Loss U: 0.0206, Loss V: 0.0005, Loss P: 0.0012 || Inlet Loss: 0.0041, Outlet Loss