## Train Coarse-Grid Operator model and create an MGRIT solver

***Date:** August 19th, 2024*

In [1]:
import os
import pdb
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ConstantLR, ExponentialLR
from torch.utils.data import Dataset, DataLoader, random_split

import scipy
from scipy import sparse as sp
from scipy.sparse.linalg import spsolve
from scipy.sparse import identity

from pymgrit.core.application import Application
from pymgrit.heat.heat_1d import VectorHeat1D
from pymgrit.core.mgrit import Mgrit

from loss_functions import NonGalerkinLoss1Rand, NonGalerkinLoss1Eig, NonGalerkinLoss2Eig, NonGalerkinLoss2EigConst, NonGalerkinLoss3Eig, NonGalerkinLoss3EigConst
from constants import (
    device,
    input_size,
    output_size,
    hidden_size,
    batch_size,
    epochs,
    nstencils,
    eps,
    plot_landscape,
    learning_rate,
)

# torch.set_default_dtype(torch.float64)
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# For reproducibility
torch.manual_seed(0)

[Dagrazs-MacBook-Air.local:52121] shmem: mmap: an error occurred while determining whether or not /var/folders/fd/8p6y3p9d67l9ms3hzwcf97vm0000gn/T//ompi.Dagrazs-MacBook-Air.501/jf.0/3518955520/sm_segment.Dagrazs-MacBook-Air.501.d1bf0000.0 could be created.


<torch._C.Generator at 0x10efd9150>

### Neural Network

In [None]:
class Net(nn.Module):    
    def __init__(self):
        super(Net, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(),
            nn.Linear(hidden_size, output_size),
        )

    def forward(self, array):
        y = self.linear_relu_stack(array)
        return y

net = Net().to(device)

### Set up Data

In [None]:
class NonGalerkinDataset(Dataset):
    def __init__(self, array, transform=None, target_transform=None):
        self.stencil = array
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
      return len(self.stencil)

    def __getitem__(self, index):
        out = self.stencil[index]
        if self.transform:
            out = self.transform(out)
        return out

In [None]:
# This uses one stencil and 32 coarsening factors
# Augment this to use multiple stencils (different beta's)
dx = 1 / 16
dt = 1 / 4096
beta = dt / dx ** 2
stencil_dataset = torch.Tensor([beta, 1 - 2 * beta, beta]).repeat(nstencils, 1)
m = torch.arange(1, nstencils + 1).unsqueeze(1)
stencil_dataset = torch.cat((stencil_dataset, m), 1)

# # Set up for using different stencils
# nstencils = 4
# betas = torch.tensor([1/8, 1/12, 1/16, 1/24])
# stencil_datasets = torch.torch.empty((0, 4), dtype=torch.float32)
# for beta in betas:
#     stencil_dataset = torch.Tensor([beta, 1 - 2 * beta, beta]).repeat(nstencils, 1)
#     m = torch.arange(1, nstencils + 1).unsqueeze(1)
#     stencil_dataset = torch.cat((stencil_dataset, m), 1)
#     stencil_datasets = torch.cat((stencil_datasets, stencil_dataset), 0)

stencil_dataset = stencil_dataset.to(device)

dataset = NonGalerkinDataset(stencil_dataset)
Ndata = len(dataset)
Ntest = 0
Ntrain = Ndata - Ntest
dataset_train, dataset_test = random_split(dataset, [Ntrain, Ntest])
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

# print test dataset
print(f"Train set size {Ntrain}, Test set size {Ntest}, batch size {batch_size}")
print(stencil_dataset)

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer, eps=1e-2):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    losses = []
    
    for batch, X in enumerate(dataloader):
        # Compute prediction and loss
        input_batch = X
        output_batch = model(X)

        # Modify inputs to loss function depending on loss function used
        loss = loss_fn(input_batch, output_batch)
        
        # Back propagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.item())

        if batch % max(1, num_batches - 1) == 0:
            loss, current = loss.item(), batch * batch_size + len(input_batch)
            print(f"Train Loss: {loss:>7e}  [{current:>5d}/{size:>5d}]")

    return losses

### Select loss function

In [None]:
loss_fn = NonGalerkinLoss1Rand()

In [None]:
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
# scheduler = optim.lr_scheduler.ConstantLR(optimizer, learn_rate_drop_factor)

loss_vis = np.zeros((epochs, nstencils // batch_size))

for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    loss = train_loop(train_loader, net, loss_fn, optimizer)
    # Print change of loss throughout training
    loss_vis[epoch] = np.array(loss)
    
    # # Modify eps throughout training for Loss 1
    # if ((epoch + 1) % eps_drop_period == 0):
    #     eps = eps / 2

    # # Modify learning rate throughout training
    # if ((epoch + 1) % learn_rate_drop_period == 0):
    #     scheduler.step()
    #     print("Update learning rate")


In [None]:
num_of_coord = (nstencils // batch_size) * epochs
loss_vis = np.reshape(loss_vis, (1, num_of_coord))
x = np.linspace(0, num_of_coord, num_of_coord)
plt.plot(x, loss_vis[0])
plt.show()

### Save & Load the Trained NN

In [None]:
torch.save(net.state_dict(), "psi-model")

In [None]:
model = Net().to(device)
model.load_state_dict(torch.load("psi-model"))
model.eval()

In [None]:
print("This is the phi we trained on")
dt = 1 / 4096
dx = 1/ 16
beta = dt / dx ** 2
m = 4
phi_m = torch.tensor([beta, 1 - 2 * beta, beta, m]).to(device).reshape((1,4))
psi = model(phi_m).reshape((1,3))
disc = torch.tensor([m*beta, 1 - 2 * m*beta, m*beta]).to(device).reshape((1,3))
loss_psi = loss_fn(phi_m, psi, eps)
loss_disc = loss_fn(phi_m, disc, eps)
print(f"phi = {phi_m}")
print(f"psi = {psi.detach()}")
print(f"output of loss function on psi: {loss_psi}")
print(f"dsc = {disc}")
print(f"output of loss function on disc: {loss_disc}")

print()
print("This is not the phi we trained on")
dt = 1 / 2560
dx = 1/ 16
beta = dt / dx ** 2
m = 2
phi_m = torch.tensor([beta, 1 - 2 * beta, beta, m]).to(device).reshape((1,4))
psi = model(phi_m).reshape((1,3))
disc = torch.tensor([m*beta, 1 - 2 * m*beta, m*beta]).to(device).reshape((1,3))
loss_psi = loss_fn(phi_m, psi, eps)
loss_disc = loss_fn(phi_m, disc, eps)
print(f"phi = {phi_m}")
print(f"psi = {psi.detach()}")
print(f"output of loss function on psi: {loss_psi}")
print(f"dsc = {disc}")
print(f"output of loss function on disc: {loss_disc}")

### Get Optiomal Stencil

In [None]:
dt = 1 / 4096
dx = 1/ 16
beta = dt / dx ** 2
m = 4
phi_m = torch.tensor([beta, 1 - 2 * beta, beta, m]).to(device)
print(phi_m)

x0 = torch.tensor([x[0], x[1], x[0]]).reshape((1,3))
x0[0,-1]

print(x0[0,1:])

def loss(x):
    loss = loss_fn(phi_m.reshape((1,4)), torch.tensor([x[0], x[1], x[0]]).reshape((1,3)), eps)
    return loss.item()

opt = scipy.optimize.minimize(loss, [1,1], method='Nelder-Mead')
optimal_stencil = torch.tensor([opt.x[0], opt.x[1], opt.x[0]])
print(optimal_stencil)

### Print Loss Landscape

In [None]:
# Plot Loss Landscape
if plot_landscape:
    n = 200
    dt = 1/4096
    dx = 1/16
    beta = dt / dx ** 2
    m = 4
    phi_m = torch.tensor([beta, 1 - 2 *  beta, beta, m]).to(device).reshape([1, 4])
    psis = torch.from_numpy(np.fromfunction(
        lambda i, j: np.array([i / n, j / n, i / n]),
        [n, n]
    )).to(device)

    losses = torch.tensor(size = [n, n])
    for i in range(n):
        if i % 10 == 0:
            print(i)
        for j in range(n):
            losses[i, j] = loss_fn(phi_m, psis[:, i, j].reshape([1, 3]), eps)

                
    plt.imshow(losses, aspect = 1, origin = 'lower', norm=LogNorm(vmin=0.001, vmax=1))
    plt.colorbar()
    plt.scatter([100], [50], c = 'red', marker = 'x')
    plt.xticks(range(0, n + 1, n / 5), [0, 0.2, 0.4, 0.6, 0.8, 1])
    plt.yticks(range(0, n + 1, n / 5), [0, 0.2, 0.4, 0.6, 0.8, 1])
    plt.title('Loss 1 with rand and const, eps=1e-2, for matrix [a, b, a]')
    plt.xlabel('b')
    plt.ylabel('a')
    plt.show()

## PyMGRIT Integration

In [None]:
"""
Vector and application class for the 1D heat equation

This extends the Python code
"""
##########

class Heat1DExp(Application):
    """
    Application class for the heat equation in 1D space,
        u_t - a*u_xx = b(x,t),  a > 0, x in [x_start,x_end], t in [0,T],
    with periodic boundary conditions in space and explicit forward Euler discretization.
    """

    def __init__(self, x_start, x_end, nx, a, init_cond=lambda x: x * 0, rhs=lambda x, t: x * 0, *args, **kwargs):
        """
        Constructor.

        :param x_start: left interval bound of spatial domain
        :param x_end: right interval bound of spatial domain
        :param nx: number of spatial degrees of freedom
        :param a: thermal conductivity
        :param init_cond: initial condition
        :param rhs: right-hand side
        """

        super().__init__(*args, **kwargs)
        # Spatial domain with homogeneous Dirichlet boundary conditions
        self.x_start = x_start
        self.x_end = x_end
        self.x = np.linspace(self.x_start, self.x_end, nx)
        self.x = self.x[0:-1]
        self.nx = nx - 1
        self.dx = self.x[1] - self.x[0]

        # Thermal conductivity
        self.a = a

        # Set (spatial) identity matrix and spatial discretization matrix
        self.identity = identity(self.nx, dtype='float', format='csr')
        self.space_disc = self.compute_matrix()

        # Set right-hand side routine
        self.rhs = rhs

        # Set the data structure for any user-defined time point
        self.vector_template = VectorHeat1D(self.nx)

        # Set initial condition
        self.init_cond = init_cond
        self.vector_t_start = VectorHeat1D(self.nx)
        self.vector_t_start.set_values(self.init_cond(self.x))

    def compute_matrix(self):
        """
        Define spatial discretization matrix for 1D heat equation (forward Euler)

        Second-order central finite differences with matrix stencil (periodic BCs)
           (a / dx^2) * [-1  2  -1]
        """

        fac = self.a / self.dx ** 2

        diagonal = np.ones(self.nx) * 2 * fac
        lower = np.ones(self.nx - 1) * -fac
        upper = np.ones(self.nx - 1) * -fac

        # diagonal = np.ones(self.nx) * 2 * 0.6231
        # lower = np.ones(self.nx - 1) * -0.1721
        # upper = np.ones(self.nx - 1) * -0.1715

        matrix = sp.diags(
            diagonals=[diagonal, lower, upper, lower, upper],
            offsets=[0, -1, 1, self.nx-1, -self.nx+1], shape=(self.nx, self.nx),
            format='csr')

        return matrix

    def step(self, u_start: VectorHeat1D, t_start: float, t_stop: float) -> VectorHeat1D:
        """
        Time integration routine for 1D heat equation example problem:
            Forward Euler

        One-step method
           u_i = (I - dt*L) * u_{i-1} + dt*b_{i-1},
        where L = self.space_disc is the spatial discretization operator

        :param u_start: approximate solution for the input time t_start
        :param t_start: time associated with the input approximate solution u_start
        :param t_stop: time to evolve the input approximate solution to
        :return: approximate solution at input time t_stop
        """
        dt = (t_stop - t_start)
        tmp = u_start.get_values()
        # This is the implicit step for reference
        # tmp = spsolve(self.identity + dt * self.space_disc,
        #               tmp + dt * self.rhs(self.x, t_stop))
        tmp = (self.identity - dt * self.space_disc) * tmp + dt * self.rhs(self.x, t_start)
        ret = VectorHeat1D(len(tmp))
        ret.set_values(tmp)
        return ret


##########

class Heat1DNN(Application):
    """
    Application class for the heat equation in 1D space,
        u_t - a*u_xx = b(x,t),  a > 0, x in [x_start,x_end], t in [0,T],
    with periodic boundary conditions in space and offline-trained neural network model.
    """

    def __init__(self, x_start, x_end, nx, a, m, init_cond=lambda x: x * 0, rhs=lambda x, t: x * 0, *args, **kwargs):
        """
        Constructor.

        :param x_start: left interval bound of spatial domain
        :param x_end: right interval bound of spatial domain
        :param nx: number of spatial degrees of freedom
        :param a: thermal conductivity
        :param init_cond: initial condition
        :param rhs: right-hand side
        """

        super().__init__(*args, **kwargs)
        # Spatial domain with homogeneous Dirichlet boundary conditions
        self.x_start = x_start
        self.x_end = x_end
        self.x = np.linspace(self.x_start, self.x_end, nx)
        self.x = self.x[0:-1]
        self.nx = nx - 1
        self.dx = self.x[1] - self.x[0]
        self.m = m

        # Thermal conductivity
        self.a = a

        # Set right-hand side routine
        self.rhs = rhs

        # Set the data structure for any user-defined time point
        self.vector_template = VectorHeat1D(self.nx)

        # Set initial condition
        self.init_cond = init_cond
        self.vector_t_start = VectorHeat1D(self.nx)
        self.vector_t_start.set_values(self.init_cond(self.x))

    def step(self, u_start: VectorHeat1D, t_start: float, t_stop: float) -> VectorHeat1D:
        """
        Time integration routine for 1D heat equation example problem:
            Neural Network Model

        One-step method
           u_i = psi * u_{i-1} + dt*b_{i-1},
        where psi is the NN model

        :param u_start: approximate solution for the input time t_start
        :param t_start: time associated with the input approximate solution u_start
        :param t_stop: time to evolve the input approximate solution to
        :return: approximate solution at input time t_stop
        """

        # compute dt and beta for the finest mesh
        dt = (t_stop - t_start) / self.m
        beta = dt * self.a / self.dx ** 2

        # use NN to get psi
        phi_m = torch.Tensor([beta, 1 - 2 * beta, beta, self.m]).to(device)
        psi_model = model(phi_m)

        # this assumes a 3-point stencil (for now)
        diagonal = np.ones(self.nx)  * psi_model[1].item()
        lower = np.ones(self.nx - 1) * psi_model[0].item()
        upper = np.ones(self.nx - 1) * psi_model[2].item()

        psi = sp.diags(
            diagonals=[diagonal, lower, upper, lower, upper],
            offsets=[0, -1, 1, self.nx-1, -self.nx+1], shape=(self.nx, self.nx),
            format='csr')
        
        tmp = u_start.get_values()
        tmp = psi * tmp + dt * self.rhs(self.x, t_start)
        ret = VectorHeat1D(len(tmp))
        ret.set_values(tmp)
        return ret
    

In [None]:
"""
Apply two-level MGRIT V-cycles with FCF-relaxation to solve the 1D heat equation
    u_t - a*u_xx = b(x,t),  a > 0, x in [0,1], t in [0,T],
with RHS b(x,t) = 0,
periodic BCs in space,
    u(0,t)   = u(1,t),    t in [0,T],
    u_x(0,t) = u_x(1,t),  t in [0,T],
and subject to the initial condition
    u(x,0)  = sin^2(pi*x),  x in [0,1]

=> exact solution u(x,t) = sin^2(pi*x)*cos(t)
"""

def main(plotting=False):
    def rhs(x, t):
        """
        Right-hand side of 1D heat equation example problem at a given space-time point (x,t),
        :param x: spatial grid point
        :param t: time point
        :return: right-hand side of 1D heat equation example problem at point (x,t)
        """

        return -(np.sin(t))*(np.sin(np.pi*x)**2) - 2*(np.pi**2)*(np.cos(t))*((np.cos(np.pi*x)**2) - (np.sin(np.pi*x)**2))

    def init_cond(x):
        """
        Initial condition of 1D heat equation example,
        :param x: spatial grid point
        :return: initial condition of 1D heat equation example problem
        """
        return np.sin(np.pi * x) ** 2

    # Domain: space = [0,1], time = [0,1]
    # nx = 17    =>  dx = 1/16
    # nt = 4097  =>  dt = 1/4096
    #            =>  beta = dt/dx^2 = 256/4096 = 1/16  (allows coarsening by up to 8)
    # Coarsen by a factor of 4 in time (from 4097 to 1025 points)
    # 
    heat0 = Heat1DExp(x_start=0, x_end=1, nx=17, a=1, init_cond=init_cond, rhs=rhs, t_start=0, t_stop=1, nt=2561)
    heat1 = Heat1DNN(x_start=0, x_end=1, nx=17, a=1, m=2, init_cond=init_cond, rhs=rhs, t_start=0, t_stop=1, nt=1281)
    heat2 = Heat1DNN(x_start=0, x_end=1, nx=17, a=1, m=4, init_cond=init_cond, rhs=rhs, t_start=0, t_stop=1, nt=641)

    # Setup two-level MGRIT solver and solve the problem
    problem = [heat0, heat1, heat2]
    mgrit = Mgrit(problem=problem, cf_iter=1, cycle_type='F', nested_iteration=False, max_iter=10,
                  logging_lvl=20, random_init_guess=True)
    info = mgrit.solve()

    if plotting:
        nt = len(mgrit.u[0])
        nx = mgrit.u[0][0].size
        values = [mgrit.u[0][i].get_values() for i in range(nt)]
        sol = np.vstack([np.array(val) for val in values])
        plt.imshow(sol, aspect=nx/nt, origin='lower')
        plt.colorbar()


if __name__ == '__main__':
    main(plotting=True)
    plt.show()
