Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nd4j.ScatterUpdates has a large overhead #10029

Open
ebeaufay opened this issue Sep 1, 2023 · 0 comments
Open

Nd4j.ScatterUpdates has a large overhead #10029

ebeaufay opened this issue Sep 1, 2023 · 0 comments

Comments

@ebeaufay
Copy link

ebeaufay commented Sep 1, 2023

Issue Description

My use-case is a type of EmbeddingLayer where the dictionary size is small (<1M).
I'm implementing a HashGridEncoding layer for training NERFs

Nd4j.ScatterUpdates is slow. In fact, even slower than a naive CPU implementation.

Nd4j.ScatterUpdates seems to run in constant time relative to the size of the dictionary and seems to run in linear time relative to the number of updates.

This is problematic for an EmbeddingLayer with a small Dictionary size and a large batch size

expected behavior

Nd4j.ScatterUpdates is quick with minimal overhead.
I would expect it to be on par with Nd4j.pullRows which is a lot faster (but also a bottleneck in the forward pass).

encountered behavior

Nd4j.ScatterUpdates runs almost in constant time relative to the size of the array to update but it has a large overhead.

until around 1M "weights", this:

    public static void doScatterUpdate(INDArray weights, INDArray indices, INDArray updates){
        Nd4j.scatterUpdate(ScatterUpdate.UpdateOp.ADD, weights, indices, updates, 1);
    }

is slower than this

    public static void doScatterUpdateCPUWorkaround(INDArray weights, int[] indexes, INDArray updates){
        float[][] tempWeights = new float[dictionarySize[0]][dictionarySize[1]];
        float[][] tempUpdates = new float[dictionarySize[0]][dictionarySize[1]];
        updates = updates.dup('f');

        for (int j = 0; j < dictionarySize[1]; j++) {
            INDArray column = updates.getColumn(j);
            tempUpdates[j] = column.data().asFloat();
        }

        for (int i = 0; i < indexes.length; i++) {
            for (int j = 0; j < dictionarySize[1]; j++) {
                tempWeights[indexes[i]][j]+=tempUpdates[j][i];
            }
        }

        INDArray reshape = Nd4j.create(tempWeights);
        weights.addi(reshape);
    }

by a factor of up to 50x (see the test class below)

The CPU workaround doesn't scale.
Nd4j.scatterUpdates runs in linear time relative to the number of updates but even for a smaller amount of updates, it takes an abnormal about of time, perhaps some optimization is possible.

Since Nd4j.pullRows does essentially the same job, it might be worth comparing. It's much faster.

Version Information

        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-cuda-11.6</artifactId>
            <version>1.0.0-M2.1</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-cuda-11.6</artifactId>
            <version>1.0.0-M2.1</version>
            <classifier>windows-x86_64-cudnn</classifier>
        </dependency>

OS: windows
Nvidia RTX 3060 laptop GPU (driver: 31.0.15.3667)
cudnn is installed too if it's used at all

Test class

public class test {

    static int[] dictionarySize = {10000,2};
    static int indexesSize = 1000000;
    public static void main(String[] args) {

        INDArray indices = Nd4j.create(indexesSize).assign(1).castTo(DataType.INT32);
        int[] indexes = indices.data().asInt();
        indices.reshape(indices.size(0),1);

        INDArray weights = Nd4j.create(dictionarySize).assign(6);
        INDArray updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
        doScatterUpdate(weights, indices, updates);
        assert weights.getFloat(1,0) == 40006;

        weights = Nd4j.create(dictionarySize).assign(6);
        updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
        doScatterUpdateCPUWorkaround(weights, indexes, updates);
        assert weights.getFloat(1,0) == 40006;

        // warmup
        for (int i = 0; i < 10; i++) {
            weights = Nd4j.create(dictionarySize).assign(6);
            updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
            doScatterUpdate(weights, indices, updates);

            weights = Nd4j.create(dictionarySize).assign(6);
            updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
            doScatterUpdateCPUWorkaround(weights, indexes, updates);
        }

        long start = System.currentTimeMillis();
        for (int i = 0; i < 10; i++) {
            weights = Nd4j.create(dictionarySize).assign(6);
            updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
            doScatterUpdate(weights, indices, updates);
        }
        System.out.println("scatterUpdates Nd4j : "+(System.currentTimeMillis()-start)+ " ms");

        start = System.currentTimeMillis();
        for (int i = 0; i < 10; i++) {
            weights = Nd4j.create(dictionarySize).assign(6);
            updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
            doScatterUpdateCPUWorkaround(weights, indexes, updates);
        }
        System.out.println("scatterUpdates CPU : "+(System.currentTimeMillis()-start)+ " ms");
    }

    public static void doScatterUpdate(INDArray weights, INDArray indices, INDArray updates){
        Nd4j.scatterUpdate(ScatterUpdate.UpdateOp.ADD, weights, indices, updates, 1);
    }

    public static void doScatterUpdateCPUWorkaround(INDArray weights, int[] indexes, INDArray updates){
        float[][] tempWeights = new float[dictionarySize[0]][dictionarySize[1]];
        float[][] tempUpdates = new float[dictionarySize[0]][dictionarySize[1]];
        updates = updates.dup('f');

        for (int j = 0; j < dictionarySize[1]; j++) {
            INDArray column = updates.getColumn(j);
            tempUpdates[j] = column.data().asFloat();
        }

        for (int i = 0; i < indexes.length; i++) {
            for (int j = 0; j < dictionarySize[1]; j++) {
                tempWeights[indexes[i]][j]+=tempUpdates[j][i];
            }
        }

        INDArray reshape = Nd4j.create(tempWeights);
        weights.addi(reshape);
    }
}

Contributing

I have little experience with C++ so it's a risk to let me contribute on that although if you tell me there's no demand and no time to improve this, I will have a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant