Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out of memory when calculate big tensor #1923

Closed
XDongiang opened this issue Dec 29, 2019 · 7 comments
Closed

Out of memory when calculate big tensor #1923

XDongiang opened this issue Dec 29, 2019 · 7 comments
Assignees
Labels
application question Questions for the JAX team

Comments

@XDongiang
Copy link

XDongiang commented Dec 29, 2019

how to fix it , please give me a help

import jax.numpy as np
from jax import device_put
from jax.config import config
from jax import jit
import numpy as onp
from jax import vmap

config.update("jax_enable_x64", True)

size = 800000
Kp = onp.random.sample(size*4).reshape(size,4)
Km = onp.random.sample(size*4).reshape(size,4)

phi = Kp + Km

def metric():
    G = onp.eye((4))
    G[0, 0] = -1
    G[1, 1] = -1
    G[2, 2] = -1
    return G

@jit
def dot(a, b):
    G = metric()
    return np.einsum('i,j,ij', a, b, G)


@jit
def Gt_con(a):
    G = metric()
    return G - np.einsum('i,j', a, a) / dot(a, a)

@jit
def Gt_cov(a):
    gt_con = Gt_con(a)
    G = metric()
    return np.einsum('ij,ik,jl->kl', gt_con, G, G)

@jit
@vmap
def P4(a):
    gt_cov = Gt_cov(a)
    gg = np.einsum('ij,kl->ijkl', gt_cov, gt_cov)
    gggg = np.einsum('ijkl,mnop->ijklmnop', gg, gg)
    # get_gggg
    get_gggg = gggg
    get_gggg += np.einsum('ijklmnop->minjolpk', gggg)
    get_gggg += np.einsum('ijklmnop->minkojpl', gggg)
    get_gggg += np.einsum('ijklmnop->minkolpj', gggg)
    get_gggg += np.einsum('ijklmnop->minlojpk', gggg)
    get_gggg += np.einsum('ijklmnop->minlokpj', gggg)
    get_gggg += np.einsum('ijklmnop->mjniokpl', gggg)
    get_gggg += np.einsum('ijklmnop->mjniolpk', gggg)
    get_gggg += np.einsum('ijklmnop->mjnkoipl', gggg)
    get_gggg += np.einsum('ijklmnop->mjnkolpi', gggg)
    get_gggg += np.einsum('ijklmnop->mjnloipk', gggg)
    get_gggg += np.einsum('ijklmnop->mjnlokpi', gggg)
    get_gggg += np.einsum('ijklmnop->mkniojpl', gggg)
    get_gggg += np.einsum('ijklmnop->mkniolpj', gggg)
    get_gggg += np.einsum('ijklmnop->mknjoipl', gggg)
    get_gggg += np.einsum('ijklmnop->mknjolpi', gggg)
    get_gggg += np.einsum('ijklmnop->mknloipj', gggg)
    get_gggg += np.einsum('ijklmnop->mknlojpi', gggg)
    get_gggg += np.einsum('ijklmnop->mlniojpk', gggg)
    get_gggg += np.einsum('ijklmnop->mlniokpj', gggg)
    get_gggg += np.einsum('ijklmnop->mlnjoipk', gggg)
    get_gggg += np.einsum('ijklmnop->mlnjokpi', gggg)
    get_gggg += np.einsum('ijklmnop->mlnkoipj', gggg)
    get_gggg += np.einsum('ijklmnop->mlnkojpi', gggg)
    # get_g_g_g_g
    get_g_g_g_g  = np.einsum('ijklmnop->ijmnlpko', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijmnlokp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijmolpkn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijmolnkp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijmplokn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijmplnko', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijnolpkm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijnolmkp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijnplokm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijnplmko', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijoplnkm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijoplmkn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmnlpjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmnlojp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmolpjn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmolnjp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmplojn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmplnjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->iknolpjm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->iknolmjp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->iknplojm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->iknplmjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikoplnjm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikoplmjn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmnkpjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmnkojp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmokpjn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmoknjp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmpkojn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmpknjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilnokpjm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilnokmjp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilnpkojm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilnpkmjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilopknjm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilopkmjn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmnlpio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmnloip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmolpin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmolnip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmploin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmplnio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jknolpim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jknolmip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jknploim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jknplmio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkoplnim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkoplmin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmnkpio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmnkoip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmokpin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmoknip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmpkoin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmpknio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlnokpim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlnokmip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlnpkoim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlnpkmio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlopknim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlopkmin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmnjpio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmnjoip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmojpin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmojnip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmpjoin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmpjnio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klnojpim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klnojmip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klnpjoim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klnpjmio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klopjnim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klopjmin', gggg)
    # _gggg
    get_gggg_  = np.einsum('ijklmnop->ijklmnop', gggg)
    get_gggg_ += np.einsum('ijklmnop->ijklmonp', gggg)
    get_gggg_ += np.einsum('ijklmnop->ijklmpno', gggg)
    get_gggg_ += np.einsum('ijklmnop->ikjlmnop', gggg)
    get_gggg_ += np.einsum('ijklmnop->ikjlmonp', gggg)
    get_gggg_ += np.einsum('ijklmnop->ikjlmpno', gggg)
    get_gggg_ += np.einsum('ijklmnop->iljkmnop', gggg)
    get_gggg_ += np.einsum('ijklmnop->iljkmonp', gggg)
    get_gggg_ += np.einsum('ijklmnop->iljkmpno', gggg)
    return 1/24 * get_gggg - 1/84 * get_g_g_g_g + 1/105 * get_gggg_

out = P4(phi)
print(out)
@shoyer
Copy link
Member

shoyer commented Dec 29, 2019

It looks like your code would result an array with 800000*4^8 or 5e10 elements, which is unlikely to fit into memory on most machines (roughly 400 GB in float64 at 8 bytes/element). So there’s no way JAX could help here — you’ll need to restructure your code to avoid creating huge arrays.

@XDongiang
Copy link
Author

XDongiang commented Dec 29, 2019

oh,i changed the code .

import jax.numpy as np
from jax import device_put
from jax.config import config
from jax import jit
import numpy as onp
from jax import vmap

config.update("jax_enable_x64", True)

size = 800000
Kp = onp.random.sample(size*4).reshape(size,4)
Km = onp.random.sample(size*4).reshape(size,4)

phi = Kp + Km

def metric():
    G = onp.eye((4))
    G[0, 0] = -1
    G[1, 1] = -1
    G[2, 2] = -1
    return G

@jit
def dot(a, b):
    G = metric()
    return np.einsum('i,j,ij', a, b, G)


@jit
def Gt_con(a):
    G = metric()
    return G - np.einsum('i,j', a, a) / dot(a, a)

@jit
def Gt_cov(a):
    gt_con = Gt_con(a)
    G = metric()
    return np.einsum('ij,ik,jl->kl', gt_con, G, G)

@jit
def P4(a):
    gt_cov = Gt_cov(a)
    gg = np.einsum('ij,kl->ijkl', gt_cov, gt_cov)
    gggg = np.einsum('ijkl,mnop->ijklmnop', gg, gg)
    # get_gggg
    get_gggg = gggg
    get_gggg += np.einsum('ijklmnop->minjolpk', gggg)
    get_gggg += np.einsum('ijklmnop->minkojpl', gggg)
    get_gggg += np.einsum('ijklmnop->minkolpj', gggg)
    get_gggg += np.einsum('ijklmnop->minlojpk', gggg)
    get_gggg += np.einsum('ijklmnop->minlokpj', gggg)
    get_gggg += np.einsum('ijklmnop->mjniokpl', gggg)
    get_gggg += np.einsum('ijklmnop->mjniolpk', gggg)
    get_gggg += np.einsum('ijklmnop->mjnkoipl', gggg)
    get_gggg += np.einsum('ijklmnop->mjnkolpi', gggg)
    get_gggg += np.einsum('ijklmnop->mjnloipk', gggg)
    get_gggg += np.einsum('ijklmnop->mjnlokpi', gggg)
    get_gggg += np.einsum('ijklmnop->mkniojpl', gggg)
    get_gggg += np.einsum('ijklmnop->mkniolpj', gggg)
    get_gggg += np.einsum('ijklmnop->mknjoipl', gggg)
    get_gggg += np.einsum('ijklmnop->mknjolpi', gggg)
    get_gggg += np.einsum('ijklmnop->mknloipj', gggg)
    get_gggg += np.einsum('ijklmnop->mknlojpi', gggg)
    get_gggg += np.einsum('ijklmnop->mlniojpk', gggg)
    get_gggg += np.einsum('ijklmnop->mlniokpj', gggg)
    get_gggg += np.einsum('ijklmnop->mlnjoipk', gggg)
    get_gggg += np.einsum('ijklmnop->mlnjokpi', gggg)
    get_gggg += np.einsum('ijklmnop->mlnkoipj', gggg)
    get_gggg += np.einsum('ijklmnop->mlnkojpi', gggg)
    # get_g_g_g_g
    get_g_g_g_g  = np.einsum('ijklmnop->ijmnlpko', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijmnlokp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijmolpkn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijmolnkp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijmplokn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijmplnko', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijnolpkm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijnolmkp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijnplokm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijnplmko', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijoplnkm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ijoplmkn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmnlpjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmnlojp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmolpjn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmolnjp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmplojn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikmplnjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->iknolpjm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->iknolmjp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->iknplojm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->iknplmjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikoplnjm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ikoplmjn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmnkpjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmnkojp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmokpjn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmoknjp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmpkojn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilmpknjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilnokpjm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilnokmjp', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilnpkojm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilnpkmjo', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilopknjm', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->ilopkmjn', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmnlpio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmnloip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmolpin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmolnip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmploin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkmplnio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jknolpim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jknolmip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jknploim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jknplmio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkoplnim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jkoplmin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmnkpio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmnkoip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmokpin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmoknip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmpkoin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlmpknio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlnokpim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlnokmip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlnpkoim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlnpkmio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlopknim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->jlopkmin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmnjpio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmnjoip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmojpin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmojnip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmpjoin', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klmpjnio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klnojpim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klnojmip', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klnpjoim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klnpjmio', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klopjnim', gggg)
    get_g_g_g_g += np.einsum('ijklmnop->klopjmin', gggg)
    # _gggg
    get_gggg_  = np.einsum('ijklmnop->ijklmnop', gggg)
    get_gggg_ += np.einsum('ijklmnop->ijklmonp', gggg)
    get_gggg_ += np.einsum('ijklmnop->ijklmpno', gggg)
    get_gggg_ += np.einsum('ijklmnop->ikjlmnop', gggg)
    get_gggg_ += np.einsum('ijklmnop->ikjlmonp', gggg)
    get_gggg_ += np.einsum('ijklmnop->ikjlmpno', gggg)
    get_gggg_ += np.einsum('ijklmnop->iljkmnop', gggg)
    get_gggg_ += np.einsum('ijklmnop->iljkmonp', gggg)
    get_gggg_ += np.einsum('ijklmnop->iljkmpno', gggg)
    return 1/24 * get_gggg - 1/84 * get_g_g_g_g + 1/105 * get_gggg_

@jit
def T2_cov(a):
    gt_cov = Gt_cov(a)
    r = a
    rt = np.einsum('ij,j->i', gt_cov, r)
    return (np.einsum('i,j->ij', rt, rt) -
            (1 / 3) * dot(rt, rt) * gt_cov)

@jit
def T1_cov(a):
    gt_cov = Gt_cov(a)
    r = a
    return np.einsum('ij,j->i', gt_cov, r)    

@jit
def T4_cov(a):
    r = a
    p4 = P4(a)
    result = np.einsum('ijklmnop,m,n,o,p', p4, r, r, r, r)
    return result

@jit
@vmap
def phif243(phi):
    t4_cov = T4_cov(phi)
    print(t4_cov.shape)
    t2_cov = T2_cov(phi)
    t1_cov = T1_cov(phi)
    return np.einsum('ijkl,j,kl->i',t4_cov,t1_cov,t2_cov)

out = phif243(phi)
print(out)

RuntimeError: Resource exhausted: Out of memory while trying to allocate 131174400000 bytes.

the P4 is a intermediate variable
whty it allocate so much memory for intermediate variable

@shoyer
Copy link
Member

shoyer commented Dec 29, 2019 via email

@XDongiang
Copy link
Author

so , it does not have method to use jax with gpu to calculate the function?
thank your replay

@shoyer
Copy link
Member

shoyer commented Dec 29, 2019

You’ll have to figure out another way to write this calculation that uses less memory. For example, you could use lax.map (effectively a loop) instead of vmap.

@mattjj mattjj added application question Questions for the JAX team labels Jan 7, 2020
@mattjj
Copy link
Member

mattjj commented Jan 7, 2020

Whoa, that's quite a program!

Following @shoyer's suggestion, you might change the last bit of your code to be

from jax import lax

def phif243_(phi):
    t4_cov = T4_cov(phi)
    t2_cov = T2_cov(phi)
    t1_cov = T1_cov(phi)
    return np.einsum('ijkl,j,kl->i',t4_cov,t1_cov,t2_cov)

@jit
def phif243(phi):
  return lax.map(phif243_, phi)

out = phif243(phi)
print(out)

Then I don't get memory errors, though performance might suffer.

You can also mix vmap and lax.map by doing a reshape to split the leading axis, say from size 800000 to shape (800, 100), and then map over the first while vmapping over the second:

@jit
def phif243(phi):
  return lax.map(vmap(phif243_), phi.reshape((-1, 100) + phi.shape[1:]))

I don't know what the most setup is; I just made those numbers up :)

Finally if you have more than one GPU you can use pmap similarly!

Since the rest is probably pretty application-specific, and since we have a lot of issues open right now, I'm going to close this one. I hope you find a setup, perhaps using lax.map and vmap and pmap, that works well for you!

@mattjj mattjj closed this as completed Jan 7, 2020
@mattjj mattjj self-assigned this Jan 7, 2020
@mattjj
Copy link
Member

mattjj commented Jan 7, 2020

Using this setup caused the calculation to finish on my CPU, but I didn't actually watch it so I'm not sure how long it took :)

@jit
def phif243(phi):
  return lax.map(vmap(phif243_), phi.reshape((-1, 800) + phi.shape[1:]))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
application question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants