In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [2]:
# 모델 정의
class PINN(nn.Module):
    def __init__(self, hidden_size):
        super(PINN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, x, t):
        inputs = torch.cat([x, t], dim=-1)
        return self.net(inputs)

In [3]:
# u_x = torch.autograd.grad(u.sum(), x, create_graph=True)[0]
def gradient(y, x):
    return torch.autograd.grad(y, x, 
                               grad_outputs=torch.ones_like(y),
                               create_graph=True)[0]

# 손실 함수 정의
def loss_fn(model, x, t, alpha):
    x.requires_grad_(True)
    t.requires_grad_(True)

    u = model(x, t)
    u_t = gradient(u, t)
    u_x = gradient(u, x)
    u_xx = gradient(u_x, x)
    pde_loss = torch.mean((u_t - alpha * u_xx)**2)

    # 경계 조건 손실
    x_bc = torch.tensor([[0.0], [1.0]], requires_grad=False)
    t_bc = torch.tensor([[0.0], [0.0]], requires_grad=False)
    u_bc = model(x_bc, t_bc)
    bc_loss = torch.mean(u_bc**2)

    # 초기 조건 손실
    x_ic = torch.linspace(0, 1, 100).reshape(-1, 1)
    t_ic = torch.zeros_like(x_ic)
    u_ic = model(x_ic, t_ic)
    ic_loss = torch.mean((u_ic - torch.sin(np.pi * x_ic))**2)
    return pde_loss + bc_loss + ic_loss

In [4]:
# 랜덤 시드
torch.manual_seed(11)

# 학습
alpha = 1.0
hidden_size = 20
model = PINN(hidden_size)
optimizer = optim.Adam(model.parameters(), lr=0.001)

x = torch.linspace(0, 1, 100).reshape(-1, 1)
t = torch.linspace(0, 1, 100).reshape(-1, 1)

n_epochs = 5000
for epoch in range(1, n_epochs + 1):
    loss = loss_fn(model, x, t, alpha)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if epoch % (n_epochs // 10) == 0:
        print(f"[{epoch:4d}/{n_epochs}] Loss: {loss.item():.2e}")

[ 500/5000] Loss: 2.50e-03
[1000/5000] Loss: 4.29e-04
[1500/5000] Loss: 1.71e-04
[2000/5000] Loss: 1.29e-04
[2500/5000] Loss: 1.06e-04
[3000/5000] Loss: 7.55e-05
[3500/5000] Loss: 1.00e-04
[4000/5000] Loss: 4.18e-05
[4500/5000] Loss: 3.21e-05
[5000/5000] Loss: 2.48e-05
