In [9]:
import numpy as np

elems_per_side = 4
comp_elems_per_side = elems_per_side
num_ranks = 2
elems_per_plane = elems_per_side * elems_per_side
elems_per_domain = elems_per_plane * elems_per_side
planes_per_rank = elems_per_side // num_ranks

# Create 3d array with numbers
numbers_3d = np.arange(elems_per_domain, dtype=float).reshape(elems_per_side, elems_per_side, elems_per_side)

split_arrays = []
for i in range(num_ranks):
    slice_array = numbers_3d[i*planes_per_rank:i*planes_per_rank+planes_per_rank, :, :]
    split_arrays.append(slice_array)

In [10]:
# Now as functions

def prepare_send_buffer_1_2_0(input_buffer):
    send_buffer = np.zeros(elems_per_domain, dtype=complex).reshape(num_ranks, elems_per_side, comp_elems_per_side, planes_per_rank)
        
    for rank in range(num_ranks):
        for idx_p in range(planes_per_rank):
            for idx_comp in range(comp_elems_per_side):
                for idx_num in range(elems_per_side):
                    send_buffer[rank][idx_num, idx_comp, idx_p] = input_buffer[rank][idx_p, idx_num, idx_comp]
                
    return send_buffer


def prepare_send_buffer_2_1_0(input_buffer):
    send_buffer = np.zeros(elems_per_domain, dtype=complex).reshape(num_ranks, elems_per_side, comp_elems_per_side, planes_per_rank)

    for rank in range(num_ranks):
        for idx_p in range(planes_per_rank):
            for idx_comp in range(comp_elems_per_side):
                for idx_num in range(elems_per_side):
                    send_buffer[rank][idx_num, idx_comp, idx_p] = input_buffer[rank][idx_p, idx_comp, idx_num]
                
    return send_buffer


def communicate_all_to_all(send_buffer):
    recv_buffer = np.zeros(elems_per_domain, dtype=complex).reshape(num_ranks, elems_per_side, comp_elems_per_side, planes_per_rank)

    for src_rank in range(num_ranks):
        for dst_rank in range(num_ranks):
            src_slice = slice(dst_rank * planes_per_rank, (dst_rank + 1) * planes_per_rank)
            dst_slice = slice(src_rank * planes_per_rank, (src_rank + 1) * planes_per_rank)
            recv_buffer[dst_rank][dst_slice] = send_buffer[src_rank][src_slice]
        
    return recv_buffer


def reorganize_recv_buffer_1_2_0(recv_buffer):
    final_buffer = np.zeros(elems_per_domain, dtype=complex).reshape(num_ranks, planes_per_rank, comp_elems_per_side, elems_per_side)
            
    for rank in range(num_ranks):
        for idx_p in range(planes_per_rank):
            for idx_comp in range(comp_elems_per_side):
                for idx_num in range(elems_per_side):
                    dst_idx_x = idx_p + (idx_num // planes_per_rank) * planes_per_rank
                    dst_idx_z = idx_num % planes_per_rank
                    final_buffer[rank][dst_idx_z, idx_comp, dst_idx_x] = recv_buffer[rank][idx_num, idx_comp, idx_p]
                
    return final_buffer


def reorganize_recv_buffer_2_1_0(recv_buffer):
    final_buffer = np.zeros(elems_per_domain, dtype=complex).reshape(num_ranks, planes_per_rank, comp_elems_per_side, elems_per_side)
        
    for rank in range(num_ranks):
        for idx_p in range(planes_per_rank):
            for idx_comp in range(comp_elems_per_side):
                for idx_num in range(elems_per_side):
                    dst_idx_x = idx_p + (idx_num // planes_per_rank) * planes_per_rank
                    dst_idx_z = idx_num % planes_per_rank
                    final_buffer[rank][dst_idx_z, idx_comp, dst_idx_x] = recv_buffer[rank][idx_num, idx_comp, idx_p]
                
    return final_buffer


def real_dist_transposition_1_2_0(input_buffer):
    send_buffer = prepare_send_buffer_1_2_0(input_buffer)
    recv_buffer = communicate_all_to_all(send_buffer)
    final_buffer = reorganize_recv_buffer_1_2_0(recv_buffer)
    
    return final_buffer


def real_dist_transposition_2_1_0(input_buffer):
    send_buffer = prepare_send_buffer_2_1_0(input_buffer)
    recv_buffer = communicate_all_to_all(send_buffer)
    final_buffer = reorganize_recv_buffer_2_1_0(recv_buffer)
    
    return final_buffer


def real_dist_transposition_0_2_1(input_buffer):
    final_buffer = np.zeros(elems_per_domain, dtype=complex).reshape(num_ranks, planes_per_rank, elems_per_side, comp_elems_per_side)

    for rank in range(num_ranks):
        for idx_p in range(planes_per_rank):
            for idx_comp in range(comp_elems_per_side):
                for idx_num in range(elems_per_side):
                    final_buffer[rank][idx_p, idx_num, idx_comp] = input_buffer[rank][idx_p, idx_comp, idx_num]
    
    return final_buffer

In [11]:
from enum import Enum

class fft_type(Enum):
    FORWARD = 0
    BACKWARD = 1
    
print_full_data = False

In [12]:
# FFT verification
print(f"Initial shape: {numbers_3d.shape} and dtype: {numbers_3d.dtype}\n")

fft_result = np.fft.fftn(numbers_3d)
ifft_result = abs(np.fft.ifftn(fft_result))

if print_full_data: print(numbers_3d)
print(f"Regular FFT result shape: {fft_result.shape} and dtype: {fft_result.dtype}")
if print_full_data: print(fft_result)
print(f"Regular iFFT result shape: {ifft_result.shape} and dtype: {ifft_result.dtype}")
if print_full_data: print(ifft_result)
print(f"Verification of regular FFTs: Arrays are identical after forward and backward: {np.array_equal(numbers_3d, ifft_result)}\n")

def exec_fft(input_buffer, type):
    result = np.zeros_like(input_buffer, dtype=complex)
    for rank in range(num_ranks):
        for idx_z in range(planes_per_rank):
            for idx_y in range(elems_per_side):
                if type == fft_type.FORWARD:
                    result[rank][idx_z, idx_y] = np.fft.fft(input_buffer[rank][idx_z, idx_y])
                elif type == fft_type.BACKWARD:
                    result[rank][idx_z, idx_y] = np.fft.ifft(input_buffer[rank][idx_z, idx_y])
                    
    return result
                
                
def dist_fft(input_buffer, type):
    result = input_buffer
    result = exec_fft(result, type)
    result = real_dist_transposition_1_2_0(result)
    result = exec_fft(result, type)
    result = real_dist_transposition_2_1_0(result)
    result = exec_fft(result, type)
    result = real_dist_transposition_0_2_1(result)
    
    if type == fft_type.BACKWARD:
        result = abs(result)
        
    return result
        
        
dist_fft_result = dist_fft(split_arrays, fft_type.FORWARD)
fft_dist_res = np.concatenate(dist_fft_result, axis=0)
print(f"Distributed FFT result shape: {fft_dist_res.shape} and dtype: {fft_dist_res.dtype}")
if print_full_data: print(fft_dist_res)
print(f"Verification of distributed FFT: Arrays are identical: {np.array_equal(fft_result, fft_dist_res)}\n")

dist_ifft_result = dist_fft(dist_fft_result, fft_type.BACKWARD)
ifft_dist_res = np.concatenate(dist_ifft_result, axis=0)
print(f"Distributed iFFT result shape: {ifft_dist_res.shape} and dtype: {ifft_dist_res.dtype}")
if print_full_data: print(ifft_dist_res)
print(f"Verification of distributed iFFT: Arrays are identical: {np.array_equal(ifft_result, ifft_dist_res)}\n")

if print_full_data: print(numbers_3d)
print(f"Verification of distributed FFTs: Arrays are identical after forward and backward: {np.array_equal(numbers_3d, ifft_dist_res)}")

Initial shape: (4, 4, 4) and dtype: float64

Regular FFT result shape: (4, 4, 4) and dtype: complex128
Regular iFFT result shape: (4, 4, 4) and dtype: float64
Verification of regular FFTs: Arrays are identical after forward and backward: True

Distributed FFT result shape: (4, 4, 4) and dtype: complex128
Verification of distributed FFT: Arrays are identical: True

Distributed iFFT result shape: (4, 4, 4) and dtype: float64
Verification of distributed iFFT: Arrays are identical: True

Verification of distributed FFTs: Arrays are identical after forward and backward: True
