In [1]:
#include <immintrin.h>
#include <iostream>

In [2]:
double *transpose(const int N, const double *X)
{
    double *X_T = (double *)malloc(N * N * sizeof(double));
    for (int i = 0; i < N; ++i)
    {
        for (int j = 0; j < N; ++j)
        {
            int original = i * N + j;
            int transposed = j * N + i;
            X_T[transposed] = X[original];
        }
    }
    return X_T;
}

In [3]:
int rows = 4;
int cols = 4;

In [4]:
double* A = (double*)malloc(rows * cols * sizeof(double));
double* B = (double*)malloc(rows * cols * sizeof(double));

In [5]:
// initialize A
for (int i = 0; i < rows * cols; i++) {
    A[i] = i;
}
// initialize B
for (int i = 0; i < rows * cols; i++) {
    B[i] = i * -1;
}

In [6]:
double* AT = transpose(rows, A);

In [7]:
int start_row = 0;
int start_col = 0;
int panel_a_rows = 3;
int panel_a_cols = 2;
int panel_b_rows = 2;
int panel_b_cols = 3;

In [8]:
// __m256d a0, a1;
/*
    0  1  2  3
    4  5  6  7
    8  9  10 11
    12 13 14 15

we have that start_row is 0 and start_col is 0, so the first element is upper left corner (0, 0) = 0
the panel_a_rows is 3 and panel_a_cols is 2, so we have a panel of 3 rows and 2 columns which means
we are looking at the following elements:
    0  1
    4  5
    8  9

a0 is the first column of the panel
a1 is the second column of the panel
*/
// we do something similar for B
// __m256d b0, b1;
/*
    -0  -1  -2  -3
    -4  -5  -6  -7
    -8  -9  -10 -11
    -12 -13 -14 -15

we have that start_row is 0 and start_col is 0, so the first element is upper left corner (0, 0) = -0
the panel_b_rows is 2 and panel_b_cols is 3, so we have a panel of 2 rows and 3 columns which means
we are looking at the following elements:
    -0  -1  -2
    -4  -5  -6

b0 is the first row of the panel
b1 is the second row of the panel
*/
// notice that these are not neat multiples of 4, so we need to pad with zeros

In [9]:
// pack first column of A (which is first row of AT) with padding in preparation to load into a0,a1
double* a0_a1_pack = (double*)aligned_alloc(32, 8 * sizeof(double));
memset(a0_a1_pack, 0, 8 * sizeof(double));

In [10]:
// print out values of AT
for (int i = 0; i < rows; i++) {
    for (int j = 0; j < cols; j++) {
        std::cout << AT[i * cols + j] << " ";
    }
    std::cout << std::endl;
}
// print out values of the packed
for (int i = 0; i < 8; i++) {
    std::cout << a0_a1_pack[i] << " ";
}

0 4 8 12 
1 5 9 13 
2 6 10 14 
3 7 11 15 
0 0 0 0 0 0 0 0 

In [11]:
for (int i = 0; i < panel_a_rows; ++i) {
    a0_a1_pack[i] = AT[start_col * cols + start_row + i];
    a0_a1_pack[i + 4] = AT[start_col * cols + start_row + i + 4];
}

In [12]:
for (int i = 0; i < 8; i++) {
    std::cout << a0_a1_pack[i] << " ";
}

0 4 8 0 1 5 9 0 

In [13]:
__m256d a0_aligned = _mm256_load_pd(a0_a1_pack);

In [13]:
__m256d a0 = _mm256_loadu_pd(a0_a1_pack);
__m256d a1 = _mm256_loadu_pd(a0_a1_pack + 4);

In [14]:
void print_m256(__m256d v) {
    double* p = (double*)&v;
    for (int i = 0; i < 4; i++) {
        std::cout << p[i] << " ";
    }
    std::cout << std::endl;
}

In [15]:
print_m256(a0);
print_m256(a1);

0 4 8 0 
1 5 9 0 


In [16]:
double* b0_b1_pack = (double*)aligned_alloc(32, 8 * sizeof(double));

In [17]:
memset(b0_b1_pack, 0, 8 * sizeof(double));

In [18]:
for (int i = 0; i < panel_b_cols; ++i) {
    b0_b1_pack[i] = B[start_row * rows + start_col + i];
    b0_b1_pack[i + 4] = B[start_row * rows + start_col + i + 4];
}

In [19]:
for (int i = 0; i < 8; i++) {
    std::cout << b0_b1_pack[i] << " ";
}

0 -1 -2 0 -4 -5 -6 0 

In [20]:
// __m256d atest = {1.0, 2.0, 3.0, 4.0};
// __m256d btest = {5.0, 6.0, 7.0, 8.0};
// __m256d result[4];

In [21]:
// void vector_outer_product(__m256d a0, __m256d b0, __m256d* result) {
//   // Broadcast each element of a0 into a separate vector
//   __m256d a0_broadcasted0 = _mm256_set1_pd( ((double*)&a0)[0] ); 
//   __m256d a0_broadcasted1 = _mm256_set1_pd( ((double*)&a0)[1] ); 
//   __m256d a0_broadcasted2 = _mm256_set1_pd( ((double*)&a0)[2] ); 
//   __m256d a0_broadcasted3 = _mm256_set1_pd( ((double*)&a0)[3] );

//   // Multiply the broadcasted vectors with b0
//   result[0] = _mm256_mul_pd(a0_broadcasted0, b0);
//   result[1] = _mm256_mul_pd(a0_broadcasted1, b0);
//   result[2] = _mm256_mul_pd(a0_broadcasted2, b0);
//   result[3] = _mm256_mul_pd(a0_broadcasted3, b0);
// }

In [61]:
// let's hardcode some values
// result 0 should contain 0,0,0,0
// result 1 should contain 0,-4,-8,0
// result 2 should contain 0,-8,-16,0
// result 3 should contain 0,0,0,0
// __m256d first = _mm256_set_pd(0.0, -0.0, -0.0, -0.0);
// __m256d second = _mm256_set_pd(0.0, -4.0, -8.0, 0.0);
// __m256d third = _mm256_set_pd(0.0, -8.0, -16.0, 0.0);
// __m256d fourth = _mm256_set_pd(0.0, 0.0, 0.0, 0.0);
// for testing purposes we'll hard code distinct values
__m256d first = _mm256_set_pd(4.0, 3.0, 2.0, 1.0);
__m256d second = _mm256_set_pd(8.0, 7.0, 6.0, 5.0);
__m256d third = _mm256_set_pd(12.0, 11.0, 10.0, 9.0);
__m256d fourth = _mm256_set_pd(16.0, 15.0, 14.0, 13.0);
// I understand that the above is not exactly proper, because the values are stored in reverse order
// but this is to reach parity with the commented out vector_outer_product function
// which when indexed, does what you expect: result[0][1] = 2.0, result[0][2] = 3.0, etc.

In [62]:
__m256d result[4] = {first, second, third, fourth};

In [63]:
double *C = (double *)calloc(rows * cols, sizeof(double));

In [64]:
int write_row = 0;
int write_col = 0;
int out_rows = 3;
int out_cols = 3;

In [76]:
((double*)&result[0])[2]

1.0000000

In [72]:
for (int i = 0; i < out_rows; ++i) {
    for (int j = 0; j < out_cols; ++j) {
        int write_idx = (write_row + i) * rows + write_col + j;
        printf("selecting i: %d, j: %d, ", i, j);
        printf("writing %f to %d\n", ((double*)&result[i])[j], write_idx);
        C[write_idx] += ((double*)&result[i])[j];
    }
}

selecting i: 0, j: 0, writing 1.000000 to 0
selecting i: 0, j: 1, writing 2.000000 to 1
selecting i: 0, j: 2, writing 1.000000 to 2
selecting i: 1, j: 0, writing 5.000000 to 4
selecting i: 1, j: 1, writing 6.000000 to 5
selecting i: 1, j: 2, writing 5.000000 to 6
selecting i: 2, j: 0, writing 9.000000 to 8
selecting i: 2, j: 1, writing 10.000000 to 9
selecting i: 2, j: 2, writing 9.000000 to 10


In [73]:
// print out C and see if it's correct
for (int i = 0; i < rows * cols; i++) {
    std::cout << C[i] << " ";
}

3 6 3 0 15 18 15 0 27 30 27 0 0 0 0 0 