-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Out of memory when calculate big tensor #1923
Comments
It looks like your code would result an array with |
oh,i changed the code . import jax.numpy as np
from jax import device_put
from jax.config import config
from jax import jit
import numpy as onp
from jax import vmap
config.update("jax_enable_x64", True)
size = 800000
Kp = onp.random.sample(size*4).reshape(size,4)
Km = onp.random.sample(size*4).reshape(size,4)
phi = Kp + Km
def metric():
G = onp.eye((4))
G[0, 0] = -1
G[1, 1] = -1
G[2, 2] = -1
return G
@jit
def dot(a, b):
G = metric()
return np.einsum('i,j,ij', a, b, G)
@jit
def Gt_con(a):
G = metric()
return G - np.einsum('i,j', a, a) / dot(a, a)
@jit
def Gt_cov(a):
gt_con = Gt_con(a)
G = metric()
return np.einsum('ij,ik,jl->kl', gt_con, G, G)
@jit
def P4(a):
gt_cov = Gt_cov(a)
gg = np.einsum('ij,kl->ijkl', gt_cov, gt_cov)
gggg = np.einsum('ijkl,mnop->ijklmnop', gg, gg)
# get_gggg
get_gggg = gggg
get_gggg += np.einsum('ijklmnop->minjolpk', gggg)
get_gggg += np.einsum('ijklmnop->minkojpl', gggg)
get_gggg += np.einsum('ijklmnop->minkolpj', gggg)
get_gggg += np.einsum('ijklmnop->minlojpk', gggg)
get_gggg += np.einsum('ijklmnop->minlokpj', gggg)
get_gggg += np.einsum('ijklmnop->mjniokpl', gggg)
get_gggg += np.einsum('ijklmnop->mjniolpk', gggg)
get_gggg += np.einsum('ijklmnop->mjnkoipl', gggg)
get_gggg += np.einsum('ijklmnop->mjnkolpi', gggg)
get_gggg += np.einsum('ijklmnop->mjnloipk', gggg)
get_gggg += np.einsum('ijklmnop->mjnlokpi', gggg)
get_gggg += np.einsum('ijklmnop->mkniojpl', gggg)
get_gggg += np.einsum('ijklmnop->mkniolpj', gggg)
get_gggg += np.einsum('ijklmnop->mknjoipl', gggg)
get_gggg += np.einsum('ijklmnop->mknjolpi', gggg)
get_gggg += np.einsum('ijklmnop->mknloipj', gggg)
get_gggg += np.einsum('ijklmnop->mknlojpi', gggg)
get_gggg += np.einsum('ijklmnop->mlniojpk', gggg)
get_gggg += np.einsum('ijklmnop->mlniokpj', gggg)
get_gggg += np.einsum('ijklmnop->mlnjoipk', gggg)
get_gggg += np.einsum('ijklmnop->mlnjokpi', gggg)
get_gggg += np.einsum('ijklmnop->mlnkoipj', gggg)
get_gggg += np.einsum('ijklmnop->mlnkojpi', gggg)
# get_g_g_g_g
get_g_g_g_g = np.einsum('ijklmnop->ijmnlpko', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijmnlokp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijmolpkn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijmolnkp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijmplokn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijmplnko', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijnolpkm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijnolmkp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijnplokm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijnplmko', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijoplnkm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijoplmkn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmnlpjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmnlojp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmolpjn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmolnjp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmplojn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmplnjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->iknolpjm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->iknolmjp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->iknplojm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->iknplmjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikoplnjm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikoplmjn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmnkpjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmnkojp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmokpjn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmoknjp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmpkojn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmpknjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilnokpjm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilnokmjp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilnpkojm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilnpkmjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilopknjm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilopkmjn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmnlpio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmnloip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmolpin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmolnip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmploin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmplnio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jknolpim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jknolmip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jknploim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jknplmio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkoplnim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkoplmin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmnkpio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmnkoip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmokpin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmoknip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmpkoin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmpknio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlnokpim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlnokmip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlnpkoim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlnpkmio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlopknim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlopkmin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmnjpio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmnjoip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmojpin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmojnip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmpjoin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmpjnio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klnojpim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klnojmip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klnpjoim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klnpjmio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klopjnim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klopjmin', gggg)
# _gggg
get_gggg_ = np.einsum('ijklmnop->ijklmnop', gggg)
get_gggg_ += np.einsum('ijklmnop->ijklmonp', gggg)
get_gggg_ += np.einsum('ijklmnop->ijklmpno', gggg)
get_gggg_ += np.einsum('ijklmnop->ikjlmnop', gggg)
get_gggg_ += np.einsum('ijklmnop->ikjlmonp', gggg)
get_gggg_ += np.einsum('ijklmnop->ikjlmpno', gggg)
get_gggg_ += np.einsum('ijklmnop->iljkmnop', gggg)
get_gggg_ += np.einsum('ijklmnop->iljkmonp', gggg)
get_gggg_ += np.einsum('ijklmnop->iljkmpno', gggg)
return 1/24 * get_gggg - 1/84 * get_g_g_g_g + 1/105 * get_gggg_
@jit
def T2_cov(a):
gt_cov = Gt_cov(a)
r = a
rt = np.einsum('ij,j->i', gt_cov, r)
return (np.einsum('i,j->ij', rt, rt) -
(1 / 3) * dot(rt, rt) * gt_cov)
@jit
def T1_cov(a):
gt_cov = Gt_cov(a)
r = a
return np.einsum('ij,j->i', gt_cov, r)
@jit
def T4_cov(a):
r = a
p4 = P4(a)
result = np.einsum('ijklmnop,m,n,o,p', p4, r, r, r, r)
return result
@jit
@vmap
def phif243(phi):
t4_cov = T4_cov(phi)
print(t4_cov.shape)
t2_cov = T2_cov(phi)
t1_cov = T1_cov(phi)
return np.einsum('ijkl,j,kl->i',t4_cov,t1_cov,t2_cov)
out = phif243(phi)
print(out)
RuntimeError: Resource exhausted: Out of memory while trying to allocate 131174400000 bytes. the P4 is a intermediate variable |
vmap's memory usage is no different than manually writing vectorised code
with an extra dimension. So you still are allocating giant arrays.
…On Sun, Dec 29, 2019 at 4:02 AM Dongsean ***@***.***> wrote:
oh,i changed the code . it lack a step
import jax.numpy as np
from jax import device_put
from jax.config import config
from jax import jit
import numpy as onp
from jax import vmap
config.update("jax_enable_x64", True)
size = 800000
Kp = onp.random.sample(size*4).reshape(size,4)
Km = onp.random.sample(size*4).reshape(size,4)
phi = Kp + Km
def metric():
G = onp.eye((4))
G[0, 0] = -1
G[1, 1] = -1
G[2, 2] = -1
return G
@jit
def dot(a, b):
G = metric()
return np.einsum('i,j,ij', a, b, G)
@jit
def Gt_con(a):
G = metric()
return G - np.einsum('i,j', a, a) / dot(a, a)
@jit
def Gt_cov(a):
gt_con = Gt_con(a)
G = metric()
return np.einsum('ij,ik,jl->kl', gt_con, G, G)
@jit
def P4(a):
gt_cov = Gt_cov(a)
gg = np.einsum('ij,kl->ijkl', gt_cov, gt_cov)
gggg = np.einsum('ijkl,mnop->ijklmnop', gg, gg)
# get_gggg
get_gggg = gggg
get_gggg += np.einsum('ijklmnop->minjolpk', gggg)
get_gggg += np.einsum('ijklmnop->minkojpl', gggg)
get_gggg += np.einsum('ijklmnop->minkolpj', gggg)
get_gggg += np.einsum('ijklmnop->minlojpk', gggg)
get_gggg += np.einsum('ijklmnop->minlokpj', gggg)
get_gggg += np.einsum('ijklmnop->mjniokpl', gggg)
get_gggg += np.einsum('ijklmnop->mjniolpk', gggg)
get_gggg += np.einsum('ijklmnop->mjnkoipl', gggg)
get_gggg += np.einsum('ijklmnop->mjnkolpi', gggg)
get_gggg += np.einsum('ijklmnop->mjnloipk', gggg)
get_gggg += np.einsum('ijklmnop->mjnlokpi', gggg)
get_gggg += np.einsum('ijklmnop->mkniojpl', gggg)
get_gggg += np.einsum('ijklmnop->mkniolpj', gggg)
get_gggg += np.einsum('ijklmnop->mknjoipl', gggg)
get_gggg += np.einsum('ijklmnop->mknjolpi', gggg)
get_gggg += np.einsum('ijklmnop->mknloipj', gggg)
get_gggg += np.einsum('ijklmnop->mknlojpi', gggg)
get_gggg += np.einsum('ijklmnop->mlniojpk', gggg)
get_gggg += np.einsum('ijklmnop->mlniokpj', gggg)
get_gggg += np.einsum('ijklmnop->mlnjoipk', gggg)
get_gggg += np.einsum('ijklmnop->mlnjokpi', gggg)
get_gggg += np.einsum('ijklmnop->mlnkoipj', gggg)
get_gggg += np.einsum('ijklmnop->mlnkojpi', gggg)
# get_g_g_g_g
get_g_g_g_g = np.einsum('ijklmnop->ijmnlpko', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijmnlokp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijmolpkn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijmolnkp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijmplokn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijmplnko', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijnolpkm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijnolmkp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijnplokm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijnplmko', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijoplnkm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ijoplmkn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmnlpjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmnlojp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmolpjn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmolnjp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmplojn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikmplnjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->iknolpjm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->iknolmjp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->iknplojm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->iknplmjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikoplnjm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ikoplmjn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmnkpjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmnkojp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmokpjn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmoknjp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmpkojn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilmpknjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilnokpjm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilnokmjp', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilnpkojm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilnpkmjo', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilopknjm', gggg)
get_g_g_g_g += np.einsum('ijklmnop->ilopkmjn', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmnlpio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmnloip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmolpin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmolnip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmploin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkmplnio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jknolpim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jknolmip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jknploim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jknplmio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkoplnim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jkoplmin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmnkpio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmnkoip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmokpin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmoknip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmpkoin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlmpknio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlnokpim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlnokmip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlnpkoim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlnpkmio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlopknim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->jlopkmin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmnjpio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmnjoip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmojpin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmojnip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmpjoin', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klmpjnio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klnojpim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klnojmip', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klnpjoim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klnpjmio', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klopjnim', gggg)
get_g_g_g_g += np.einsum('ijklmnop->klopjmin', gggg)
# _gggg
get_gggg_ = np.einsum('ijklmnop->ijklmnop', gggg)
get_gggg_ += np.einsum('ijklmnop->ijklmonp', gggg)
get_gggg_ += np.einsum('ijklmnop->ijklmpno', gggg)
get_gggg_ += np.einsum('ijklmnop->ikjlmnop', gggg)
get_gggg_ += np.einsum('ijklmnop->ikjlmonp', gggg)
get_gggg_ += np.einsum('ijklmnop->ikjlmpno', gggg)
get_gggg_ += np.einsum('ijklmnop->iljkmnop', gggg)
get_gggg_ += np.einsum('ijklmnop->iljkmonp', gggg)
get_gggg_ += np.einsum('ijklmnop->iljkmpno', gggg)
return 1/24 * get_gggg - 1/84 * get_g_g_g_g + 1/105 * get_gggg_
@jit
def T2_cov(a):
gt_cov = Gt_cov(a)
r = a
rt = np.einsum('ij,j->i', gt_cov, r)
return (np.einsum('i,j->ij', rt, rt) -
(1 / 3) * dot(rt, rt) * gt_cov)
@jit
def T1_cov(a):
gt_cov = Gt_cov(a)
r = a
return np.einsum('ij,j->i', gt_cov, r)
@jit
def T4_cov(a):
r = a
p4 = P4(a)
result = np.einsum('ijklmnop,m,n,o,p', p4, r, r, r, r)
return result
@jit
@vmap
def phif243(phi):
t4_cov = T4_cov(phi)
print(t4_cov.shape)
t2_cov = T2_cov(phi)
t1_cov = T1_cov(phi)
return np.einsum('ijkl,j,kl->i',t4_cov,t1_cov,t2_cov)
out = phif243(phi)
print(out)
RuntimeError: Resource exhausted: Out of memory while trying to allocate 131174400000 bytes.
the P4 is a intermediate variable
whty
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#1923?email_source=notifications&email_token=AAJJFVSNQ6DXPMKS6ACH45TQ3CGWHA5CNFSM4KAVYS7KYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEHY53GA#issuecomment-569499032>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVVDZNJ5W7XNPAHHOGLQ3CGWHANCNFSM4KAVYS7A>
.
|
so , it does not have method to use jax with gpu to calculate the function? |
You’ll have to figure out another way to write this calculation that uses less memory. For example, you could use |
Whoa, that's quite a program! Following @shoyer's suggestion, you might change the last bit of your code to be from jax import lax
def phif243_(phi):
t4_cov = T4_cov(phi)
t2_cov = T2_cov(phi)
t1_cov = T1_cov(phi)
return np.einsum('ijkl,j,kl->i',t4_cov,t1_cov,t2_cov)
@jit
def phif243(phi):
return lax.map(phif243_, phi)
out = phif243(phi)
print(out) Then I don't get memory errors, though performance might suffer. You can also mix @jit
def phif243(phi):
return lax.map(vmap(phif243_), phi.reshape((-1, 100) + phi.shape[1:])) I don't know what the most setup is; I just made those numbers up :) Finally if you have more than one GPU you can use Since the rest is probably pretty application-specific, and since we have a lot of issues open right now, I'm going to close this one. I hope you find a setup, perhaps using |
Using this setup caused the calculation to finish on my CPU, but I didn't actually watch it so I'm not sure how long it took :) @jit
def phif243(phi):
return lax.map(vmap(phif243_), phi.reshape((-1, 800) + phi.shape[1:])) |
how to fix it , please give me a help
The text was updated successfully, but these errors were encountered: