In [None]:
import random

import scipy.io
import numpy as np

import torch
import torch.nn as nn
from torch import optim

from tqdm import tqdm

In [None]:
import wandb

wandb.login()

In [None]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# torch.backends.cudnn.deterministic = True
# random.seed(321)
# np.random.seed(321)
# torch.manual_seed(321)
# torch.cuda.manual_seed_all(321)

In [None]:
from math import pi


class BackwardModel(nn.Module):
    def __init__(self, activation):
        super(BackwardModel, self).__init__()
        self.activation = activation
        self.lambda1 = nn.Parameter(torch.tensor([0.0]), requires_grad=False).to(device)
        self.lambda2 = nn.Parameter(torch.log(torch.Tensor([0.01 / pi])), requires_grad=True).to(device)

        self.layers = nn.Sequential(
            nn.Linear(2, 64),
            self.activation(),
            nn.Linear(64, 64),
            self.activation(),
            nn.Linear(64, 64),
            self.activation(),
            nn.Linear(64, 64),
            self.activation(),
            nn.Linear(64, 64),
            self.activation(),
            nn.Linear(64, 1),
        ).to(device)

    def forward(self, *args):
        return self.layers(torch.stack(args, dim=1)).flatten()

In [None]:
class PBLossEq(nn.Module):
    def __init__(self):
        super(PBLossEq, self).__init__()

    def forward(self, model: BackwardModel, x, t, u_pred, u):
        u_t = torch.autograd.grad(
            u_pred, t,
            grad_outputs=torch.ones_like(u),
            retain_graph=True,
            create_graph=True
        )[0]
        u_x = torch.autograd.grad(
            u_pred, x,
            grad_outputs=torch.ones_like(u),
            retain_graph=True,
            create_graph=True
        )[0]
        u_xx = torch.autograd.grad(
            u_x, x,
            grad_outputs=torch.ones_like(u_x),
            retain_graph=True,
            create_graph=True
        )[0]

        return (u_t + torch.exp(model.lambda1) * u * u_x - torch.exp(model.lambda2) * u_xx).pow(2).mean()


class PBLossU(nn.Module):
    def __init__(self):
        super(PBLossU, self).__init__()

    def forward(self, model: BackwardModel, x, t, u_pred, u):
        return (u_pred - u).pow(2).mean()

In [None]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from random import randint


class CustomDataset(Dataset):
    def __init__(self, t, x, u):
        super(CustomDataset).__init__()
        self.t = t
        self.x = x
        self.u = u

    def __len__(self):
        return len(self.t)

    def __getitem__(self, item):
        return self.t[item], self.x[item], self.u[item]


def load_data(path):
    data = scipy.io.loadmat(path)

    T = data['t'].flatten()[:, None]
    X = data['x'].flatten()[:, None]
    U = np.real(data['usol']).T

    t, x, u = [], [], []
    for i, t_curr in enumerate(T):
        for j, x_curr in enumerate(X):
            t.append(torch.Tensor(np.array(t_curr)))
            x.append(torch.Tensor(np.array(x_curr)))
            u.append(torch.Tensor(np.array(U[i][j])))

    t = torch.Tensor(t)
    x = torch.Tensor(x)
    u = torch.Tensor(u)

    return CustomDataset(t, x, u)

In [None]:
import torch.optim.lr_scheduler as lr_scheduler
from IPython.display import clear_output


def train_log(loss_eq, loss_u, example_ct, epoch, lambda1, lambda2):
    wandb.log({
        "epoch": epoch,
        "loss_eq": loss_eq,
        "loss_u": loss_u,
        "lambda1": torch.exp(lambda1),
        "lambda2": torch.exp(lambda2)
    }, step=example_ct)


