In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
import time
import scipy.io

# Устройство
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
os.makedirs('./results_siren/', exist_ok=True)


# =============================
# SIREN (как в INR.pdf)
# =============================
class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
            else:
                self.linear.weight.uniform_(
                    -np.sqrt(6 / self.in_features) / self.omega_0,
                    np.sqrt(6 / self.in_features) / self.omega_0
                )

    def forward(self, x):
        return torch.sin(self.omega_0 * self.linear(x))


class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features,
                 outermost_linear=True, first_omega_0=30, hidden_omega_0=30.):
        super().__init__()
        self.net = nn.ModuleList()
        self.net.append(SineLayer(in_features, hidden_features,
                                  is_first=True, omega_0=first_omega_0))

        for _ in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features,
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Linear(hidden_features, out_features)
            with torch.no_grad():
                final_linear.weight.uniform_(
                    -np.sqrt(6 / hidden_features) / hidden_omega_0,
                    np.sqrt(6 / hidden_features) / hidden_omega_0
                )
                final_linear.bias.fill_(0.0)
            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features,
                                      is_first=False, omega_0=hidden_omega_0))

    def forward(self, coords):
        # coords: [N, dim+1], должен иметь requires_grad=True извне
        x = coords
        for layer in self.net:
            x = layer(x)
        return x  # [N, 1]


# =============================
# Сеть v (остаётся MLP)
# =============================
class NetV(nn.Module):
    def __init__(self, input_dim, output_dim, layers, hidden_size):
        super(NetV, self).__init__()
        self.layers = layers
        self.hidden_size = hidden_size
        self.input_layer = nn.Linear(input_dim, hidden_size)
        self.h_layers = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(layers)])
        self.output_layer = nn.Linear(hidden_size, output_dim)
        self.activation_tanh = nn.Tanh()
        self.activation_softplus = nn.Softplus()

    def forward(self, x):
        hi = self.activation_tanh(self.input_layer(x))
        for i, layer in enumerate(self.h_layers):
            if i % 2 == 0:
                hi = self.activation_softplus(layer(hi))
            else:
                hi = torch.sin(layer(hi))
        return self.output_layer(hi)


