In [1]:
# Translated PINN class from TensorFlow 1.x to PyTorch

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
import numpy as np
import os
from typing import List, Dict

class FeedForwardNN(nn.Module):
    def __init__(self, layers: List[int]):
        super().__init__()
        modules = []
        for i in range(len(layers) - 2):
            modules.append(nn.Linear(layers[i], layers[i + 1]))
            modules.append(nn.Tanh())
        modules.append(nn.Linear(layers[-2], layers[-1]))
        self.model = nn.Sequential(*modules)

    def forward(self, x):
        return self.model(x)


class PINN:
    def __init__(self, x, y, t, v1, v5, layers, diff_norms, use_pde=True, device='cpu'):
        self.device = device
        self.use_pde = use_pde
        self.diff_norms = diff_norms
        self.loss_history = {'v1': [], 'v5': [], 'f1': [], 'f5': []} if use_pde else {'v1': [], 'v5': []}

        # Convert data to tensors
        self.x = torch.tensor(x, dtype=torch.float32).to(device).requires_grad_()
        self.y = torch.tensor(y, dtype=torch.float32).to(device).requires_grad_()
        self.t = torch.tensor(t, dtype=torch.float32).to(device).requires_grad_()
        self.v1_true = torch.tensor(v1, dtype=torch.float32).to(device)
        self.v5_true = torch.tensor(v5, dtype=torch.float32).to(device)

        X = torch.cat([self.x, self.y, self.t], dim=1)
        self.lb = X.min(0).values
        self.ub = X.max(0).values

        # Networks
        self.net_v1 = FeedForwardNN(layers).to(device)
        self.net_v5 = FeedForwardNN(layers).to(device)
        if use_pde:
            self.net_v2 = FeedForwardNN(layers).to(device)
            self.net_v3 = FeedForwardNN(layers).to(device)
            self.net_v4 = FeedForwardNN(layers).to(device)

        self.params = list(self.net_v1.parameters()) + list(self.net_v5.parameters())
        if use_pde:
            self.params += list(self.net_v2.parameters())
            self.params += list(self.net_v3.parameters())
            self.params += list(self.net_v4.parameters())

        self.optimizer = optim.Adam(self.params, lr=1e-3)

    def normalize(self, x):
        return 2.0 * (x - self.lb) / (self.ub - self.lb + 1e-6) - 1.0

    def forward_net(self, net, x, y, t):
        xyt = torch.cat([x, y, t], dim=1)
        xyt_norm = self.normalize(xyt)
        return net(xyt_norm)

    def compute_derivative(self, f, x, order=1):
        for _ in range(order):
            f = grad(f, x, torch.ones_like(f), create_graph=True)[0]
        return f

    def compute_source_terms(self, x, v1, v5):
        X_SRC = 0.5
        SIG_SRC = 0.1
        N_SRC_A = 1.0
        ENER_SRC_A = 1.0
        
        S_n = N_SRC_A * torch.exp(-((x - X_SRC)**2) / (2.0 * SIG_SRC**2))
        S_Ee = ENER_SRC_A * torch.exp(-((x - X_SRC)**2) / (2.0 * SIG_SRC**2))

        cond1Sn = S_n[:, 0] > 0.01
        S_n = torch.where(cond1Sn, S_n[:, 0], 0.001 * torch.ones_like(S_n[:, 0]))
        cond1SEe = S_Ee[:, 0] > 0.01
        S_Ee = torch.where(cond1SEe, S_Ee[:, 0], 0.001 * torch.ones_like(S_Ee[:, 0]))

        cond2Sn = x[:, 0] > X_SRC
        S_n = torch.where(cond2Sn, S_n, 0.5 * torch.ones_like(S_n))
        cond2SEe = x[:, 0] > X_SRC
        S_Ee = torch.where(cond2SEe, S_Ee, 0.5 * torch.ones_like(S_Ee))

        cond4Sn = v1[:, 0] > 5.0
        S_n = torch.where(cond4Sn, torch.zeros_like(S_n), S_n)
        cond4SEe = v5[:, 0] > 1.0
        S_Ee = torch.where(cond4SEe, torch.zeros_like(S_Ee), S_Ee)

        return S_n.unsqueeze(1), S_Ee.unsqueeze(1)

    def loss_function(self):
        x, y, t = self.x, self.y, self.t
        v1 = self.forward_net(self.net_v1, x, y, t)
        v5 = self.forward_net(self.net_v5, x, y, t)

        loss_v1 = torch.mean((v1 - self.v1_true) ** 2)
        loss_v5 = torch.mean((v5 - self.v5_true) ** 2)

        if not self.use_pde:
            return loss_v1 + loss_v5

        v2 = self.forward_net(self.net_v2, x, y, t)
        v3 = self.forward_net(self.net_v3, x, y, t)
        v4 = self.forward_net(self.net_v4, x, y, t)

        # Derivatives
        v1_t = self.compute_derivative(v1, t)
        v1_x = self.compute_derivative(v1, x)
        v1_y = self.compute_derivative(v1, y)
        v5_t = self.compute_derivative(v5, t)
        v5_x = self.compute_derivative(v5, x)
        v5_y = self.compute_derivative(v5, y)
        v2_x = self.compute_derivative(v2, x)
        v2_y = self.compute_derivative(v2, y)

        # Physics constants (placeholders for now)
        MINOR_RADIUS = 0.22
        TAU_T = 1.0
        EPS_R = 1.0
        ALPHA_D = 1.0
        EPS_V = 1.0
        ETA = 1.0
        MASS_RATIO = 1.0

        B = (0.22 + 0.68) / (0.68 + 0.22 + MINOR_RADIUS * x)
        pe = v1 * v5
        pe_y = self.compute_derivative(pe, y)
        jp = v1 * ((TAU_T ** 0.5) * v4 - v3)

        lnn = torch.log(v1 + 1e-6)
        lnTe = torch.log(v5 + 1e-6)

        lnn_xxxx = self.compute_derivative(lnn, x, 4)
        lnn_yyyy = self.compute_derivative(lnn, y, 4)
        lnTe_xxxx = self.compute_derivative(lnTe, x, 4)
        lnTe_yyyy = self.compute_derivative(lnTe, y, 4)

        D_lnn = -((50. / self.diff_norms['DiffX_norm']) ** 2.) * lnn_xxxx + \
                -((50. / self.diff_norms['DiffY_norm']) ** 2.) * lnn_yyyy

        D_lnTe = -((50. / self.diff_norms['DiffX_norm']) ** 2.) * lnTe_xxxx + \
                 -((50. / self.diff_norms['DiffY_norm']) ** 2.) * lnTe_yyyy

        S_n, S_Ee = self.compute_source_terms(x, v1, v5)

        f_v1 = v1_t + (1. / B) * (v2_y * v1_x - v2_x * v1_y) - (-EPS_R * (v1 * v2_y - ALPHA_D * pe_y) + S_n + v1 * D_lnn)
        f_v5 = v5_t + (1. / B) * (v2_y * v5_x - v2_x * v5_y) - v5 * (
            5. * EPS_R * ALPHA_D * v5_y / 3. +
            (2. / 3.) * (-EPS_R * (v2_y - ALPHA_D * pe_y / (v1 + 1e-6)) +
                        (1. / (v1 + 1e-6)) * (ETA * jp * jp / (v5 * MASS_RATIO + 1e-6))) +
            (2. / (3. * pe + 1e-6)) * S_Ee + D_lnTe)

        loss_f1 = torch.mean(f_v1 ** 2)
        loss_f5 = torch.mean(f_v5 ** 2)

        return loss_v1 + loss_v5 + loss_f1 + loss_f5

    def train_step(self):
        self.optimizer.zero_grad()
        loss = self.loss_function()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def predict(self, x_star, y_star, t_star):
        x_star = torch.tensor(x_star, dtype=torch.float32).to(self.device).requires_grad_(True)
        y_star = torch.tensor(y_star, dtype=torch.float32).to(self.device).requires_grad_(True)
        t_star = torch.tensor(t_star, dtype=torch.float32).to(self.device).requires_grad_(True)

        with torch.no_grad():
            v1 = self.forward_net(self.net_v1, x_star, y_star, t_star)
            v5 = self.forward_net(self.net_v5, x_star, y_star, t_star)
            result = {'v1': v1.cpu().numpy(), 'v5': v5.cpu().numpy()}
            if self.use_pde:
                v2 = self.forward_net(self.net_v2, x_star, y_star, t_star)
                v3 = self.forward_net(self.net_v3, x_star, y_star, t_star)
                v4 = self.forward_net(self.net_v4, x_star, y_star, t_star)
                result.update({
                    'v2': v2.cpu().numpy(),
                    'v3': v3.cpu().numpy(),
                    'v4': v4.cpu().numpy()
                })
        return result

    def save(self, path):
        os.makedirs(path, exist_ok=True)
        torch.save({
            'net_v1': self.net_v1.state_dict(),
            'net_v5': self.net_v5.state_dict(),
            'net_v2': self.net_v2.state_dict() if self.use_pde else None,
            'net_v3': self.net_v3.state_dict() if self.use_pde else None,
            'net_v4': self.net_v4.state_dict() if self.use_pde else None,
            'diff_norms': self.diff_norms,
            'loss_history': self.loss_history
        }, os.path.join(path, 'checkpoint.pth'))

    def load(self, path):
        checkpoint = torch.load(os.path.join(path, 'checkpoint.pth'))
        self.net_v1.load_state_dict(checkpoint['net_v1'])
        self.net_v5.load_state_dict(checkpoint['net_v5'])
        if self.use_pde:
            self.net_v2.load_state_dict(checkpoint['net_v2'])
            self.net_v3.load_state_dict(checkpoint['net_v3'])
            self.net_v4.load_state_dict(checkpoint['net_v4'])
        self.diff_norms = checkpoint['diff_norms']
        self.loss_history = checkpoint['loss_history']