In [6]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm


class ReversePINN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ReversePINN, self).__init__()
        # Define neural network layers
        self.layers = nn.ModuleList(
            [
                (
                    nn.Linear(input_size if i == 0 else hidden_size, hidden_size)
                    if i % 2 == 0
                    else nn.Tanh()
                )
                for i in range(20)
            ]
        )
        # Output layer for concentration (C)
        self.layers.append(nn.Linear(hidden_size, output_size))

        # Learnable parameters: D and Kd (initially set as trainable parameters)
        self.D = nn.Parameter(
            torch.tensor([10.0], dtype=torch.float32, requires_grad=True)
        )  # Diffusion coefficient
        self.Kd = nn.Parameter(
            torch.tensor([10.0], dtype=torch.float32, requires_grad=True)
        )  # Sorption coefficient

        # Loss function and optimizer
        self.loss = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)

    def forward(self, x):
        # Feedforward pass to get concentration profile
        for layer in self.layers:
            x = layer(x)
        return x

    def residual_loss(self, xtrain):
        g = xtrain.clone()
        g.requires_grad = True
        u_pred = self.forward(g)  # Predicted concentration

        # Compute gradients for the residuals
        u_x_t = torch.autograd.grad(
            u_pred,
            g,
            torch.ones([g.shape[0], 1]).to("cuda"),
            retain_graph=True,
            create_graph=True,
        )[0]
        u_xx_tt = torch.autograd.grad(
            u_x_t, g, torch.ones(g.shape).to("cuda"), create_graph=True
        )[0]

        # Diffusion term: u_xx (second spatial derivative)
        u_xx = u_xx_tt[:, [0]]

        # Physics-informed residual (diffusion-sorption-reaction equation)
        residual = u_x_t + self.D * u_xx - self.Kd * u_pred

        return self.loss(
            residual, torch.zeros_like(residual).to("cuda")
        )  # Residual should be zero

    def total_loss(self, xtrain, utrain):
        # Data loss: Compare predicted concentration to true concentration
        data_loss = self.loss(self.forward(xtrain), utrain)

        # Residual loss: Physics-informed
        residual_loss = self.residual_loss(xtrain)

        return data_loss + residual_loss

    def train_model(self, xtrain, utrain, epochs=1000):
        for epoch in tqdm(range(epochs)):
            self.optimizer.zero_grad()
            loss = self.total_loss(xtrain, utrain)
            loss.backward()
            self.optimizer.step()
            if epoch % 100 == 0:
                print(
                    f"Epoch {epoch}, Loss {loss.item()}, D {self.D.item()}, Kd {self.Kd.item()}"
                )

In [2]:
import h5py

Nx = 100
Nt = 201

with h5py.File("diff_sorp.h5", "r") as f:
    dtrue = f["1"]["D"][()]
    ktrue = f["1"]["Kd"][()]
    C_all = f["1"]["data"][:]
    f.close()
# Create input (x, t) and output (C) tensors
x = np.linspace(0, 1, Nx)
t = np.linspace(0, 1, Nt)
X, T = np.meshgrid(x, t)
X_train = torch.tensor(
    np.vstack((X.flatten(), T.flatten())).T, dtype=torch.float32
).cuda()
C_train = torch.tensor(C_all.flatten(), dtype=torch.float32).cuda()
X_train.shape, C_train.shape

(torch.Size([20100, 2]), torch.Size([20100]))

In [3]:
dtrue, ktrue

(82.44376791378382, 92.0820718580985)

In [4]:
idx = np.random.choice(Nx * Nt, 10000, replace=False)
X_train = X_train[idx]
C_train = C_train[idx]
X_train.shape, C_train.shape

(torch.Size([10000, 2]), torch.Size([10000]))

In [7]:
model = ReversePINN(2, 20, 1).cuda()
model.train_model(X_train, C_train, epochs=10000)

  0%|          | 12/10000 [00:00<02:51, 58.37it/s]

Epoch 0, Loss 3.148857593536377, D 9.999899864196777, Kd 9.999899864196777


  1%|          | 110/10000 [00:01<02:40, 61.44it/s]

Epoch 100, Loss 0.186541348695755, D 10.0056734085083, Kd 9.994372367858887


  2%|▏         | 208/10000 [00:03<02:33, 63.69it/s]

Epoch 200, Loss 0.15570379793643951, D 10.005768775939941, Kd 9.994296073913574


  3%|▎         | 313/10000 [00:04<02:34, 62.62it/s]

Epoch 300, Loss 0.15547527372837067, D 10.003291130065918, Kd 9.994296073913574


  4%|▍         | 411/10000 [00:06<02:30, 63.53it/s]

Epoch 400, Loss 0.15524114668369293, D 10.000638961791992, Kd 9.994296073913574


  5%|▌         | 509/10000 [00:08<02:31, 62.57it/s]

Epoch 500, Loss 0.15500888228416443, D 9.997982025146484, Kd 9.994296073913574


  6%|▌         | 607/10000 [00:09<02:30, 62.51it/s]

Epoch 600, Loss 0.1547805517911911, D 9.995450973510742, Kd 9.994296073913574


  7%|▋         | 712/10000 [00:11<02:24, 64.24it/s]

Epoch 700, Loss 0.15455478429794312, D 9.993144989013672, Kd 9.994296073913574


  8%|▊         | 810/10000 [00:12<02:31, 60.72it/s]

Epoch 800, Loss 0.15432755649089813, D 9.991159439086914, Kd 9.994296073913574


  9%|▉         | 907/10000 [00:14<02:25, 62.44it/s]

Epoch 900, Loss 0.15409229695796967, D 9.989611625671387, Kd 9.994296073913574


 10%|█         | 1011/10000 [00:16<02:25, 61.75it/s]

Epoch 1000, Loss 0.15383917093276978, D 9.98864459991455, Kd 9.994296073913574


 11%|█         | 1109/10000 [00:17<02:24, 61.50it/s]

Epoch 1100, Loss 0.1535528153181076, D 9.988454818725586, Kd 9.994296073913574


 12%|█▏        | 1210/10000 [00:19<02:21, 61.97it/s]

Epoch 1200, Loss 0.15320777893066406, D 9.98931884765625, Kd 9.994296073913574


 13%|█▎        | 1308/10000 [00:20<02:14, 64.40it/s]

Epoch 1300, Loss 0.15275883674621582, D 9.991662979125977, Kd 9.994296073913574


 14%|█▍        | 1413/10000 [00:22<02:17, 62.46it/s]

Epoch 1400, Loss 0.15211786329746246, D 9.996142387390137, Kd 9.994255065917969


 15%|█▌        | 1511/10000 [00:24<02:15, 62.69it/s]

Epoch 1500, Loss 0.15110065042972565, D 10.003744125366211, Kd 9.994159698486328


 16%|█▌        | 1611/10000 [00:25<02:23, 58.63it/s]

Epoch 1600, Loss 0.14932839572429657, D 10.015676498413086, Kd 9.994064331054688


 17%|█▋        | 1711/10000 [00:27<02:21, 58.78it/s]

Epoch 1700, Loss 0.146308034658432, D 10.032292366027832, Kd 9.99390983581543


 18%|█▊        | 1811/10000 [00:29<02:25, 56.13it/s]

Epoch 1800, Loss 0.1422041803598404, D 10.051521301269531, Kd 9.99366283416748


 19%|█▉        | 1914/10000 [00:31<02:09, 62.67it/s]

Epoch 1900, Loss 0.1367885172367096, D 10.072051048278809, Kd 9.993273735046387


 20%|██        | 2012/10000 [00:32<02:06, 63.19it/s]

Epoch 2000, Loss 0.12886011600494385, D 10.094144821166992, Kd 9.992734909057617


 21%|██        | 2110/10000 [00:34<02:03, 64.06it/s]

Epoch 2100, Loss 0.11626271158456802, D 10.117528915405273, Kd 9.99202823638916


 22%|██▏       | 2208/10000 [00:35<02:00, 64.44it/s]

Epoch 2200, Loss 0.10063805431127548, D 10.13970947265625, Kd 9.991294860839844


 23%|██▎       | 2313/10000 [00:37<01:59, 64.36it/s]

Epoch 2300, Loss 0.09514500200748444, D 10.152056694030762, Kd 9.990806579589844


 24%|██▍       | 2411/10000 [00:38<02:02, 62.20it/s]

Epoch 2400, Loss 0.09436696022748947, D 10.157161712646484, Kd 9.990519523620605


 25%|██▌       | 2509/10000 [00:40<01:56, 64.24it/s]

Epoch 2500, Loss 0.09412041306495667, D 10.160111427307129, Kd 9.990326881408691


 26%|██▌       | 2614/10000 [00:42<01:54, 64.70it/s]

Epoch 2600, Loss 0.09394551813602448, D 10.162301063537598, Kd 9.990135192871094


 27%|██▋       | 2712/10000 [00:43<01:52, 64.87it/s]

Epoch 2700, Loss 0.09382367879152298, D 10.164353370666504, Kd 9.989944458007812


 28%|██▊       | 2810/10000 [00:45<01:55, 62.14it/s]

Epoch 2800, Loss 0.09373323619365692, D 10.166158676147461, Kd 9.989765167236328


 29%|██▉       | 2908/10000 [00:46<01:50, 64.26it/s]

Epoch 2900, Loss 0.09365847706794739, D 10.167877197265625, Kd 9.989574432373047


 30%|███       | 3013/10000 [00:48<01:51, 62.78it/s]

Epoch 3000, Loss 0.09359855949878693, D 10.169474601745605, Kd 9.989402770996094


 31%|███       | 3111/10000 [00:49<01:49, 62.94it/s]

Epoch 3100, Loss 0.09354732930660248, D 10.170990943908691, Kd 9.989212036132812


 32%|███▏      | 3209/10000 [00:51<01:46, 64.04it/s]

Epoch 3200, Loss 0.09350716322660446, D 10.172483444213867, Kd 9.989031791687012


 33%|███▎      | 3307/10000 [00:53<01:47, 62.32it/s]

Epoch 3300, Loss 0.09346488863229752, D 10.173826217651367, Kd 9.98885440826416


 34%|███▍      | 3412/10000 [00:54<01:41, 64.67it/s]

Epoch 3400, Loss 0.09344291687011719, D 10.17516803741455, Kd 9.988666534423828


 35%|███▌      | 3510/10000 [00:56<01:47, 60.47it/s]

Epoch 3500, Loss 0.09340085834264755, D 10.176380157470703, Kd 9.988495826721191


 36%|███▌      | 3607/10000 [00:57<01:44, 61.42it/s]

Epoch 3600, Loss 0.09337326139211655, D 10.177613258361816, Kd 9.98830509185791


 37%|███▋      | 3712/10000 [00:59<01:40, 62.37it/s]

Epoch 3700, Loss 0.09335515648126602, D 10.178808212280273, Kd 9.988127708435059


 38%|███▊      | 3810/10000 [01:01<01:39, 62.10it/s]

Epoch 3800, Loss 0.093325175344944, D 10.17993450164795, Kd 9.987943649291992


 39%|███▉      | 3908/10000 [01:02<01:35, 64.09it/s]

Epoch 3900, Loss 0.09330771118402481, D 10.181077003479004, Kd 9.987754821777344


 40%|████      | 4013/10000 [01:04<01:36, 61.99it/s]

Epoch 4000, Loss 0.09328452497720718, D 10.182066917419434, Kd 9.987571716308594


 41%|████      | 4111/10000 [01:05<01:33, 62.83it/s]

Epoch 4100, Loss 0.09326624125242233, D 10.183025360107422, Kd 9.987380981445312


 42%|████▏     | 4209/10000 [01:07<01:34, 61.34it/s]

Epoch 4200, Loss 0.09324843436479568, D 10.184074401855469, Kd 9.987190246582031


 43%|████▎     | 4313/10000 [01:09<01:32, 61.27it/s]

Epoch 4300, Loss 0.09323251992464066, D 10.185027122497559, Kd 9.987004280090332


 44%|████▍     | 4411/10000 [01:10<01:30, 62.03it/s]

Epoch 4400, Loss 0.09321688860654831, D 10.185977935791016, Kd 9.98681354522705


 45%|████▌     | 4503/10000 [01:12<02:06, 43.29it/s]

Epoch 4500, Loss 0.09321710467338562, D 10.186904907226562, Kd 9.986616134643555


 46%|████▌     | 4606/10000 [01:15<02:10, 41.29it/s]

Epoch 4600, Loss 0.09318885952234268, D 10.187758445739746, Kd 9.98642349243164


 47%|████▋     | 4714/10000 [01:17<01:22, 63.82it/s]

Epoch 4700, Loss 0.09317554533481598, D 10.188616752624512, Kd 9.98623275756836


 48%|████▊     | 4812/10000 [01:18<01:21, 63.91it/s]

Epoch 4800, Loss 0.09316359460353851, D 10.189457893371582, Kd 9.986026763916016


 49%|████▉     | 4910/10000 [01:20<01:26, 58.76it/s]

Epoch 4900, Loss 0.09315139800310135, D 10.190242767333984, Kd 9.985832214355469


 50%|█████     | 5013/10000 [01:21<01:16, 65.01it/s]

Epoch 5000, Loss 0.09313968569040298, D 10.19108772277832, Kd 9.985641479492188


 51%|█████     | 5111/10000 [01:23<01:20, 60.87it/s]

Epoch 5100, Loss 0.09312926977872849, D 10.191861152648926, Kd 9.985416412353516


 52%|█████▏    | 5208/10000 [01:25<01:16, 62.71it/s]

Epoch 5200, Loss 0.09311860054731369, D 10.192623138427734, Kd 9.985225677490234


 53%|█████▎    | 5313/10000 [01:26<01:17, 60.45it/s]

Epoch 5300, Loss 0.09315194189548492, D 10.193382263183594, Kd 9.985011100769043


 54%|█████▍    | 5410/10000 [01:28<01:14, 62.02it/s]

Epoch 5400, Loss 0.09309892356395721, D 10.194093704223633, Kd 9.984779357910156


 55%|█████▌    | 5508/10000 [01:29<01:11, 62.48it/s]