# =============================
# Основной решатель (SIREN-WAN)
# =============================
class ParabolicWanPDESolver(nn.Module):
    def __init__(self, dim, N_x, N_t, N_bd, file_path,
                 beta_int=100.0, beta_intw=500.0, beta_bd=1000.0,
                 v_step=1, v_rate=0.015, u_step=1, u_rate=0.0001,
                 u_hidden_dim=64, u_hidden_layers=4,
                 iteration=20, device='cuda'):
        super().__init__()
        self.device = device
        self.dim = dim
        self.t0, self.t1 = -1.0, 1.0
        self.low, self.up = -1.0, 1.0
        self.iteration = iteration
        self.dir = file_path

        # Параметры точного решения
        self.la = np.pi / 2
        self.pho = 2.0
        self.mu = self.la**2 - 1

        # Размеры сэмплирования
        self.N_x = N_x
        self.N_t = N_t
        self.N_bd = N_bd

        # Потери
        self.beta_int = beta_int
        self.beta_intw = beta_intw
        self.beta_bd = beta_bd
        self.u_step = u_step
        self.v_step = v_step

        # Сети — ЗАМЕНА: XNODE → SIREN
        self.net_u = Siren(
            in_features=dim + 1,
            hidden_features=u_hidden_dim,
            hidden_layers=u_hidden_layers,
            out_features=1,
            outermost_linear=True,
            first_omega_0=1.0,
            hidden_omega_0=1.0
        ).to(device)

        self.net_v = NetV(dim + 1, 1, layers=6, hidden_size=40).to(device)

        self.optimizer_u = torch.optim.Adam(self.net_u.parameters(), lr=u_rate)
        self.optimizer_v = torch.optim.Adagrad(self.net_v.parameters(), lr=v_rate)

        print(f"SIREN-WAN initialized: dim={dim}, N_x={N_x}, N_t={N_t}")

    def sample_train(self, N_x, N_t, N_bd):
        low, up, t0, t1 = self.low, self.up, self.t0, self.t1

        # Сэмплируем пространственные и временные точки отдельно
        x_spatial = np.random.uniform(low, up, (N_x, self.dim))  # [N_x, dim]
        t_temporal = np.linspace(t0, t1, N_t)                   # [N_t]

        # Декартово произведение: [N_x * N_t, dim+1]
        x_dm_list = []
        for t in t_temporal:
            for x in x_spatial:
                x_dm_list.append(np.concatenate([x, [t]]))
        x_dm_full = np.array(x_dm_list)

        # f(x,t) = μu - u²
        u_exact = self.pho * np.sin(self.la * x_dm_full[:, 0:1]) * np.exp((self.mu - self.la**2) * x_dm_full[:, -1:])
        f_dm = self.mu * u_exact - u_exact**2

        # Начальные условия (t = t0)
        x_init = np.random.uniform(low, up, (N_bd, self.dim))
        t_init = np.full((N_bd, 1), t0)
        x_init_full = np.concatenate([x_init, t_init], axis=1)
        u_init = self.pho * np.sin(self.la * x_init[:, 0:1]) * np.exp((self.mu - self.la**2) * t0)
        u_init = u_init.reshape(-1, 1)

        # Конечные точки (t = t1)
        x_right = np.random.uniform(low, up, (N_bd, self.dim))
        t_right = np.full((N_bd, 1), t1)
        x_right_full = np.concatenate([x_right, t_right], axis=1)

        # Граничные точки
        x_bd_list = []
        for i in range(self.dim):
            x_bound = np.random.uniform(low, up, (N_bd, self.dim))
            t_bound = np.random.uniform(t0, t1, (N_bd, 1))
            x_bound[:, i] = up
            x_bd_list.append(np.concatenate([x_bound, t_bound], axis=1))
            x_bound = np.random.uniform(low, up, (N_bd, self.dim))
            x_bound[:, i] = low
            x_bd_list.append(np.concatenate([x_bound, t_bound], axis=1))
        x_bd = np.concatenate(x_bd_list, axis=0)
        u_bd = self.pho * np.sin(self.la * x_bd[:, 0:1]) * np.exp((self.mu - self.la**2) * x_bd[:, -1:])
        u_bd = u_bd.reshape(-1, 1)

        # В тензоры
        to_t = lambda x: torch.FloatTensor(x).to(self.device)
        return {
            'x_dm': to_t(x_dm_full),
            'x_init': to_t(x_init_full),
            'x_right': to_t(x_right_full),
            'x_bd': to_t(x_bd),
            'f_val': to_t(f_dm),
            'u_init': to_t(u_init),
            'u_bd': to_t(u_bd)
        }

    def sample_test(self, N_test=5000):
        x_test = np.random.uniform(self.low, self.up, (N_test, self.dim))
        t_test = np.random.uniform(self.t0, self.t1, (N_test, 1))
        x_test_full = np.concatenate([x_test, t_test], axis=1)
        u_exact = self.pho * np.sin(self.la * x_test[:, 0:1]) * np.exp((self.mu - self.la**2) * t_test)
        u_exact = u_exact.reshape(-1, 1)
        return {
            'test_x': torch.FloatTensor(x_test_full).to(self.device),
            'test_u': torch.FloatTensor(u_exact).to(self.device)
        }

    def fun_w(self, x):
        I1 = 0.210987
        x_list = torch.split(x, 1, dim=1)
        h_len = (self.up - self.low) / 2.0
        x_scale_list = [(x_i - self.low - h_len) / h_len for x_i in x_list]
        z_x_list = []
        for xs in x_scale_list:
            supp = (1 - torch.abs(xs)) > 0
            denom = xs ** 2 - 1
            safe_denom = torch.where(supp, denom, torch.ones_like(denom))
            z_x = torch.where(supp, torch.exp(1.0 / safe_denom) / I1, torch.zeros_like(xs))
            z_x_list.append(z_x)
        w_val = torch.ones_like(z_x_list[0])
        for z in z_x_list:
            w_val = w_val * z
        if x.requires_grad:
            dw = torch.autograd.grad(w_val, x, grad_outputs=torch.ones_like(w_val),
                                    create_graph=True, allow_unused=False)[0]
            dw = torch.where(torch.isnan(dw), torch.zeros_like(dw), dw)
        else:
            dw = torch.zeros_like(x)
        return w_val, dw

    def grad_u(self, x_in):
        x_in = x_in.detach().requires_grad_(True)
        u_val = self.net_u(x_in)
        grad_u = torch.autograd.grad(u_val, x_in, grad_outputs=torch.ones_like(u_val),
                                    create_graph=True, retain_graph=True)[0]
        return u_val, grad_u

    def grad_v(self, x_in):
        x_in = x_in.detach().requires_grad_(True)
        v_val = self.net_v(x_in)
        grad_v = torch.autograd.grad(v_val, x_in, grad_outputs=torch.ones_like(v_val),
                                    create_graph=True, retain_graph=True)[0]
        return v_val, grad_v

    def compute_loss(self, train_dict):
        x_dm = train_dict['x_dm']
        x_init = train_dict['x_init']
        x_right = train_dict['x_right']
        x_bd = train_dict['x_bd']
        f_val = train_dict['f_val']
        u_init_true = train_dict['u_init']
        u_bd_true = train_dict['u_bd']

        u_val, grad_u = self.grad_u(x_dm)
        v_val, grad_v = self.grad_v(x_dm)
        grad_u_x = grad_u[:, :-1]; grad_u_t = grad_u[:, -1:]
        grad_v_x = grad_v[:, :-1]; grad_v_t = grad_v[:, -1:]

        w_val, grad_w = self.fun_w(x_dm[:, :-1])
        wv_val = w_val * v_val

        dudw_val = torch.sum(grad_u_x * grad_w, dim=1, keepdim=True)
        dudv_val = torch.sum(grad_u_x * grad_v_x, dim=1, keepdim=True)
        dudwv_val = v_val * dudw_val + w_val * dudv_val
        u_w_dvt_val = u_val * w_val * grad_v_t
        uu_wv_val = u_val * u_val * wv_val
        f_wv_val = f_val * wv_val

        u_init_pred, _ = self.grad_u(x_init)
        u_right_pred, _ = self.grad_u(x_right)
        w_init, _ = self.fun_w(x_init[:, :-1])
        w_right, _ = self.fun_w(x_right[:, :-1])
        v_init, _ = self.grad_v(x_init)
        v_right, _ = self.grad_v(x_right)

        uwv_init = torch.mean(u_init_pred * w_init * v_init)
        uwv_right = torch.mean(u_right_pred * w_right * v_right)

        term1 = uwv_right
        term2 = torch.mean(dudwv_val)
        term3 = torch.mean(u_w_dvt_val)
        term4 = uwv_init
        term5 = torch.mean(uu_wv_val)
        term6 = torch.mean(f_wv_val)

        residual = (term1 + term2 - term3) - (term4 + term5 + term6)
        test_norm = torch.mean(wv_val**2) + 1e-8
        loss_int = self.beta_int * (residual ** 2) / test_norm

        # Вспомогательный
        uw_init = torch.mean(u_init_pred * w_init)
        uw_right = torch.mean(u_right_pred * w_right)
        dudw_aux = torch.mean(dudw_val)
        uu_w_aux = torch.mean(u_val * u_val * w_val)
        f_w_aux = torch.mean(f_val * w_val)
        residual_w = (uw_right + dudw_aux) - (uw_init + uu_w_aux + f_w_aux)
        w_norm = torch.mean(w_val**2) + 1e-8
        loss_intw = self.beta_intw * (residual_w ** 2) / w_norm

        loss_bd = torch.mean(torch.abs(self.net_u(x_bd) - u_bd_true))
        loss_init = torch.mean(torch.abs(u_init_pred - u_init_true))

        loss_u = self.beta_bd * (loss_bd + loss_init) + loss_int + loss_intw
        loss_v = -torch.log(loss_int + 1e-8)

        return loss_u, loss_v, loss_int, loss_bd, loss_init

    def train(self):
        history = {'step': [], 'loss_u': [], 'loss_v': [], 'l2r': []}
        test_dict = self.sample_test()

        for i in range(self.iteration):
            train_dict = self.sample_train(self.N_x, self.N_t, self.N_bd)
            loss_u, loss_v, loss_int, loss_bd, loss_init = self.compute_loss(train_dict)

            # Обучение v
            for _ in range(self.v_step):
                self.optimizer_v.zero_grad()
                _, loss_v, _, _, _ = self.compute_loss(train_dict)
                loss_v.backward()
                self.optimizer_v.step()

            # Обучение u
            for _ in range(self.u_step):
                self.optimizer_u.zero_grad()
                loss_u, _, _, _, _ = self.compute_loss(train_dict)
                loss_u.backward()
                self.optimizer_u.step()

            # Логирование
            if i % 100 == 0:
                with torch.no_grad():
                    pred_u = self.net_u(test_dict['test_x'])
                    test_u = test_dict['test_u']
                    err_l2 = torch.sqrt(torch.mean((test_u - pred_u)**2)).item()
                    u_norm = torch.sqrt(torch.mean(test_u**2)).item()
                    l2r = err_l2 / (u_norm + 1e-8)

                    history['step'].append(i)
                    history['loss_u'].append(loss_u.item())
                    history['loss_v'].append(loss_v.item())
                    history['l2r'].append(l2r)

                if i % 1000 == 0:
                    print(f"Iter {i:6d} | L2r: {l2r:.6f} | Loss_u: {loss_u.item():.4f}")

        return history, test_dict


# =============================
# Запуск
# =============================
if __name__ == '__main__':
    dim = 5
    demo = ParabolicWanPDESolver(
        dim=dim,
        N_x=400,      # Пространственные точки
        N_t=20,       # Временные точки
        N_bd=400,     # Граничные точки
        file_path='./results_siren/',
        beta_int=100.0,
        beta_intw=500.0,
        beta_bd=1000.0,
        v_step=1,
        v_rate=0.015,
        u_step=1,
        u_rate=0.0001,  # 1e-4 для SIREN
        u_hidden_dim=64,
        u_hidden_layers=4,
        iteration=20000,
        device=device
    )

    history, test_dict = demo.train()

    # Сохранение
    torch.save(demo.net_u.state_dict(), './results_siren/net_u_siren.pth')
    scipy.io.savemat('./results_siren/history_siren.mat', history)
    print("SIREN-WAN training completed and saved.")

Using device: cuda
SIREN-WAN initialized: dim=5, N_x=400, N_t=20
Iter      0 | L2r: 1.218347 | Loss_u: 5520.2217
Iter   1000 | L2r: 0.024462 | Loss_u: 80.2962
Iter   2000 | L2r: 0.011046 | Loss_u: 52.6475
Iter   3000 | L2r: 0.009118 | Loss_u: 31.8148
Iter   4000 | L2r: 0.006979 | Loss_u: 21.5423
Iter   5000 | L2r: 0.009720 | Loss_u: 25.4938
Iter   6000 | L2r: 0.005586 | Loss_u: 28.4218
Iter   7000 | L2r: 0.006581 | Loss_u: 48.0427
Iter   8000 | L2r: 0.006446 | Loss_u: 52.7316
Iter   9000 | L2r: 0.004918 | Loss_u: 25.9735
Iter  10000 | L2r: 0.003650 | Loss_u: 12.6060
Iter  11000 | L2r: 0.004307 | Loss_u: 13.8130
Iter  12000 | L2r: 0.003552 | Loss_u: 21.7656
Iter  13000 | L2r: 0.004089 | Loss_u: 24.3719
Iter  14000 | L2r: 0.003108 | Loss_u: 15.9557
Iter  15000 | L2r: 0.068099 | Loss_u: 302.5242
Iter  16000 | L2r: 0.091491 | Loss_u: 434.0579
Iter  17000 | L2r: 0.036763 | Loss_u: 112.4852
Iter  18000 | L2r: 0.112068 | Loss_u: 281.8498
Iter  19000 | L2r: 0.104328 | Loss_u: 260.5639
