# Summary

Milestone 3 of reinforcement learning project. Notebook to get a GMRES that (1) works on GPU and (2) works with batched data.

## TODO

1. Finish testing.
2. Setup dataset from other notebook.
3. Integrate the GPU GMRES with the loss to see how much a single step will cost.
4. Improve loss from Milestone 2.
    1. Vectorize it.
    2. Additional terms for sparsity?
5. Begin to pretrain actor.
    1. First experiment with models.
        1. Minimize eye_dist.
        2. Try different layers.
        3. Try edge prediction. 
        4. REALLY TRY TO GET SPARSITY. (Log-sum penalty) ?
    2. Anything that works, store in a python file.
    3. Build out a python file that has models that have worked.
    4. Make sure to save weights and successfully plots to /Weights and /Plots

# Dataset

Same helmholtz dataset from milestone 2.

In [156]:
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import dense_to_sparse

In [157]:
def generate_random_helmholtz(n, density=0.001):
    L = 1.0
    k = np.random.uniform(10, 200)  # Random wavenumber
    h = L / (n - 1)  

    # Discretization of Helmholtz operator (1D)
    diagonals = [np.ones(n-1), -2*np.ones(n), np.ones(n-1)]

    helmholtz = sp.diags(diagonals, [-1, 0, 1]) / h**2 + k**2 * sp.eye(n)

    # Ensure no perturbations on the tridiagnonal
    perturb = sp.random(n, n, density=density) * np.max(helmholtz)
    
    perturb.setdiag(0)  # Main diagonal
    perturb.setdiag(0, k=1)  # First upper diagonal
    perturb.setdiag(0, k=-1)  # First lower diagonal
        
    return helmholtz + perturb

In [158]:
def test_singularity(mat):
    det_mat = np.linalg.det(mat)

    if np.isclose(det_mat, 0):
        print("Matrix is singular.")
        return False
        
    rank_mat = np.linalg.matrix_rank(mat)
    if rank_mat < mat.shape[0]:
        print("Matrix is singular.")
        return False
    
    print("Matrix is non-singular.")
        
    return True

In [163]:
def normalize_features(graph_data):
    x_min = graph_data.min()
    x_max = graph_data.max()
    out = (graph_data - x_min) / (x_max - x_min + 1e-8)
    return out.float()

def normalize_edge_attr(graph_data):
    e_min = graph_data.min()
    e_max = graph_data.max()
    out = (graph_data - e_min) / (e_max - e_min + 1e-8)
    return out.float()

# TODO

In [171]:
class HelmHoltzDataset(Dataset):
    def __init__(self, generator, checker, norm_feat, norm_edge, epoch_len=1000, size=500, density=0.001):
        super().__init__(None, transform)
        
        self.epoch_len = epoch_len
        self.mat_size = size
        self.mat_density = density
        
        self.generator = generator
        self.checker = checker
        
        self.norm_features = norm_feat
        self.norm_edge_attr = norm_edge

    def len(self):
        return self.epoch_len

    def get(self, idx):
        mat = self.generator(self.mat_size, self.mat_density)
        
        # TODO : I already have a sparse matrix, DONT convert to spare
        mat = torch.tensor(mat.toarray(), dtype=torch.float32)
        edge_index, edge_weights = dense_to_sparse(mat)
        
        node_features = self.norm_features(mat)
        edge_weights = self.norm_edge_attr(edge_weights)

        data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_weights)
        return data

In [172]:
dataset = HelmHoltzDataset(generate_random_helmholtz, test_singularity, normalize_features, normalize_edge_attr)

In [173]:
dataloader = DataLoader(dataset, batch_size=64)

In [174]:
next(iter(dataloader))

DataBatch(x=[32000, 500], edge_index=[2, 143445], edge_attr=[143445], batch=[32000], ptr=[65])

# GPU-GMRES

Computes GMRES solvers. Possibly get it to work with batches but I gave up.

Authored by https://github.com/devzhk/Pytorch-linalg 

In [52]:
from PytorchLinalg.linalg import GMRES
import torch
import scipy.sparse.linalg as spla

In [53]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [54]:
def solve(A, b):   
    A, b = A.to(device), b.to(device)
    return GMRES(A, b, track=True)

In [55]:
def manual_test(A, b, x):
    return torch.linalg.norm(A @ x - b)

In [56]:
def solve_scipy(A, b):
    A, b = A.cpu().numpy(), b.cpu().numpy()
    x, exit = spla.gmres(A, b, maxiter=1000, restart=20)
    print(exit)
    return x

## Testing

Make sure this actually works.

### Random (Garbage)

Should fail. No gurantee the matrix is any good.

In [57]:
def random_usage(n=500):    
    A = torch.randn(n, n, device=device)
    
    b = torch.randn(n, device=device)
    
    # Solve the systems
    sol, (conv_iter, errs) = solve(A, b)
    print (conv_iter)
    print(errs[-1] if len(errs) > 0 else []) # rel_tol
    
    return A, b, sol, conv_iter, errs

In [58]:
A, b, sol, conv_iter, errs = random_usage()

499
1.6573783796047792e-05


In [59]:
f'GMRES: {manual_test(A, b, sol)}, ZEROS: {manual_test(A, b, torch.zeros(500, device=device))}'

'GMRES: 0.000749350932892412, ZEROS: 22.068756103515625'

In [60]:
x = solve_scipy(A, b)

1000


In [61]:
f'GMRES: {manual_test(A, b, torch.from_numpy(x).to(device))}, ZEROS: {manual_test(A, b, torch.zeros(500, device=device))}'

'GMRES: 21.56355857849121, ZEROS: 22.068756103515625'

### Random (Identity)

Converge in a single iteration.

In [62]:
def eye_usage(n=500):    
    A = torch.eye(n, device=device)
    
    b = torch.randn(n, device=device)
    
    # Solve the systems
    sol, (conv_iter, errs) = solve(A, b)
    print (conv_iter)
    print(errs[-1] if len(errs) > 0 else [])
    
    return A, b, sol, conv_iter, errs

In [63]:
A, b, sol, conv_iter, errs = eye_usage()

0
0.0


In [64]:
x = solve_scipy(A, b)

0


### Hilbert

Hilbert matrices from Milestone 1. Need a try catch for errors that arise from givens rotation.

In [65]:
from scipy.linalg import hilbert

In [121]:
def hilbert_usage(n=500):    
    A = torch.tensor(hilbert(n)).to(device, torch.float)
    
    b = torch.randn(n, device=device)
    
    # Solve the systems
    try:
        sol, (conv_iter, errs) = solve(A, b)
    except ValueError as e:
        print(f'ValueError: {e}, Givens-Rotation-Issue')
        return A, b, torch.zeros(n, device=device), -1, []
    
    print (conv_iter)
    print(errs[-1] if len(errs) > 0 else [])
    
    return A, b, sol, conv_iter, errs

In [122]:
A, b, sol, conv_iter, errs = hilbert_usage()

ValueError: 14-th cosine contains NaN, Givens-Rotation-Issue


In [123]:
f'GMRES: {manual_test(A, b, sol)}, ZEROS: {manual_test(A, b, torch.zeros(500, device=device))}'

'GMRES: 21.546018600463867, ZEROS: 21.546018600463867'

In [124]:
x = solve_scipy(A, b)

1000


In [125]:
f'GMRES: {manual_test(A, b, torch.from_numpy(x).to(device))}, ZEROS: {manual_test(A, b, torch.zeros(500, device=device))}'

'GMRES: 425.2770080566406, ZEROS: 21.546018600463867'

### Helmholtz 

Helmholtz matrices from Milestone 1.

In [134]:
import numpy as np
import scipy.sparse as sp

In [135]:
def generate_random_helmholtz(n, density=0.001):
    L = 1.0
    k = np.random.uniform(10, 200)  # Random wavenumber
    h = L / (n - 1)  

    # Discretization of Helmholtz operator (1D)
    diagonals = [np.ones(n-1), -2*np.ones(n), np.ones(n-1)]

    helmholtz = sp.diags(diagonals, [-1, 0, 1]) / h**2 + k**2 * sp.eye(n)

    # Ensure no perturbations on the tridiagnonal
    perturb = sp.random(n, n, density=density) * np.max(helmholtz)
    
    perturb.setdiag(0)  # Main diagonal
    perturb.setdiag(0, k=1)  # First upper diagonal
    perturb.setdiag(0, k=-1)  # First lower diagonal
        
    return helmholtz + perturb

In [138]:
def helmholtz_usage(n=500):    
    A = torch.from_numpy(generate_random_helmholtz(n).toarray()).to(device, torch.float)
    
    b = torch.randn(n, device=device)
    
    # Solve the systems
    try:
        sol, (conv_iter, errs) = solve(A, b)
    except ValueError as e:
        print(f'ValueError: {e}, Givens-Rotation-Issue')
        return A, b, torch.zeros(n, device=device), -1, []
    
    print (conv_iter)
    print(errs[-1] if len(errs) > 0 else [])
    
    return A, b, sol, conv_iter, errs

In [140]:
A, b, sol, conv_iter, errs = helmholtz_usage()

499
1.7057906006812118e-05


In [142]:
f'GMRES: {manual_test(A, b, sol)}, ZEROS: {manual_test(A, b, torch.zeros(500, device=device))}'

'GMRES: 0.0007967761484906077, ZEROS: 22.100479125976562'

In [141]:
x = solve_scipy(A, b)

1000


In [143]:
f'GMRES: {manual_test(A, b, torch.from_numpy(x).to(device))}, ZEROS: {manual_test(A, b, torch.zeros(500, device=device))}'

'GMRES: 0.058694325387477875, ZEROS: 22.100479125976562'

## Loss Integration

Identity distance from Milestone 2 (with vectorized form). This will be the form going forward.

In [177]:
from torch import nn

In [184]:
class IdenittyDistance(nn.Module):
    def __init__(self, l1=1e-1, logsum=None):
        super().__init__()
        self.l1 = l1
        self.logsum = logsum
        
    def forward(self, inp, outp):
        batch_size = inp.ptr.shape[0] - 1
        inner_product = torch.bmm(inp.x.view(batch_size, 500, 500), outp.view(batch_size, 500, 500))

        identity = torch.eye(500, device=inp.x.device).expand(batch_size, -1, -1)
        frobenius_loss = torch.norm(inner_product - identity, p='fro', dim=(1, 2))
        avg_loss = torch.mean(frobenius_loss)

        if self.l1 is not None:
            avg_loss += self.l1 * torch.norm(outp, p=1)
        elif self.logsum is not None:
            avg_loss += torch.sum(torch.log(1 + torch.abs(param) / self.logsum))
        
        return avg_loss

In [185]:
eyeloss = IdenittyDistance()

# Models

Potential models for actor network.

In [None]:
def shared_training_loop():
    pass

## (Node Feature) Models

Models in which the direct node feature output is used as the estimate.

### TransformerConv

Uses a transformer network to pass messages.

### GCN

The basic GNN convolutional model.

### ChebConv

A hihger order convolutional model. **May be worth it to connect K with spectral properties of my input?**

## (Edge Prediction) Models

Models in which the node features are passed in pairs to linear layers and compute an edge weight.

Sharing a training loop means some things like edge prediction will happen within the model class.

## (Misc) Models

Models I just had an idea with. Not sure if I will use this section, but its here.