Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Size Match AssertionError in Numba parallelization #6065

Open
2 tasks done
DeepakSaini119 opened this issue Aug 2, 2020 · 1 comment
Open
2 tasks done

Size Match AssertionError in Numba parallelization #6065

DeepakSaini119 opened this issue Aug 2, 2020 · 1 comment

Comments

@DeepakSaini119
Copy link

Reporting a bug

I have a code in which I try to get the embeddings of a bag of tokens given the embeddings of tokens:

@nb.njit(nb.float32[:,:](nb.int32[:],nb.float32[:],nb.int32[:],nb.float32[:,:]), parallel=True)
def _compute_doc_embeddings2(indices, data, indptr, vocab_embeddings):
    m = len(indptr) - 1
    
    embs = np.zeros((m, vocab_embeddings.shape[1]), dtype=np.float32)
    for i in nb.prange(m):
        _indices = indices[indptr[i]: indptr[i + 1]]
        _data = data[indptr[i]: indptr[i + 1]].copy()  # copy as nb requires contiguous

        embs[i, :] = np.sum(vocab_embeddings[_indices] * _data.reshape((-1, 1)), axis=0)
    
    return embs

This works fine if I don't do it in parallel mode. But when trying to do with the above code(in parallel=True and using prange), like below:

i = np.array([0,1,2,3], dtype=np.int32)
d = np.array([0.5,0.5,0.5,0.5], dtype=np.float32)
p = np.array([0,2,4], dtype=np.int32)
ve = np.array([[0.4,-0.4,0.5], [-0.4,0.4,-0.5], [0.6,-0.6,0.7], [-0.6,0.6,-0.7]], dtype=np.float32)

embs = _compute_doc_embeddings2(i, d, p, ve)

I get this error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-17-62efe4922c46> in <module>
      5 
      6 # embs = np.zeros((2, 3), dtype=np.float32)
----> 7 embs = _compute_doc_embeddings2(i, d, p, ve)

AssertionError: Sizes of $116binary_subscr.32, $124call_method.36 do not match on <ipython-input-16-bb6227693479> (10)

I am not able to decipher the error. I would appreciate any help. TIA.

@stuartarchibald
Copy link
Contributor

Thanks for the report. Think this is a bug, it's hitting a runtime assertion in the parallel transform. Reproducer:

from numba import njit, prange, float32, int32
import numpy as np

@njit(parallel=True)
def foo(indices, data, indptr, z):
    m = len(indptr) - 1

    embs = np.zeros((m, z.shape[1]), dtype=np.float32)
    for i in prange(m):
        _indices = indices[indptr[i]: indptr[i + 1]]
        _data = data[indptr[i]: indptr[i + 1]].copy()
        f = z[_indices]
        g = _data.reshape((-1, 1))
        print(f.shape, g.shape)
        x =  f * g

    return embs


i = np.array([0,1,2,3], dtype=np.int32)
d = np.array([0.5,0.5,0.5,0.5], dtype=np.float32)
p = np.array([0,2,4], dtype=np.int32)
ve = np.array([[0.4,-0.4,0.5],
               [-0.4,0.4,-0.5],
               [0.6,-0.6,0.7],
               [-0.6,0.6,-0.7]], dtype=np.float32)

foo(i, d, p, ve)

gives:

(2, 3) (2, 1)
(2, 3) (2, 1)
Traceback (most recent call last):
  File "issue6065.py", line 28, in <module>
    foo(i, d, p, ve)
AssertionError: Sizes of f, g do not match on issue6065.py (15)

cc @DrTodd13

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants