In [None]:
import random

import scipy.io
import numpy as np

import torch
import torch.nn as nn
from torch import optim

from tqdm import tqdm

In [None]:
import wandb

wandb.login()

In [None]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
torch.backends.cudnn.deterministic = True
random.seed(321)
np.random.seed(321)
torch.manual_seed(321)
torch.cuda.manual_seed_all(321)

In [None]:
class BackwardModel(nn.Module):
    def __init__(self, in_size, out_size, hidden_size, activation):
        super(BackwardModel, self).__init__()
        self.activation = activation
        self.lambda1 = nn.Parameter(torch.tensor([0.2], requires_grad=True).to(device))
        self.lambda2 = nn.Parameter(torch.tensor([-5.0], requires_grad=True).to(device))

        self.layers = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            self.activation(),
            nn.Linear(hidden_size, hidden_size),
            self.activation(),
            nn.Linear(hidden_size, out_size)
        ).to(device)

    def forward(self, *args):
        return self.layers(torch.stack(args, dim=1)).flatten()

In [None]:
class PBLoss(nn.Module):
    def __init__(self):
        super(PBLoss, self).__init__()

    def forward(self, model: BackwardModel, x, t, u_pred, u):
        u_t = torch.autograd.grad(
            u_pred, t,
            grad_outputs=torch.ones_like(u_pred),
            retain_graph=True,
            create_graph=True
        )[0]
        u_x = torch.autograd.grad(
            u_pred, x,
            grad_outputs=torch.ones_like(u_pred),
            retain_graph=True,
            create_graph=True
        )[0]
        u_xx = torch.autograd.grad(
            u_x, x,
            grad_outputs=torch.ones_like(u_x),
            retain_graph=True,
            create_graph=True
        )[0]

        loss1 = (u_t + torch.exp(model.lambda1) * u * u_x - torch.exp(model.lambda2) * u_xx).pow(2).mean()
        loss2 = (u_pred - u).pow(2).mean().nan_to_num(nan=0.0)

        return loss1 + loss2

In [None]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from random import randint

class CustomDataset(Dataset):
    def __init__(self, t, x, u):
        super(CustomDataset).__init__()
        self.t = t
        self.x = x
        self.u = u

    def __len__(self):
        return len(self.t)

    def __getitem__(self, item):
        return self.t[item], self.x[item], self.u[item]


def load_data(path):
    data = scipy.io.loadmat(path)

    T = data['t'].flatten()[:, None]
    X = data['x'].flatten()[:, None]
    U = np.real(data['usol']).T

    t, x, u = [], [], []
    for i, t_curr in enumerate(T):
        for j, x_curr in enumerate(X):
            t.append(torch.Tensor(np.array(t_curr)))
            x.append(torch.Tensor(np.array(x_curr)))
            u.append(torch.Tensor(np.array(U[i][j])))

    samples_cnt = 30000
    for i in range(samples_cnt):
        tp = randint(0,1)
        if tp==0:
            t.append(torch.rand(1))
            x.append(torch.tensor(1))
            u.append(torch.tensor(0))
        if tp==1:
            t.append(torch.rand(1))
            x.append(torch.tensor(-1))
            u.append(torch.tensor(0))

    t = torch.Tensor(t)
    x = torch.Tensor(x)
    u = torch.Tensor(u)

    return CustomDataset(t, x, u)

In [None]:
def train_log(loss, example_ct, epoch, lambda1, lambda2):
    wandb.log({
        "epoch": epoch,
        "loss": loss,
        "lambda1": torch.exp(lambda1),
        "lambda2": torch.exp(lambda2)
    }, step=example_ct
    )


def train(model, loader, criterion, optimizer, config):
    wandb.watch(model, criterion, log="all", log_freq=10)
    total_batches = len(loader) * config.epochs
    example_ct = 0
    batch_ct = 0
    for epoch in tqdm(range(config.epochs)):
        for _, (t, x, u) in enumerate(loader):
            loss = train_batch(t, x, u, model, optimizer, criterion)
            example_ct += len(t)
            batch_ct += 1

            if (batch_ct + 1) % 25 == 0:
                train_log(loss, example_ct, epoch, model.lambda1, model.lambda2)


def train_batch(t, x, u, model, optimizer, criterion):
    t = t.to(device)
    x = x.to(device)
    t.requires_grad = True
    x.requires_grad = True
    u_pred = model(t, x)

    loss = criterion(model, t, x, u_pred, u)

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()

    return loss

In [None]:
def model_pipeline(hyperparameters):
    with wandb.init(project="PINN Backward model", config=hyperparameters):
        config = wandb.config
        model = BackwardModel(
            in_size=2,
            out_size=1,
            hidden_size=64,
            activation=nn.Tanh
        )
        optimizer = optim.Adam(
            model.parameters(),
            lr=config.learning_rate,
        )
        criterion = PBLoss()
        dataset = load_data(config.dataset)

        train_dataset, test_dataset = random_split(
            dataset,
            (int(len(dataset) * config.test_size), len(dataset) - int(len(dataset) * config.test_size)),
            generator=torch.Generator()
        )
        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

        train(model, train_loader, criterion, optimizer, config)

    return model

In [None]:
params = dict(
    epochs=50,
    batch_size=32,
    learning_rate=0.005,
    test_size=0.8,
    dataset="burgers_shock.mat",
)
model = model_pipeline(params)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style='darkgrid')

res = 256

plt.figure(figsize=(16, 8))
for it in range(0,11):
    T = 0.1 * it
    t, x, u = [], [], []
    for i in range(256):
        X = torch.Tensor([-1 + i * (2 / 256)])
        # pr10int(X)
        t.append(torch.tensor([T]))
        x.append(X)

    t = torch.Tensor(t)
    x = torch.Tensor(x)

    pred = model(t, x)
    # plt.xlim(-1,1)
    # plt.ylim(-1,1)
    plt.plot(x.detach().numpy(), pred.detach().numpy(), markersize=3)
# plt.savefig(f"./images/{it}.png")

In [None]:
type(model)