Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Google Colab TPU support #189

Open
dsmic opened this issue Oct 9, 2023 · 8 comments
Open

Google Colab TPU support #189

dsmic opened this issue Oct 9, 2023 · 8 comments

Comments

@dsmic
Copy link

dsmic commented Oct 9, 2023

Feature

Desired Behavior / Functionality

Setting default device of torch to TPU should work, but it hangs

How Can It Be Tested

I have a not totally minimal example, which can be tested in google colab. If you you run it, it shows, that the TPU is set up correctly, but the integration hangs. If you interrupt you get :

tensor([1., 1., 3., 1., 1.], device='xla:1')
xla:1
(0, 1, 2) 0
(0, 2, 1) 1
(1, 0, 2) 1
(1, 2, 0) 0
(2, 0, 1) 0
(2, 1, 0) 1
[[0 1 2 3 4 5]
 [0 2 1 3 4 5]
 [1 0 2 3 4 5]
 [1 2 0 3 4 5]
 [2 0 1 3 4 5]
 [2 1 0 3 4 5]]
(6, 6)

---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)

[<ipython-input-4-2b4f7889adcb>](https://localhost:8080/#) in <cell line: 142>()
    140 
    141 
--> 142 plotwf(ppp)
    143 
    144 

4 frames

[/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py](https://localhost:8080/#) in __torch_function__(self, func, types, args, kwargs)
     60         if func in _device_constructors() and kwargs.get('device') is None:
     61             kwargs['device'] = self.device
---> 62         return func(*args, **kwargs)
     63 
     64 # NB: This is directly called from C++ in torch/csrc/Device.cpp

KeyboardInterrupt:

In google colab there are two cells, the first installs TPU support for torch and the needed libs

!pip install cloud-tpu-client==0.10 torch==2.0.0 torchvision==0.15.1 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
!pip install noisyopt
!pip install torchquad

The second runs the program

# Python program to compute Hessian in PyTorch
# importing libraries
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.func import hessian
from torchquad import VEGAS, set_up_backend, set_precision
from torch import vmap
import time
from noisyopt import minimizeCompass
from itertools import permutations
from sympy.combinatorics.permutations import Permutation

N_Int_Points = 3000

# set_precision(data_type='float64', backend='torch')
import torch_xla.core.xla_model as xm
dev = xm.xla_device()
torch.set_default_device(dev)
set_up_backend("torch", data_type="float64", torch_enable_cuda=False)
# set_log_level("TRACE")


j = None
# j = torch.complex(torch.tensor(0, dtype=torch.float64), torch.tensor(1, dtype=torch.float64))

ppp = torch.tensor(
    [1.0, 1.0, 3.0, 1.0, 1.0]
    )
print(ppp)
print(ppp.device)

dist_nuclei = torch.tensor(ppp[0], dtype=torch.float64)
nNuclei = 3
nElectrons = 3
IntElectron = [[-dist_nuclei * (nNuclei//2) * 1.3, dist_nuclei * (nNuclei//2) * 1.3]]

m_Nuclei = 1836
m_Electron = 1
IntNuclei = [[-1.5, 1.5]]

nParticles = nElectrons + nNuclei
m = torch.tensor([m_Nuclei]*nNuclei + [m_Electron]*nElectrons)

offsets = torch.zeros(nParticles)  # at 0 there must be high spatial probability density for VEGAS integration to work
for i in range(nNuclei):
    offsets[i+nElectrons] = dist_nuclei * (i - nNuclei//2)

def CutRange(x, r):
    return torch.sigmoid(x+r)*(1-torch.sigmoid(x-r))


q = torch.tensor([-1]*nElectrons + [1]*nNuclei)


def V(dx):
    return torch.exp(-dx**2)


def Vpot(xinp):
    """Potential energy"""
    x = xinp + offsets
    x1 = x.reshape(-1, 1)
    x2 = x.reshape(1, - 1)
    dx = x1 - x2
    Vdx = q.reshape(-1, 1) * V(dx) * q.reshape(1, -1)
    Vdx = Vdx.triu(diagonal=1)
    return Vdx.sum()


def Epot(wf, x):
    return (torch.conj(wf(x)) * Vpot(x) * wf(x)).real


def H_single(wf, x):
    if j is None:
        gg = torch.func.grad(lambda x: wf(x).real)(x)
    else:
        gg = torch.complex(torch.func.grad(lambda x: wf(x).real)(x), torch.func.grad(lambda x: wf(x).imag)(x))
    v = 1/(2*m)  # from partial integration the minus sign already present
    gg = torch.sqrt(v) * gg
    return ((torch.dot(torch.conj(gg), gg) + Epot(wf, x)).real)


def H(wf, x):
    gg = vmap(lambda x: H_single(wf, x))(x)
    return gg


def testwf(ppp, x):
    # x = xx[tuple(perms[0]), ]  # allows summation over permutations later
    return torch.exp(-ppp[1]*x[nElectrons:]**2).prod(-1) * CutRange(x[:nElectrons], ppp[2] * nNuclei//2).prod(-1) * (ppp[4] * torch.sin(x[:nElectrons] * torch.pi / ppp[3])).prod(-1)


def Norm(wf, x):
    return (torch.conj(wf(x)) * wf(x)).real


vg = VEGAS()

for i in range(nNuclei):
    offsets[i+nElectrons] = ppp[0] * (i - nNuclei//2)

# create permutations

perms = []
perms_p = []
for i in permutations(list(range(nElectrons))):
    a = Permutation(list(i))
    print(i, a.parity())
    perms.append(list(i) + list(range(nElectrons, nParticles)))
    perms_p.append(a.parity())

perms = np.array(perms)
print(perms)
print(perms.shape)


def plotwf(ppp):
    pl_x = np.linspace(IntElectron[0][0].cpu(), IntElectron[0][1].cpu(), 100)
    pl_y = []
    pl_y = []

    plot_pos = 0
    for x in pl_x:
        def wf(x):
            return testwf(ppp, x)
        xinp = [0]*plot_pos + [x] + [0]*(nParticles-1-plot_pos)
        xinp = torch.from_numpy(np.array(xinp))
        if plot_pos < nElectrons:
            int_domain = [IntElectron[0]]*plot_pos + [[x, x+0.01]] + [IntElectron[0]]*(nElectrons-1-plot_pos) + [[-0.1, 0.1]]*nNuclei
        else:
            int_domain = [IntElectron[0]]*nElectrons + [[-0.1, 0.1]]*(plot_pos-nElectrons) + [[x, x+0.01]] + [[-0.1, 0.1]]*(nParticles - 1 - plot_pos)
        integral_value = vg.integrate(lambda y: vmap(lambda y: Norm(lambda x: testwf(ppp, x), y))(y), dim=nParticles, N=10000,  integration_domain=int_domain, max_iterations=20)
        pl_y.append(integral_value)
    pl_y = torch.tensor(pl_y).cpu().numpy()

    plt.plot(pl_x + offsets[plot_pos].cpu().numpy(), pl_y)
    plt.show()


plotwf(ppp)


def doIntegration(pinp):
    start = time.time()
    global offsets
    ppp = torch.tensor(pinp)
    for i in range(nNuclei):
        offsets[i+nElectrons] = ppp[0] * (i - nNuclei//2)
    IntElectron = [[-ppp[0] * (nNuclei//2) * 1.3, ppp[0] * (nNuclei//2) * 1.3]]
    IntNuclei = [[-1.7 / ppp[2], 1.7/ppp[2]]]
    Normvalue = integral_value = vg.integrate(lambda y: vmap(lambda y: Norm(lambda x: testwf(ppp, x), y))(y), dim=nParticles, N=N_Int_Points,  integration_domain=IntElectron*nElectrons+IntNuclei*nNuclei, max_iterations=30)
    integral_value = vg.integrate(lambda y: H(lambda x: testwf(ppp, x), y), dim=nParticles, N=N_Int_Points,  integration_domain=IntElectron*nElectrons+IntNuclei*nNuclei, max_iterations=30)
    retH = integral_value/Normvalue
    print("H", integral_value, ppp, retH, time.time() - start)
    return retH.cpu().numpy()


ret = minimizeCompass(doIntegration, x0=ppp.cpu().numpy(), deltainit=0.1, deltatol=0.01, paired=False, bounds=[[0.01, 5]]*ppp.shape[0], errorcontrol=True, funcNinit=10)

print(ret)

@gomezzz
Copy link
Collaborator

gomezzz commented Oct 9, 2023

Hi @dsmic !

Thanks for posting this.

On a practical level, my first suspicion would be this

torch.set_default_device(dev)
set_up_backend("torch", data_type="float64", torch_enable_cuda=False)

Set up backend probably calls this code , leading to a call of

torch.set_default_tensor_type("float64") which may not be correct for TPU? 🤔

If that is not it, just to be sure, are you sure the problem is within torchquad? Not sure if you can use a different torch / torch_xla version to check if you get a more verbose feedback there?

@dsmic
Copy link
Author

dsmic commented Oct 9, 2023

Thx for the response. I did some digging and it seems it is just awfully slow. Taking 20seconds for preparing the next call to my function:

counter = 0
def Norm(wf, x):
    global counter
    print('deb', x.device)
    res = (torch.conj(wf(x)) * wf(x)).real
    print(counter, 'res', res)
    counter += 1
    return res

My function returns nearly immediately :(

So I am not sure, what is so expensive with TPU...

@gomezzz
Copy link
Collaborator

gomezzz commented Oct 9, 2023

Can you check which device your tensors are on? I suspect you are using the CPU and not TPU because torch.set_default_tensor_type("float64") makes the CPU the default device to use. I am not quite sure what default tensor type should be used with TPU/XLA. You could try not setting up the backend at all but I am not sure that works 🤔 Alternatively, try moving your torch.set_default_device(dev) call after the set_up_backend call?

If neither works, we might need a dedicated backend type for TPUs. Not sure if we ever tried them before.

@dsmic
Copy link
Author

dsmic commented Oct 9, 2023

Yes, the tensors are on the device. (my debug print prints the device. I increased the log level and the time seems to be spend within torchquad:

13:57:19|TQ-INFO| Setting Torch's default tensor type to Float64 (CUDA not initialized).
13:57:19|TQ-DEBUG| Setting LogLevel to TRACE
13:57:19|TQ-DEBUG| Checking inputs to Integrator.
13:57:19|TQ-DEBUG| 
 VEGAS integrating a 6-dimensional fn with 10000 points over [[-1.2999999523162842, -1.2899999523162842], [tensor(-1.3000, device='xla:1'), tensor(1.3000, device='xla:1')], [tensor(-1.3000, device='xla:1'), tensor(1.3000, device='xla:1')], [-0.1, 0.1], [-0.1, 0.1], [-0.1, 0.1]]

13:57:19|TQ-DEBUG| Setting up integration domain.
13:57:19|TQ-DEBUG| Starting VEGAS
13:57:19|TQ-DEBUG| Running Map Warmup with warmup_N_it=5, N_samples=80...
13:57:19|TQ-DEBUG| |  Iter  |    N_Eval    |     Result     |      Error     |    Acc        | Total Evals

tensor([1., 1., 3., 1., 1.], device='xla:1')
xla:1
(0, 1, 2) 0
(0, 2, 1) 1
(1, 0, 2) 1
(1, 2, 0) 0
(2, 0, 1) 0
(2, 1, 0) 1
[[0 1 2 3 4 5]
 [0 2 1 3 4 5]
 [1 0 2 3 4 5]
 [1 2 0 3 4 5]
 [2 0 1 3 4 5]
 [2 1 0 3 4 5]]
(6, 6)
deb xla:1
0 res xla:1

13:57:19|TQ-DEBUG| The integrand was not evaluated in 28 of 240 VEGASMap intervals. Filling the weights for some of them with neighbouring values.
13:57:20|TQ-DEBUG|   remaining intervals: 1
13:57:20|TQ-DEBUG|   remaining intervals: 0
13:57:35|TQ-DEBUG| |	0|         80|  4.000101e-05  |  3.852720e-11  |  1.551718e-01%| 80

deb xla:1
1 res xla:1

13:57:35|TQ-DEBUG| The integrand was not evaluated in 22 of 240 VEGASMap intervals. Filling the weights for some of them with neighbouring values.
13:57:35|TQ-DEBUG|   remaining intervals: 0
13:58:12|TQ-DEBUG| |	1|         80|  4.332120e-05  |  3.020238e-11  |  1.268587e-01%| 160

deb xla:1
2 res xla:1

13:58:12|TQ-DEBUG| The integrand was not evaluated in 34 of 240 VEGASMap intervals. Filling the weights for some of them with neighbouring values.
13:58:13|TQ-DEBUG|   remaining intervals: 1
13:58:13|TQ-DEBUG|   remaining intervals: 0
13:59:12|TQ-DEBUG| |	2|         80|  4.083231e-05  |  1.591607e-11  |  9.770439e-02%| 240

deb xla:1
3 res xla:1

13:59:13|TQ-DEBUG| The integrand was not evaluated in 33 of 240 VEGASMap intervals. Filling the weights for some of them with neighbouring values.
13:59:13|TQ-DEBUG|   remaining intervals: 1
13:59:13|TQ-DEBUG|   remaining intervals: 0
14:00:38|TQ-DEBUG| |	3|         80|  4.592329e-05  |  1.429694e-11  |  8.233576e-02%| 320

deb xla:1
4 res xla:1

14:00:38|TQ-DEBUG| The integrand was not evaluated in 38 of 240 VEGASMap intervals. Filling the weights for some of them with neighbouring values.
14:00:39|TQ-DEBUG|   remaining intervals: 2
14:00:39|TQ-DEBUG|   remaining intervals: 0

---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)

@gomezzz
Copy link
Collaborator

gomezzz commented Oct 10, 2023

Hmmmm, okay that's good.

Then, it could be that the problem is specific to vegas. I noticed you are using a comparatively small number of evaluation points, that is usually quite inefficient with VEGAS (as those evaluation are split between a number of iterations, so you parallelize over a small number of points in the end). Could you try a different integrator to see if that is better?

@dsmic
Copy link
Author

dsmic commented Oct 10, 2023

Yeah, the small number of evaluations was just during testing, as I thought it might have falling back to cpu. Usually I use much much bigger numbers....

MonteCarlo is also not convincing. It is also very slow (much slower than CPU) and did even throw some wired exceptions some times ....

As I have to pay for the TPU usage, I might not test to much. I am using it with the V100 NVIDIA card at the moment, which is quite fine....

Thanks for your support...

@gomezzz
Copy link
Collaborator

gomezzz commented Oct 10, 2023

Okay, one final thought maybe: I noticed you are using float64, could this be the problem? TPUs are targeted at float16 if I am not mistaken?

@dsmic
Copy link
Author

dsmic commented Oct 10, 2023

Good tip, but does not help :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants