Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

std.numerics.dotProduct for fixed-size arrays #9902

Open
dlangBugzillaToGithub opened this issue Apr 24, 2011 · 1 comment
Open

std.numerics.dotProduct for fixed-size arrays #9902

dlangBugzillaToGithub opened this issue Apr 24, 2011 · 1 comment

Comments

@dlangBugzillaToGithub
Copy link

bearophile_hugs reported this on 2011-04-24T11:21:34Z

Transfered from https://issues.dlang.org/show_bug.cgi?id=5880

CC List

Description

A third overload for fixed-sized arrays offers:
- compile-time errors for the length mismatch instead of run-time ones (not in release build);
- allows compilers to optimize the code better because the lengths are known at compile-time;
- the fixed-size argument arrays are by reference, avoiding the copy in this case too.


/*pure*/ CommonType!(ElementType!(Range1), ElementType!(Range2))
dotProduct(Range1, Range2)(Range1 a, Range2 b)
    if (isInputRange!(Range1) && isInputRange!(Range2) &&
            !(isArray!(Range1) && isArray!(Range2)))
{
    // can't be pure yet because of length property and enforce
    enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
    static if (haveLen) enforce(a.length == b.length);
    typeof(return) result = 0;
    for (; !a.empty; a.popFront, b.popFront)
    {
        result += a.front * b.front;
    }
    static if (!haveLen) enforce(b.empty);
    return result;
}

/// Ditto
pure Unqual!(CommonType!(F1, F2))
dotProduct(F1, F2)(in F1[] avector, in F2[] bvector)
if (!isStaticArray!F1 || !isStaticArray!F2)
{
    immutable n = avector.length;
    assert(n == bvector.length);
    auto avec = avector.ptr, bvec = bvector.ptr;
    typeof(return) sum0 = 0, sum1 = 0;

    const all_endp = avec + n;
    const smallblock_endp = avec + (n & ~3);
    const bigblock_endp = avec + (n & ~15);

    for (; avec != bigblock_endp; avec += 16, bvec += 16)
    {
        sum0 += avec[0] * bvec[0];
        sum1 += avec[1] * bvec[1];
        sum0 += avec[2] * bvec[2];
        sum1 += avec[3] * bvec[3];
        sum0 += avec[4] * bvec[4];
        sum1 += avec[5] * bvec[5];
        sum0 += avec[6] * bvec[6];
        sum1 += avec[7] * bvec[7];
        sum0 += avec[8] * bvec[8];
        sum1 += avec[9] * bvec[9];
        sum0 += avec[10] * bvec[10];
        sum1 += avec[11] * bvec[11];
        sum0 += avec[12] * bvec[12];
        sum1 += avec[13] * bvec[13];
        sum0 += avec[14] * bvec[14];
        sum1 += avec[15] * bvec[15];
    }

    for (; avec != smallblock_endp; avec += 4, bvec += 4) {
        sum0 += avec[0] * bvec[0];
        sum1 += avec[1] * bvec[1];
        sum0 += avec[2] * bvec[2];
        sum1 += avec[3] * bvec[3];
    }

    sum0 += sum1;

    /* Do trailing portion in naive loop. */
    while (avec != all_endp)
        sum0 += (*avec++) * (*bvec++);

    return sum0;
}

/// Ditto
pure Unqual!(CommonType!(F1, F2))
dotProduct(F1, F2, size_t n, size_t n2)(ref const F1[n] avector, ref const F2[n2] bvector)
if (isStaticArray!(typeof(avector)) && isStaticArray!(typeof(bvector)))
{
    static assert(n == n2); // do not move this to the template constraints
    auto avec = avector.ptr, bvec = bvector.ptr;
    typeof(return) sum0 = 0, sum1 = 0;

    const all_endp = avec + n;
    const smallblock_endp = avec + (n & ~3);
    const bigblock_endp = avec + (n & ~15);

    for (; avec != bigblock_endp; avec += 16, bvec += 16)
    {
        sum0 += avec[0] * bvec[0];
        sum1 += avec[1] * bvec[1];
        sum0 += avec[2] * bvec[2];
        sum1 += avec[3] * bvec[3];
        sum0 += avec[4] * bvec[4];
        sum1 += avec[5] * bvec[5];
        sum0 += avec[6] * bvec[6];
        sum1 += avec[7] * bvec[7];
        sum0 += avec[8] * bvec[8];
        sum1 += avec[9] * bvec[9];
        sum0 += avec[10] * bvec[10];
        sum1 += avec[11] * bvec[11];
        sum0 += avec[12] * bvec[12];
        sum1 += avec[13] * bvec[13];
        sum0 += avec[14] * bvec[14];
        sum1 += avec[15] * bvec[15];
    }

    for (; avec != smallblock_endp; avec += 4, bvec += 4) {
        sum0 += avec[0] * bvec[0];
        sum1 += avec[1] * bvec[1];
        sum0 += avec[2] * bvec[2];
        sum1 += avec[3] * bvec[3];
    }

    sum0 += sum1;

    /* Do trailing portion in naive loop. */
    while (avec != all_endp)
        sum0 += (*avec++) * (*bvec++);

    return sum0;
}

unittest
{
	minidot()

    double[] a0, b0;
    assert(dotProduct(a0, b0) == 0);

    assert(dotProduct([1.0, 2.0], [4.0, 6.0]) == 16.0);
    assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3);
    assert(dotProduct(iota(1, 5), iota(10, 41, 10)) == 300);

    int[4] a1 = [1, 2, 3, 4];
    int[4] b1 = [10, 20, 30, 40];
    assert(dotProduct(a1, b1) == 300);

    int[] c1 = [10, 20, 30, 40];
    assert(dotProduct(a1, c1) == 300);

    int[5] c2 = [10, 20, 30, 40, 0];
    assert(!__traits(compiles, { dotProduct(a1, c2); } )); // can't compile
    
    // more unittests needed
}
@dlangBugzillaToGithub
Copy link
Author

bearophile_hugs commented on 2011-04-24T12:39:28Z

This doesn't work yet:

int[4] a1 = [1, 2, 3, 4];
assert(dotProduct(a1, iota(10, 41, 10)) == 300);

@LightBender LightBender removed the P4 label Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants