# 

In [1]:
import torch
import numpy as np
import os
import gpt as g

import sys
sys.path.append("src/")
from qcd_ml.dirac import dirac_wilson_clover

SharedMemoryMpi:  World communicator of size 1
SharedMemoryMpi:  Node  communicator of size 1
SharedMemoryMpi: SharedMemoryAllocate 1073741824 MMAP anonymous implementation 

__|__|__|__|__|__|__|__|__|__|__|__|__|__|__
__|__|__|__|__|__|__|__|__|__|__|__|__|__|__
__|_ |  |  |  |  |  |  |  |  |  |  |  | _|__
__|_                                    _|__
__|_   GGGG    RRRR    III    DDDD      _|__
__|_  G        R   R    I     D   D     _|__
__|_  G        R   R    I     D    D    _|__
__|_  G  GG    RRRR     I     D    D    _|__
__|_  G   G    R  R     I     D   D     _|__
__|_   GGGG    R   R   III    DDDD      _|__
__|_                                    _|__
__|__|__|__|__|__|__|__|__|__|__|__|__|__|__
__|__|__|__|__|__|__|__|__|__|__|__|__|__|__
  |  |  |  |  |  |  |  |  |  |  |  |  |  |  


Copyright (C) 2015 Peter Boyle, Azusa Yamaguchi, Guido Cossu, Antonin Portelli and other authors

This program is free software; you can redistribute it and/or modify
it under the terms of the 

In [15]:
def lattice2ndarray(lattice):
    """ 
    Converts a gpt (https://github.com/lehner/gpt) lattice to a numpy ndarray 
    keeping the ordering of axes as one would expect.
    Example::
        q_top = g.qcd.gauge.topological_charge_5LI(U_smeared, field=True)
        plot_scalar_field(lattice2ndarray(q_top))
    """
    shape = lattice.grid.fdimensions
    shape = list(reversed(shape))
    if lattice[:].shape[1:] != (1,):
        shape.extend(lattice[:].shape[1:])
   
    result = lattice[:].reshape(shape)
    result = np.swapaxes(result, 0, 3)
    result = np.swapaxes(result, 1, 2)
    return result

def ndarray2lattice(ndarray, grid, lat_constructor):
    """
    Converts an ndarray to a gpt lattice, it is the inverse 
    of lattice2ndarray.

    Example::
        lat = ndarray2lattice(arr, g.grid([4,4,4,8], g.double), g.vspincolor)
    """
    lat = lat_constructor(grid)
    data = np.swapaxes(ndarray, 0, 3)
    data = np.swapaxes(data, 1, 2)
    lat[:] = data.reshape([data.shape[0] * data.shape[1] * data.shape[2] * data.shape[3]] + list(data.shape[4:]))
    return lat

In [16]:
grid = g.grid([4, 4, 4, 8], g.double)

rng = g.random("foo")
psi = g.vspincolor(grid)

rng.cnormal(psi)

ndarray1 = lattice2ndarray(psi)
ndarray2 = lattice2ndarray(ndarray2lattice(ndarray1, grid, g.vspincolor))

assert np.allclose(ndarray1, ndarray2)

GPT :     659.905667 s : Initializing gpt.random(foo,vectorized_ranlux24_389_64) took 0.00016284 s


In [19]:
def GMRES_torch(A, b, x0, maxiter=1000, eps=1e-4
              , regulate_b_norm=1e-4
              , innerproduct=None
              , prec=None):
    """
    GMRES solver.
    
    innerproduct is a function (vec,vec)->scalar which is a product.
    prec is a function vec->vec.

    Literature:
    - https://en.wikipedia.org/wiki/Generalized_minimal_residual_method
    - https://www-users.cse.umn.edu/~saad/Calais/PREC.pdf

    Authors:
    - Daniel Knüttel 2024
    """
    if hasattr(A, "__call__"):
        apply_A = lambda x: A(x)
    else:
        apply_A = lambda x: A @ x

    if innerproduct is None:
        innerproduct = lambda x,y: (x.conj() * y).sum()

    
    rk = b - apply_A(x0)

    b_norm = np.sqrt(innerproduct(b, b).real) + regulate_b_norm
    
    rk_norm = np.sqrt(innerproduct(rk, rk).real)
    res = rk_norm / b_norm
    if rk_norm / b_norm <= eps:
        return x0, {"converged": True, "k": 0}

    vk = rk / rk_norm

    v = [None, vk]
    
    cs = np.zeros(maxiter + 2, np.complex128)
    sn = np.zeros(maxiter + 2, np.complex128)
    gamma = np.zeros(maxiter + 2, np.complex128)
    gamma[1] = rk_norm
    H = [None]
    
    converged = False
    for k in range(1, maxiter + 1):
        if prec is not None:
            z = prec(v[k])
        else:
            z = v[k]
        qk = apply_A(z)
        
        Hk = np.zeros(k + 2, np.complex128)
        for i in range(1, k + 1):
            Hk[i] = innerproduct(v[i], qk)
        for i in range(1, k + 1):
            qk -= Hk[i] * v[i]
            
        Hk[k+1] = np.sqrt(innerproduct(qk, qk).real)
        v.append(qk / Hk[k+1])

        for i in range(1, k):
            # (c   s ) [a]   [a']
            # (-s* c*) [b] = [b']
            tmp = cs[i+1] * Hk[i] + sn[i+1] * Hk[i+1]
            Hk[i+1] = -np.conj(sn[i+1]) * Hk[i] + np.conj(cs[i+1]) * Hk[i+1]
            Hk[i] = tmp
            

        beta = np.sqrt(np.abs(Hk[k])**2 + np.abs(Hk[k + 1])**2)

        # ( c    s )[a]   [X]
        # (-s*   c*)[b] = [0]
        # is solved by 
        # s* = b; c* = a
        sn[k+1] = np.conj(Hk[k+1]) / beta
        cs[k+1] = np.conj(Hk[k]) / beta
        Hk[k] = cs[k+1] * Hk[k] + sn[k+1] * Hk[k+1]
        Hk[k+1] = 0
        
        
        gamma[k+1] = -np.conj(sn[k+1]) * gamma[k]
        gamma[k] = cs[k+1] * gamma[k]
        
        H.append(Hk)
        res = np.abs(gamma[k+1]) / b_norm

        if np.abs(gamma[k+1]) / b_norm <= eps:
            converged = True
            break

    y = np.zeros(k+1, np.complex128)
    for i in reversed(range(1, k + 1)):
        overlap = 0
        for j in range(i+1, k+1):
            overlap += H[j][i] * y[j]
        y[i] = (gamma[i] - overlap) / H[i][i]
    if prec is None:
        x = x0 + sum(yi * vi for yi, vi in zip(y[1:], v[1:]))
    else:
        x = x0 + sum(yi * prec(vi) for yi, vi in zip(y[1:], v[1:]))
    return x, {"converged": converged, "k": k, "res": res}

In [20]:
innerproduct = lambda x,y: (x.conj() * y).sum()
norm = lambda x: torch.sqrt(innerproduct(x, x).real)

def orthonormalize(vecs):
    basis = []
    for vec in vecs:
        for b in basis:
            vec = vec - innerproduct(b, vec) * b
        vec = vec / norm(vec)
        basis.append(vec)
    return basis

In [21]:
psi = torch.complex(
        torch.randn(8, 8, 8, 16, 4, 3, dtype=torch.double)
        , torch.randn(8, 8, 8, 16, 4, 3, dtype=torch.double))

n_basis = 4

bv = [torch.randn_like(psi) for _ in range(n_basis)]
#bv = orthonormalize(bv)
bv = [bi / norm(bi) for bi in bv]

In [22]:
for b1 in bv:
    for b2 in bv:
        print(f"{innerproduct(b1, b2): .1e}  ", end="")
    print()

 1.0e+00+0.0e+00j   3.0e-03+4.8e-04j   5.4e-04+1.8e-03j  -7.2e-05+2.5e-03j  
 3.0e-03-4.8e-04j   1.0e+00+0.0e+00j   2.2e-04-2.1e-03j   8.3e-04+2.1e-04j  
 5.4e-04-1.8e-03j   2.2e-04+2.1e-03j   1.0e+00+0.0e+00j   3.5e-03-2.3e-03j  
-7.2e-05-2.5e-03j   8.3e-04-2.1e-04j   3.5e-03+2.3e-03j   1.0e+00+0.0e+00j  


In [26]:

U = torch.tensor(np.load(os.path.join("test", "assets","1500.config.npy")))
U_gpt = g.load("/home/knd35666/data/ensembles/ens_001/1500.config/")
w_gpt = g.qcd.fermion.wilson_clover(U_gpt, {"mass": -0.5,
    "csw_r": 0.0,
    "csw_t": 0.0,
    "xi_0": 1.0,
    "nu": 1.0,
    "isAnisotropic": False,
    "boundary_phases": [1,1,1,1]})

w_torch = dirac_wilson_clover(U, -0.5, 1.0)
w = lambda x: torch.tensor(lattice2ndarray(w_gpt(ndarray2lattice(x.numpy(), U_gpt[0].grid, g.vspincolor))))


GPT :     871.696638 s : Reading /home/knd35666/data/ensembles/ens_001/1500.config/
GPT :     871.708816 s : Switching view to [1,1,1,1]/Read
GPT :     872.028659 s : Read 0.00109863 GB at 0.00343484 GB/s (0.00395677 GB/s for distribution, 0.0260415 GB/s for reading + checksum, 0.160435 GB/s for checksum, 1 views per node)
GPT :     872.063653 s : Read 0.00109863 GB at 0.0337133 GB/s (0.0803193 GB/s for distribution, 0.0581121 GB/s for reading + checksum, 0.533642 GB/s for checksum, 1 views per node)
GPT :     872.091018 s : Read 0.00109863 GB at 0.0458284 GB/s (0.548767 GB/s for distribution, 0.0500114 GB/s for reading + checksum, 0.563946 GB/s for checksum, 1 views per node)
GPT :     872.114229 s : Read 0.00109863 GB at 0.0628933 GB/s (2.92943 GB/s for distribution, 0.0642848 GB/s for reading + checksum, 0.284603 GB/s for checksum, 1 views per node)
GPT :     872.114652 s : Completed reading /home/knd35666/data/ensembles/ens_001/1500.config/ in 0.419604 s


In [51]:
import time

start = time.perf_counter_ns()
w(psi)
stop = time.perf_counter_ns()
print((stop - start) / 1000 / 1000)

13.064981


In [27]:
def GMRES_restarted(A, b, x0, max_restart=10, maxiter_inner=100, eps=1e-4
              , regulate_b_norm=1e-3
              , innerproduct=None
              , prec=None):
    x = x0
    total_iterations = 0
    for rs in range(max_restart):
        x, ret = GMRES_torch(A, b, x, maxiter=maxiter_inner, eps=1e-4
              , regulate_b_norm=regulate_b_norm
              , innerproduct=None
              , prec=None)
        total_iterations += ret["k"]
        print("restarting with res:", ret["res"])
        if ret["converged"]:
            break
    ret["k"] = total_iterations
    return x, ret

In [28]:
zero = torch.zeros_like(psi)

ui = []
for i, b in enumerate(bv):
    uk, ret = GMRES_restarted(w, zero, b, eps=1e-3, maxiter_inner=20, max_restart=4)
    print(f"[{i:2d}]: {ret['converged']} ({ret['k']:5d})")
    ui.append(uk)


restarting with res: tensor(4.3550, dtype=torch.float64)
restarting with res: tensor(0.1862, dtype=torch.float64)
restarting with res: tensor(0.0117, dtype=torch.float64)
restarting with res: tensor(0.0008, dtype=torch.float64)
[ 0]: False (   80)
restarting with res: tensor(4.4334, dtype=torch.float64)
restarting with res: tensor(0.1767, dtype=torch.float64)
restarting with res: tensor(0.0102, dtype=torch.float64)
restarting with res: tensor(0.0007, dtype=torch.float64)
[ 1]: False (   80)
restarting with res: tensor(4.3920, dtype=torch.float64)
restarting with res: tensor(0.1767, dtype=torch.float64)
restarting with res: tensor(0.0103, dtype=torch.float64)
restarting with res: tensor(0.0007, dtype=torch.float64)
[ 2]: False (   80)
restarting with res: tensor(4.4562, dtype=torch.float64)
restarting with res: tensor(0.1900, dtype=torch.float64)
restarting with res: tensor(0.0117, dtype=torch.float64)
restarting with res: tensor(0.0007, dtype=torch.float64)
[ 3]: False (   80)


In [29]:
for b1 in ui:
    for b2 in ui:
        print(f"{abs(innerproduct(b1, b2)): .1e}  ", end="")
    print()

 3.3e-12   9.4e-13   7.8e-13   1.4e-12  
 9.4e-13   2.3e-12   1.1e-12   8.2e-13  
 7.8e-13   1.1e-12   2.6e-12   8.1e-13  
 1.4e-12   8.2e-13   8.1e-13   2.8e-12  


In [30]:
block_size = [4, 4, 4, 4]

In [31]:
uk.shape

torch.Size([8, 8, 8, 16, 4, 3])

In [32]:
L_fine = uk.shape[:4]
L_fine

torch.Size([8, 8, 8, 16])

In [33]:
L_coarse = [lf // bs for lf, bs in zip(L_fine, block_size)]

In [34]:
L_coarse

[2, 2, 2, 4]

In [35]:
import itertools
from collections import deque

# 

In [39]:
lx, ly, lz, lt = block_size

ui_blocked = list(np.empty(L_coarse, dtype=object))

for bx, by, bz, bt in itertools.product(*(range(li) for li in L_coarse)):
    for uk in ui:
        u_block = uk[bx * lx: (bx + 1)*lx
                    , by * ly: (by + 1)*ly
                    , bz * lz: (bz + 1)*lz
                    , bt * lt: (bt + 1)*lt]
        if ui_blocked[bx][by][bz][bt] is None:
            ui_blocked[bx][by][bz][bt]  = []
        ui_blocked[bx][by][bz][bt].append(u_block)
    
    ui_blocked[bx][by][bz][bt] = orthonormalize(ui_blocked[bx][by][bz][bt])

In [102]:
len(ui_blocked[0][0][0][0])


4

In [36]:
def v_project(block_size, ui_blocked, n_basis, L_coarse, v):
    projected = torch.complex(torch.zeros(L_coarse + [n_basis], dtype=torch.double)
                              , torch.zeros(L_coarse + [n_basis], dtype=torch.double))
    lx, ly, lz, lt = block_size
    
    for bx, by, bz, bt in itertools.product(*(range(li) for li in L_coarse)):
        for k, uk in enumerate(ui_blocked[bx][by][bz][bt]):
            projected[bx, by, bz, bt, k] = innerproduct(v[bx * lx: (bx + 1)*lx
                                                        , by * ly: (by + 1)*ly
                                                        , bz * lz: (bz + 1)*lz
                                                        , bt * lt: (bt + 1)*lt], uk)
    return projected

In [37]:
def v_prolong(block_size, ui_blocked, n_basis, L_coarse, v):
    L_fine = [bi*li for bi,li in zip(block_size, L_coarse)]
    prolonged = torch.complex(torch.zeros(L_fine + list(ui_blocked[0][0][0][0][0].shape[4:]), dtype=torch.double)
                              , torch.zeros(L_fine + list(ui_blocked[0][0][0][0][0].shape[4:]), dtype=torch.double))
    for bx, by, bz, bt in itertools.product(*(range(li) for li in L_coarse)):
        for k, uk in enumerate(ui_blocked[bx][by][bz][bt]):
            prolonged[bx * lx: (bx + 1)*lx
                    , by * ly: (by + 1)*ly
                    , bz * lz: (bz + 1)*lz
                    , bt * lt: (bt + 1)*lt] += v[bx,by,bz,bt,k] * uk
    return prolonged

In [40]:
v_prolong(block_size, ui_blocked, n_basis, L_coarse, v_project(block_size, ui_blocked, n_basis, L_coarse, psi))

tensor([[[[[[-2.1033e-02-1.2182e-02j,  1.4261e-02+8.2737e-03j,
             -3.3875e-02+1.0338e-02j],
            [-7.7219e-02-2.0970e-02j,  1.4176e-02+1.4268e-03j,
              2.1541e-02+3.9030e-02j],
            [-2.4076e-02-3.4080e-02j,  2.1603e-02-8.5465e-03j,
             -1.8592e-02-5.3017e-02j],
            [ 8.3675e-03+5.9521e-02j,  4.6360e-03+1.7195e-03j,
              1.6911e-03-3.2527e-02j]],

           [[-1.9882e-02-2.3489e-03j, -4.0199e-02+2.8447e-02j,
              4.8063e-03+3.4052e-02j],
            [ 4.4102e-03+8.2228e-02j, -4.4410e-03+2.6748e-02j,
              2.5241e-02+1.4563e-02j],
            [-2.4003e-02+1.8222e-02j, -1.2591e-02-2.4705e-02j,
             -5.8666e-02+1.5418e-02j],
            [ 4.4672e-02-3.0265e-02j, -4.3323e-03-5.8281e-03j,
              1.2353e-02-2.0462e-02j]],

           [[ 1.8284e-02-3.2619e-02j,  5.1414e-02+1.9638e-03j,
             -1.0616e-02-2.1163e-03j],
            [-5.1323e-03-3.4069e-02j, -6.5964e-03-4.1034e-02j,
             -4

In [41]:
for bx, by, bz, bt in itertools.product(*(range(li) for li in L_coarse)):
    for k, uk in enumerate(ui_blocked[bx][by][bz][bt]):
        projected = torch.complex(torch.zeros(L_coarse + [n_basis], dtype=torch.double)
                          , torch.zeros(L_coarse + [n_basis], dtype=torch.double))
        projected[bx,by,bz,bt,k] = 1
        print(bx, by, bz, bt, k)
        assert torch.allclose(projected
                              , v_project(block_size, ui_blocked, n_basis, L_coarse, v_prolong(block_size, ui_blocked, n_basis, L_coarse, projected))
                             )

0 0 0 0 0
0 0 0 0 1
0 0 0 0 2
0 0 0 0 3
0 0 0 1 0
0 0 0 1 1
0 0 0 1 2
0 0 0 1 3
0 0 0 2 0
0 0 0 2 1
0 0 0 2 2
0 0 0 2 3
0 0 0 3 0
0 0 0 3 1
0 0 0 3 2
0 0 0 3 3
0 0 1 0 0
0 0 1 0 1
0 0 1 0 2
0 0 1 0 3
0 0 1 1 0
0 0 1 1 1
0 0 1 1 2
0 0 1 1 3
0 0 1 2 0
0 0 1 2 1
0 0 1 2 2
0 0 1 2 3
0 0 1 3 0
0 0 1 3 1
0 0 1 3 2
0 0 1 3 3
0 1 0 0 0
0 1 0 0 1
0 1 0 0 2
0 1 0 0 3
0 1 0 1 0
0 1 0 1 1
0 1 0 1 2
0 1 0 1 3
0 1 0 2 0
0 1 0 2 1
0 1 0 2 2
0 1 0 2 3
0 1 0 3 0
0 1 0 3 1
0 1 0 3 2
0 1 0 3 3
0 1 1 0 0
0 1 1 0 1
0 1 1 0 2
0 1 1 0 3
0 1 1 1 0
0 1 1 1 1
0 1 1 1 2
0 1 1 1 3
0 1 1 2 0
0 1 1 2 1
0 1 1 2 2
0 1 1 2 3
0 1 1 3 0
0 1 1 3 1
0 1 1 3 2
0 1 1 3 3
1 0 0 0 0
1 0 0 0 1
1 0 0 0 2
1 0 0 0 3
1 0 0 1 0
1 0 0 1 1
1 0 0 1 2
1 0 0 1 3
1 0 0 2 0
1 0 0 2 1
1 0 0 2 2
1 0 0 2 3
1 0 0 3 0
1 0 0 3 1
1 0 0 3 2
1 0 0 3 3
1 0 1 0 0
1 0 1 0 1
1 0 1 0 2
1 0 1 0 3
1 0 1 1 0
1 0 1 1 1
1 0 1 1 2
1 0 1 1 3
1 0 1 2 0
1 0 1 2 1
1 0 1 2 2
1 0 1 2 3
1 0 1 3 0
1 0 1 3 1
1 0 1 3 2
1 0 1 3 3
1 1 0 0 0
1 1 0 0 1
1 1 0 0 2
1 1 0 0 3


In [43]:
def get_coarse_operator(block_size, ui_blocked, n_basis, L_coarse, fine_operator):
    def operator(source_coarse):
        source_fine = v_prolong(block_size, ui_blocked, n_basis, L_coarse, source_coarse)
        dst_fine = fine_operator(source_fine)
        return v_project(block_size, ui_blocked, n_basis, L_coarse, dst_fine)
    return operator

In [44]:
w_coarse = get_coarse_operator(block_size, ui_blocked, n_basis, L_coarse, w)

In [45]:
def prec_mm(v):
    v_coarse = v_project(block_size, ui_blocked, n_basis, L_coarse, v)
    sol_coarse, r = GMRES_torch(w_coarse, v_coarse, v_coarse, eps=1e-3, maxiter=50)
    print("INNER: ", r)
    sol_fine = v_prolong(block_size, ui_blocked, n_basis, L_coarse, sol_coarse)
    return sol_fine
    

In [46]:
x, ret = GMRES_torch(w, psi, psi, eps=1e-4, maxiter=1000)

In [47]:
ret

{'converged': True, 'k': 47, 'res': tensor(9.2229e-05, dtype=torch.float64)}