In [None]:
!pip install matplotlib

In [None]:
%%file gridding.py

from mpi4py import MPI
import numpy as np
import matplotlib.pyplot as plt
import sys

# Function to transform coordinates u and v to wavelengths
def transform_coordinates(data, frequencies):
    c = 299792458.0  # Speed of light in meters per second
    data[:, 0] = data[:, 0] * data[:, 6] / c
    data[:, 1] = data[:, 1] * data[:, 6] / c
    return data

# Function to read the input file
def read_data(input_file):
    data = []
    with open(input_file, 'r') as file:
        lines = file.readlines()
        for line in lines:
            values = line.strip().split(',')
            data.append([float(value) for value in values])
    return np.array(data)

# Function to perform the gridding process
def grid_data(local_data, N, delta_x_rad):
    local_F_r = np.zeros((N, N), dtype=np.float64)
    local_F_i = np.zeros((N, N), dtype=np.float64)
    local_W_t = np.zeros((N, N), dtype=np.float64)

    weight = local_data[:, 5]
    ik = np.round(local_data[:, 0]/ delta_x_rad).astype(int) + N // 2
    jk = np.round(local_data[:, 1]/ delta_x_rad).astype(int) + N // 2

    valid_indices = (ik >= 0) & (ik < N) & (jk >= 0) & (jk < N)

    local_F_r[ik[valid_indices], jk[valid_indices]] += weight[valid_indices] * local_data[valid_indices, 3]
    local_F_i[ik[valid_indices], jk[valid_indices]] += weight[valid_indices] * local_data[valid_indices, 4]
    local_W_t[ik[valid_indices], jk[valid_indices]] += weight[valid_indices]

    return local_F_r, local_F_i, local_W_t



# Parse command-line arguments
if len(sys.argv) != 7:
    print("Usage: mpirun -n num_processes python gridding.py -i data_file -d deltax -N image_size")
    sys.exit(1)

if sys.argv[1] != '-i':
    print("Error: You must specify the data file using -i")
    sys.exit(1)

if sys.argv[3] != '-d':
    print("Error: You must specify the value of deltax using -d")
    sys.exit(1)

if sys.argv[5] != '-N':
    print("Error: You must specify the image size using -N")
    sys.exit(1)


input_file = sys.argv[2]
delta_x_arcsec = float(sys.argv[4])
delta_x_rad = (np.pi / 180 / 3600) * delta_x_arcsec
N = int(sys.argv[6])
    
# Initialize MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

if rank == 0:
    data = read_data(input_file)
    frequencies = data[:, 6]
    data = transform_coordinates(data, frequencies)

else:
    data = None
    
# Scatter the data from the root process to all processes
recv_data = np.empty((data.shape[0] // size, data.shape[1]), dtype = np.float64)
comm.Scatter(data, recv_data, root = 0)

# Perform the gridding process on all processes
local_F_r, local_F_i, local_W_t = grid_data(recv_data, N, delta_x_rad)

# Gather the results back to the root process
global_F_r = np.empty((N, N), dtype=np.float64)
global_F_i = np.empty((N, N), dtype=np.float64)
global_W_t = np.empty((N, N), dtype=np.float64)

# Gather the results from all processes back to the root process
comm.Gather(local_F_r, global_F_r, root=0)
comm.Gather(local_F_i, global_F_i, root=0)
comm.Gather(local_W_t, global_W_t, root=0)

# The root process (rank 0) join partials results
if rank == 0:
    # Sum all the matrices
    global_F_r_sum = np.sum(global_F_r, axis=0)
    global_F_i_sum = np.sum(global_F_i, axis=0)
    global_W_t_sum = np.sum(global_W_t, axis=0)

    # Calculate the dirty image
    dirty_image = np.fft.ifft2(np.fft.fftshift(global_F_r_sum + 1j * global_F_i_sum) / np.fft.fftshift(global_W_t_sum))

    # Display the dirty image
    plt.imshow(np.abs(dirty_image), origin='lower', extent=[0, N, 0, N])
    plt.colorbar()
    plt.show()

In [None]:
!mpirun -n 8 python gridding.py -i hltau_completo_uv.csv -d 0.1 -N 256