Epoch 5500, Loss 0.09308934956789017, D 10.19484806060791, Kd 9.98453140258789


 56%|█████▌    | 5613/10000 [01:31<01:11, 61.56it/s]

Epoch 5600, Loss 0.09308058768510818, D 10.19554328918457, Kd 9.984274864196777


 57%|█████▋    | 5711/10000 [01:33<01:08, 62.44it/s]

Epoch 5700, Loss 0.0930715873837471, D 10.196210861206055, Kd 9.983988761901855


 58%|█████▊    | 5809/10000 [01:34<01:08, 61.28it/s]

Epoch 5800, Loss 0.09306368231773376, D 10.196879386901855, Kd 9.983711242675781


 59%|█████▉    | 5911/10000 [01:36<01:08, 59.68it/s]

Epoch 5900, Loss 0.09305491298437119, D 10.197535514831543, Kd 9.983428955078125


 60%|██████    | 6009/10000 [01:38<01:01, 65.14it/s]

Epoch 6000, Loss 0.09305337816476822, D 10.198187828063965, Kd 9.983138084411621


 61%|██████    | 6107/10000 [01:39<01:01, 63.65it/s]

Epoch 6100, Loss 0.09303908795118332, D 10.198807716369629, Kd 9.982852935791016


 62%|██████▏   | 6211/10000 [01:41<01:02, 60.98it/s]

Epoch 6200, Loss 0.0932098999619484, D 10.19944953918457, Kd 9.982564926147461


 63%|██████▎   | 6309/10000 [01:42<01:04, 57.65it/s]

Epoch 6300, Loss 0.09302407503128052, D 10.200043678283691, Kd 9.982259750366211


 64%|██████▍   | 6411/10000 [01:44<00:59, 60.75it/s]

Epoch 6400, Loss 0.09301663190126419, D 10.200615882873535, Kd 9.981973648071289


 65%|██████▌   | 6508/10000 [01:46<00:57, 60.33it/s]

Epoch 6500, Loss 0.0930098220705986, D 10.201191902160645, Kd 9.981649398803711


 66%|██████▌   | 6613/10000 [01:47<00:54, 61.70it/s]

Epoch 6600, Loss 0.09300263971090317, D 10.20175838470459, Kd 9.981359481811523


 67%|██████▋   | 6711/10000 [01:49<00:52, 63.19it/s]

Epoch 6700, Loss 0.0929984301328659, D 10.202322006225586, Kd 9.98104190826416


 68%|██████▊   | 6807/10000 [01:51<00:59, 53.68it/s]

Epoch 6800, Loss 0.09298911690711975, D 10.202858924865723, Kd 9.980722427368164


 69%|██████▉   | 6908/10000 [01:52<00:48, 63.96it/s]

Epoch 6900, Loss 0.09298302978277206, D 10.203429222106934, Kd 9.980422973632812


 70%|███████   | 7006/10000 [01:54<00:49, 60.30it/s]

Epoch 7000, Loss 0.09297610074281693, D 10.203953742980957, Kd 9.980064392089844


 71%|███████   | 7108/10000 [01:56<00:45, 63.78it/s]

Epoch 7100, Loss 0.09296960383653641, D 10.204434394836426, Kd 9.979697227478027


 72%|███████▏  | 7213/10000 [01:57<00:43, 64.11it/s]

Epoch 7200, Loss 0.09296464174985886, D 10.204944610595703, Kd 9.979318618774414


 73%|███████▎  | 7306/10000 [01:59<00:48, 55.21it/s]

Epoch 7300, Loss 0.09295728802680969, D 10.205424308776855, Kd 9.978950500488281


 74%|███████▍  | 7412/10000 [02:01<00:42, 60.90it/s]

Epoch 7400, Loss 0.09295470267534256, D 10.205913543701172, Kd 9.978557586669922


 75%|███████▌  | 7512/10000 [02:02<00:39, 63.25it/s]

Epoch 7500, Loss 0.09294535219669342, D 10.206387519836426, Kd 9.97818660736084


 76%|███████▌  | 7610/10000 [02:04<00:39, 60.91it/s]

Epoch 7600, Loss 0.0930410623550415, D 10.206857681274414, Kd 9.977798461914062


 77%|███████▋  | 7708/10000 [02:05<00:36, 62.40it/s]

Epoch 7700, Loss 0.09293389320373535, D 10.207327842712402, Kd 9.977411270141602


 78%|███████▊  | 7813/10000 [02:07<00:35, 61.83it/s]

Epoch 7800, Loss 0.09292813390493393, D 10.207804679870605, Kd 9.977029800415039


 79%|███████▉  | 7896/10000 [02:09<00:34, 61.20it/s]


KeyboardInterrupt: 