def train(model, loader, criterion_eq, criterion_u, optimizer_layers, optimizer_lambdas, config):
    wandb.watch(model, criterion_eq, log="all", log_freq=10)
    example_ct, batch_ct = 0, 0

    scheduler1 = lr_scheduler.ExponentialLR(optimizer_layers, gamma=0.999)
    scheduler2 = lr_scheduler.ExponentialLR(optimizer_lambdas, gamma=0.999)
    for epoch in tqdm(range(config.epochs)):
        for _, (t, x, u) in enumerate(loader):
            loss_eq, loss_u = train_batch(epoch, t, x, u, model, optimizer_layers, optimizer_lambdas, scheduler1,
                                          scheduler2, criterion_eq, criterion_u)
            example_ct += len(t)
            batch_ct += 1

            if (batch_ct + 1) % 100 == 0:
                train_log(loss_eq, loss_u, example_ct, epoch, model.lambda1, model.lambda2)

        clear_output(wait=True)
        draw(model, epoch)
    return example_ct, batch_ct


def train_batch(epoch, t, x, u, model, optimizer_layers, optimizer_lambdas, scheduler1, scheduler2, criterion_eq,
                criterion_u):
    t = t.to(device)
    x = x.to(device)
    u = u.to(device)
    t.requires_grad = True
    x.requires_grad = True
    u_pred = model(t, x)

    loss_eq = criterion_eq(model, t, x, u_pred, u)
    loss_u = criterion_u(model, t, x, u_pred, u)

    optimizer_layers.zero_grad()
    optimizer_lambdas.zero_grad()

    total_loss = loss_u + loss_eq
    total_loss.backward()

    optimizer_layers.step()
    optimizer_lambdas.step()

    scheduler1.step()
    scheduler2.step()

    return loss_eq.item(), loss_u.item()

In [None]:
def model_pipeline(hyperparameters):
    with wandb.init(project="PINN Backward model", config=hyperparameters):
        config = wandb.config
        model = BackwardModel(
            activation=nn.Tanh
        )
        optimizer_layers = optim.Adam(
            model.parameters(),
            lr=config.learning_rate,
        )
        optimizer_lambdas = optim.Adam(
            [model.lambda1, model.lambda2],
            lr=config.learning_rate
        )

        criterion_eq = PBLossEq()
        criterion_u = PBLossU()
        dataset = load_data(config.dataset)

        train_dataset, test_dataset = random_split(
            dataset,
            (int(len(dataset) * config.test_size), len(dataset) - int(len(dataset) * config.test_size)),
            generator=torch.Generator()
        )
        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

        train(model, train_loader, criterion_eq, criterion_u, optimizer_layers, optimizer_lambdas, config)

    return model

# def start_train_LBFGS(model):
#     criterion = PBLoss()
#     dataset = load_data("burgers_shock.mat")
#
#     train_dataset, test_dataset = random_split(
#         dataset,
#         (int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)),
#         generator=torch.Generator()
#     )
#     train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)
#     test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
#
#     train_LBFGS(model, train_loader, criterion)
#
#     return model

In [None]:
params = dict(
    epochs=100,
    batch_size=512,
    learning_rate=0.005,
    test_size=1,
    dataset="burgers_shock.mat",
)
model = model_pipeline(params)

In [None]:
wandb.finish()
wandb.init()
model = start_train_LBFGS(model)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style='darkgrid')


def draw(model, epoch):
    res = 256

    plt.figure(figsize=(16, 8))
    iters = 6
    for it in range(iters + 1):
        T = (1 / iters) * it
        t, x, u = [], [], []
        for i in range(256):
            X = torch.Tensor([-1 + i * (2 / 256)])
            # pr10int(X)
            t.append(torch.tensor([T]))
            x.append(X)

        t = torch.Tensor(t)
        x = torch.Tensor(x)

        pred = model(t, x)
        plt.plot(x.detach().numpy(), pred.detach().numpy(), markersize=3)
    plt.show()
    plt.savefig(f"./images/{epoch}.png")

In [None]:
res = 256

plt.figure(figsize=(16, 8))
iters = 6
for it in range(iters + 1):
    T = (1 / iters) * it
    t, x, u = [], [], []
    for i in range(256):
        X = torch.Tensor([-1 + i * (2 / 256)])
        # pr10int(X)
        t.append(torch.tensor([T]))
        x.append(X)

    t = torch.Tensor(t)
    x = torch.Tensor(x)

    pred = model(t, x)
    plt.plot(x.detach().numpy(), pred.detach().numpy(), markersize=3)
plt.show()