In [41]:
import os

# Wipe out system CUDA paths
os.environ["LD_LIBRARY_PATH"] = ""

# Optional: point to jaxlib's internal CUDA (if known)
# os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/path/to/cuda"

# Force CPU (already working)
os.environ["JAX_PLATFORM_NAME"] = "cpu"


In [35]:
!pip uninstall jax jaxlib -y



Found existing installation: jax 0.4.30
Uninstalling jax-0.4.30:
  Successfully uninstalled jax-0.4.30
Found existing installation: jaxlib 0.4.30
Uninstalling jaxlib-0.4.30:
  Successfully uninstalled jaxlib-0.4.30


In [36]:
!pip install --upgrade pip
!pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda12_pip]
  Using cached jax-0.4.30-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.30,>=0.4.27 (from jax[cuda12_pip])
  Using cached jaxlib-0.4.30-cp39-cp39-manylinux2014_x86_64.whl.metadata (1.0 kB)
Using cached jax-0.4.30-py3-none-any.whl (2.0 MB)
Using cached jaxlib-0.4.30-cp39-cp39-manylinux2014_x86_64.whl (79.6 MB)
Installing collected packages: jaxlib, jax
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [jax][32m1/2[0m [jax]
[1A[2KSuccessfully installed jax-0.4.30 jaxlib-0.4.30


In [3]:
import jax
print("JAX version:", jax.__version__)
print("JAX devices:", jax.devices())


JAX version: 0.6.2
JAX devices: [CudaDevice(id=0)]


In [3]:
import sys
print(sys.version)

3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0]


In [4]:
!pip install nbimporter --quiet


In [4]:
#import nbimporter

from .synthetic_data_generator import create_synthetic_matrix, create_synthetic_data
from GNP import GNP
from ResGCN import ResGCN, scale_A_by_spectral_radius
from GMRES import GMRES
from scipy.sparse import csc_matrix, identity
import torch
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import math

ImportError: attempted relative import with no known parent package

In [5]:
Lx = 1000 # Size of domain in km
dxm = 2 # Mesh resolution in km
n2d = np.arange(0, Lx + 1, dxm, dtype="float32").shape[0]**2 # Number of mesh nodes

ss, ii, jj, tri, xcoord, ycoord = create_synthetic_matrix(Lx, dxm, False)
tt = create_synthetic_data(Lx, dxm)

E0627 07:56:02.263596 3292448 cuda_dnn.cc:523] Loaded runtime CuDNN library: 9.5.1 but source was compiled with: 9.8.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E0627 07:56:02.349500 3292448 cuda_dnn.cc:523] Loaded runtime CuDNN library: 9.5.1 but source was compiled with: 9.8.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [None]:
plt.rcParams['figure.figsize'] = [12, 6]
triang = mpl.tri.Triangulation(xcoord, ycoord, tri)

plt.tripcolor(triang, tt)
plt.colorbar()

In [None]:
num_layers = 8              # number of layers in GNP
embed = 32                  # embedding dimension in GNP
hidden = 64                 # hidden dimension in MLPs in GNP
drop_rate = 0.05            # dropout rate in GNP
disable_scale_input = False # whether disable the scaling of inputs in GNP
dtype = torch.float64       # training precision for GNP
lr = 2e-3                   # learning rate in training GNP
weight_decay = 0.0          # weight decay in training GNP
training_data = 'x_mix'     # type of training data x
m = 80                      # Krylov subspace dimension for training data
batch_size = 4              # batch size in training GNP
grad_accu_steps = 1         # gradient accumulation steps in training GNP
epochs = 1000               # number of epochs in training GNP 

In [None]:
kc = 2 * math.pi / np.logspace(1, 3, 10) # 10 values from 10 to 1000 km


In [None]:
tt.shape


In [None]:
n = 1 # Filter order
device = torch.device("cuda")
solver = GMRES()
As = []

for k in kc: # Looping over the scales
    Smat1 = csc_matrix((ss * (1.0 / np.square(k)), (ii, jj)), shape=(n2d, n2d))
    Smat = identity(n2d) + 2.0 * (Smat1 ** n)
    A = torch.sparse_csc_tensor(Smat.indptr, Smat.indices, Smat.data, Smat.shape, dtype=torch.float64).to(device)
    A = scale_A_by_spectral_radius(A)
    As += [A]
    
data = torch.tensor(np.array(tt), device=device, dtype=torch.float64)

In [None]:
net = ResGCN(As[-1], num_layers, embed, hidden, drop_rate, scale_input=True, dtype=dtype).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = None
M = GNP(A, training_data, m, net, device)

In [None]:
iterations = np.zeros(len(As))

for i in range(len(As)):
    ttw = data - As[i] @ data  # Work with perturbations
    x, iters, _, _, _ = solver.solve(A=As[i], b=ttw, rtol=1e-6, max_iters=20000, progress_bar=False)
    print("Iteration " + str(iters))
    iterations[i] = iters

In [None]:
plt.plot(np.logspace(1, 3, 10), iterations)
plt.xlabel("Length [km]")
plt.ylabel("Iterations")
plt.xscale("log")
plt.yscale("log")

In [None]:
import time 

tic = time.time()
hist_loss, best_loss, best_epoch, model_file = M.train(
    batch_size, grad_accu_steps, epochs, optimizer, scheduler, num_workers=4,
    checkpoint_prefix_with_path="./tmp_", progress_bar=False)
M.net.load_state_dict(torch.load(f"./tmp_epoch_{best_epoch}.pt", map_location=device, weights_only=True))

print(f'Done. Training time: {time.time()-tic} seconds')
print(f'Loss: inital = {hist_loss[0]}, final = {hist_loss[-1]}, best = {best_loss}, epoch = {best_epoch}')

In [None]:
iterations_pre = np.zeros(len(As))

for i in range(len(As)):
    ttw = data - As[i] @ data  # Work with perturbations
    x, iters, _, _, _ = solver.solve(A=As[i], b=ttw, rtol=1e-6, max_iters=20000, progress_bar=False, M=M)
    print("Iteration " + str(iters))
    iterations_pre[i] = iters

In [None]:
plt.plot(np.logspace(1, 3, 10), iterations, label="No preconditioner")
plt.plot(np.logspace(1, 3, 10), iterations_pre, label="One phase of training")

plt.xlabel("Length [km]")
plt.ylabel("Iterations")
plt.xscale("log")
plt.yscale("log")
plt.legend()
    

In [None]:
! rm ./*.pt


In [None]:
net = ResGCN(As[0], num_layers, embed, hidden, drop_rate, scale_input=True, dtype=dtype).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = None
M = GNP(A, training_data, m, net, device)

for A in As: 
    M.A = A
    hist_loss, best_loss, best_epoch, model_file = M.train(
    batch_size, grad_accu_steps, epochs // 10, optimizer, scheduler, num_workers=4, 
    checkpoint_prefix_with_path="./tmp_", progress_bar=False)
    M.net.load_state_dict(torch.load(f"./tmp_epoch_{best_epoch}.pt", map_location=device, weights_only=True))
    print(f'Loss: inital = {hist_loss[0]}, final = {hist_loss[-1]}, best = {best_loss}, epoch = {best_epoch}')
    

In [None]:
iterations_post = np.zeros(len(As))

for i in range(len(As)):
    ttw = data - As[i] @ data  # Work with perturbations
    x, iters, _, _, _ = solver.solve(A=As[i], b=ttw, rtol=1e-6, max_iters=20000, progress_bar=False, M=M)
    print("Iteration " + str(iters))
    iterations_post[i] = iters

In [None]:
plt.plot(np.logspace(1, 3, 10), iterations, label="No preconditioner")
plt.plot(np.logspace(1, 3, 10), iterations_pre, label="One phase of training")
plt.plot(np.logspace(1, 3, 10), iterations_post, label="10 phases of training")

plt.xlabel("Length [km]")
plt.ylabel("Iterations")
plt.xscale("log")
plt.yscale("log")
plt.legend()