In [1]:
import os
import copy
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import random
import matplotlib.pyplot as plt

In [2]:
# ====== 1. Reproducibility & Paths ======
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# === monod paths (kept as-is) ===
DATA_DIR = "/home/yehoon/npz/amv/monod/monod_data/"
MODEL_SAVE_PATH = "./model/PRT_Monod/PRT_Monod.pt"
RESULT_SAVE_DIR = "./model/PRT_Monod/PRT_Monod"
os.makedirs(RESULT_SAVE_DIR, exist_ok=True)

BATCH_SIZE = 25
NUM_EPOCHS = 200
LR = 1e-3
NX, NY = 64, 148
TRUNK_DIM = 4     # monod: (x, y, t_norm, dist_inlet)
BRANCH2_DIM = 2   # (Pe, Da) for Monod setting
OUT_DIM = 128

In [3]:
# ====== 2. Data Loader (kept paths exactly the same) ======
def load_dataset():
    # NOTE: os.path.join(DATA_DIR, <absolute_path>) ignores DATA_DIR by design.
    # This mirrors your current code and will load from the absolute paths below.
    train_dataset = torch.load(os.path.join(DATA_DIR, "/home/yehoon/npz/model/PRT_monod/monod_train_dataset_trunk4.pt"))
    test_dataset  = torch.load(os.path.join(DATA_DIR, "/home/yehoon/npz/model/PRT_monod/monod_test_dataset_trunk4.pt"))
    return train_dataset, test_dataset

In [5]:
# ====== 3. Model Definition (paper-style names) ======
class BranchCNN(nn.Module):
    """Geometry branch: CNN encoder for (1, H, W) binary pore image."""
    def __init__(self, in_channels, out_dim, num_blocks=5):  # monod often uses 5 blocks
        super().__init__()
        channels = [in_channels, 16, 32, 64, 128, 256][:num_blocks+1]
        layers = []
        for i in range(num_blocks):
            layers += [
                nn.Conv2d(channels[i], channels[i+1], 3, 1, 1),
                nn.SiLU(),
                nn.AvgPool2d(2),
            ]
        self.features = nn.Sequential(*layers)
        h, w = NX, NY
        for _ in range(num_blocks):
            h //= 2
            w //= 2
        self.fc = nn.Linear(channels[num_blocks]*h*w, out_dim)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class BranchFNN(nn.Module):
    """Parameter branch: FNN for (Pe, Da)."""
    def __init__(self, in_dim=2, out_dim=128, hidden_dim=128, num_layers=3):
        super().__init__()
        layers = [nn.Linear(in_dim, hidden_dim), nn.SiLU()]
        for _ in range(num_layers-2):
            layers += [nn.Linear(hidden_dim, hidden_dim), nn.SiLU()]
        layers += [nn.Linear(hidden_dim, out_dim)]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class Trunk(nn.Module):
    """Trunk network: FNN on (x, y, t_norm, dist_inlet)."""
    def __init__(self, trunk_in_dim, out_dim, num_layers=8, width=128):
        super().__init__()
        layers = [nn.Linear(trunk_in_dim, width), nn.SiLU()]
        for _ in range(num_layers-2):
            layers += [nn.Linear(width, width), nn.SiLU()]
        layers += [nn.Linear(width, out_dim)]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class PRT_DeepONet_Monod(nn.Module):
    """
    Paper-style DeepONet:
      y(x) = Σ_k  BranchCNN_k(geom) * BranchFNN_k(params) * Trunk_k(query) + b
    where:
      - geom   : (B,1,H,W)
      - params : (B, BRANCH2_DIM)
      - query  : (B, 1, L, TRUNK_DIM), L = H*W (or H*W*T for transient)
    """
    def __init__(self, nx=64, ny=148, trunk_in_dim=4, out_dim=128, branch2_in_dim=2, cnn_blocks=5):
        super().__init__()
        self.nx, self.ny = nx, ny
        self.branch_geom = BranchCNN(1, out_dim, num_blocks=cnn_blocks)
        self.branch_param = BranchFNN(branch2_in_dim, out_dim, hidden_dim=128, num_layers=3)
        self.trunk = Trunk(trunk_in_dim, out_dim, num_layers=8, width=128)
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, branch1_input, branch2_input, trunk_input):
        # branch1_input: (B,1,NX,NY)
        # branch2_input: (B,BRANCH2_DIM)
        # trunk_input  : (B, 1, L, TRUNK_DIM) or (B, L, TRUNK_DIM)
        if trunk_input.ndim == 4:
            trunk_input = trunk_input.squeeze(1)
        B, L, D = trunk_input.shape

        t = self.trunk(trunk_input)                                 # (B, L, C)
        t = t.unsqueeze(1)                                          # (B, 1, L, C)
        g = self.branch_geom(branch1_input).unsqueeze(1).unsqueeze(2)   # (B, 1, 1, C)
        p = self.branch_param(branch2_input).unsqueeze(1).unsqueeze(2)  # (B, 1, 1, C)

        out = (g * p * t).sum(-1) + self.bias                       # (B, 1, L)
        out = out.view(B, self.nx, self.ny, 1)                      # (B, NX, NY, 1)
        return out

In [6]:
# ====== 4. Training Utilities (same skeleton) ======
def train_model(
    model, train_dataset, test_dataset, num_epochs=1000, lr=0.001, batch_size=128, patience=15):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.HuberLoss(delta=1.0)
    scaler = torch.cuda.amp.GradScaler()

    class EarlyStopping:
        def __init__(self, patience=15, delta=1e-5, verbose=True):
            self.patience = patience
            self.delta = delta
            self.verbose = verbose
            self.counter = 0
            self.best_loss = None
            self.early_stop = False
            self.best_model_state = None
        def __call__(self, val_loss, model):
            if self.best_loss is None or val_loss < self.best_loss - self.delta:
                self.best_loss = val_loss
                self.counter = 0
                self.best_model_state = copy.deepcopy(model.state_dict())
                if self.verbose:
                    print(f"Validation loss decreased. New best loss: {val_loss:.6f}")
            else:
                self.counter += 1
                if self.verbose:
                    print(f"No improvement in validation loss. Counter: {self.counter}/{self.patience}")
                if self.counter >= self.patience:
                    self.early_stop = True

    early_stopping = EarlyStopping(patience=15, delta=1e-5, verbose=True)
    train_losses, test_losses = [], []

    for epoch in range(num_epochs):
        # --- Training ---
        model.train()
        running_train_loss = 0.0
        for batch in train_loader:
            batch_branch1, batch_branch2, batch_trunk, batch_target = [b.to(device) for b in batch]
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                preds = model(batch_branch1, batch_branch2, batch_trunk)
                loss = criterion(preds, batch_target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_train_loss += loss.item()
        avg_train_loss = running_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # --- Validation ---
        model.eval()
        running_test_loss = 0.0
        with torch.no_grad():
            for batch in test_loader:
                batch_branch1, batch_branch2, batch_trunk, batch_target = [b.to(device) for b in batch]
                with torch.cuda.amp.autocast():
                    preds = model(batch_branch1, batch_branch2, batch_trunk)
                    loss = criterion(preds, batch_target)
                running_test_loss += loss.item()
        avg_test_loss = running_test_loss / len(test_loader)
        test_losses.append(avg_test_loss)

        print(f"Epoch [{epoch+1}/{num_epochs}] Train: {avg_train_loss:.6f} | Test: {avg_test_loss:.6f}")

        early_stopping(avg_test_loss, model)
        if early_stopping.early_stop:
            print("Early stopping triggered. Restoring best model state.")
            model.load_state_dict(early_stopping.best_model_state)
            break

    return train_losses, test_losses

In [7]:
# ====== 5. Evaluation Example (same skeleton) ======
def evaluate(model, test_dataset, num_samples=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    model.to(device)
    for idx in range(min(num_samples, len(test_dataset))):
        b1, b2, trunk, y_true = [x[idx:idx+1].to(device) for x in test_dataset.tensors]
        with torch.no_grad():
            y_pred = model(b1, b2, trunk)
        y_true_np = y_true.cpu().numpy()[0, ..., 0]
        y_pred_np = y_pred.cpu().numpy()[0, ..., 0]
        plt.figure(figsize=(8,3))
        plt.subplot(1,2,1); plt.imshow(y_true_np, cmap='viridis'); plt.title('Ground Truth'); plt.axis('off')
        plt.subplot(1,2,2); plt.imshow(y_pred_np, cmap='viridis'); plt.title('Prediction');  plt.axis('off')
        plt.suptitle(f"Sample #{idx}")
        plt.show()

In [None]:
# ====== 6. Main Entry (same skeleton) ======
if __name__ == "__main__":
    # 1) Load Data
    train_dataset, test_dataset = load_dataset()
    # Shapes (for reference):
    # train_dataset.tensors[0]: (N, 1, 64, 148)        # geometry
    # train_dataset.tensors[1]: (N, 2)                 # (Pe, Da)
    # train_dataset.tensors[2]: (N, 9472, 4)           # (x, y, t_norm, dist_inlet)
    # train_dataset.tensors[3]: (N, 64, 148, 1)        # target

    # 2) Build Model (paper-style names)
    model = PRT_DeepONet_Monod(
        nx=NX, ny=NY, trunk_in_dim=TRUNK_DIM, out_dim=OUT_DIM,
        branch2_in_dim=BRANCH2_DIM, cnn_blocks=5
    )

    # 3) Train
    train_losses, test_losses = train_model(
        model, train_dataset, test_dataset,
        num_epochs=NUM_EPOCHS, lr=LR, batch_size=BATCH_SIZE
    )

    # 4) Evaluate
    evaluate(model, test_dataset)

    # 5) Save
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"Model saved to {MODEL_SAVE_PATH}")

    # 6) (Optional) Plot losses
    plt.figure()
    plt.plot(train_losses, label="Train")
    plt.plot(test_losses, label="Test")
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.legend(); plt.tight_layout(); plt.show()

Epoch [1/200] Train: 0.001885 | Test: 0.001330
Validation loss decreased. New best loss: 0.001330
Epoch [2/200] Train: 0.001039 | Test: 0.001186
Validation loss decreased. New best loss: 0.001186
Epoch [3/200] Train: 0.000420 | Test: 0.001122
Validation loss decreased. New best loss: 0.001122
Epoch [4/200] Train: 0.000288 | Test: 0.001095
Validation loss decreased. New best loss: 0.001095
Epoch [5/200] Train: 0.000235 | Test: 0.001044
Validation loss decreased. New best loss: 0.001044
Epoch [6/200] Train: 0.000207 | Test: 0.001065
No improvement in validation loss. Counter: 1/15
Epoch [7/200] Train: 0.000189 | Test: 0.001040
No improvement in validation loss. Counter: 2/15
Epoch [8/200] Train: 0.000175 | Test: 0.001035
No improvement in validation loss. Counter: 3/15
Epoch [9/200] Train: 0.000167 | Test: 0.001026
Validation loss decreased. New best loss: 0.001026
