### Accelleration of CNN using MPI and JIT

In [15]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

The following code loads all the performance provided by the job submissions of the different implementations. The idea is to perform a simple analysis on the time needed for training the convolutional layer. The analyis is performed using 5, 10, 15 images and 200 epochs.

In [19]:
df_jax = pd.read_csv('time_analysis_jax.csv', header=None, names=["type", "process", "size_dataset", "num_epochs", "tot_time", "iter_time"])
df_jax.head()
df_mpi = pd.read_csv('time_analysis_mpi.csv', header=None, names=["type", "process", "size_dataset", "num_epochs", "tot_time", "iter_time"])
df_serial = pd.read_csv('time_analysis_serial.csv', header=None, names=["type", "process", "size_dataset", "num_epochs", "tot_time", "iter_time"])
df_combined = pd.concat([df_jax, df_mpi, df_serial], ignore_index=True)
df_combined.head()

Unnamed: 0,type,process,size_dataset,num_epochs,tot_time,iter_time
0,mpi_jax,5,5,200,39.914208,0.199549
1,mpi_jax,5,5,200,40.024055,0.200098
2,mpi_jax,5,5,200,40.01478,0.200052
3,mpi_jax,5,5,200,39.969562,0.199825
4,mpi_jax,5,5,200,39.91204,0.199538


The following code computes the average time and the standard deviation of the 5 runs that we have performed for each dimension and implementation.

In [20]:
avg_time = df_combined.groupby(['type', 'process', 'size_dataset', 'num_epochs']).mean()
avg_time 

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,tot_time,iter_time
type,process,size_dataset,num_epochs,Unnamed: 4_level_1,Unnamed: 5_level_1
mpi,5,5,200,783.121949,3.915565
mpi,10,10,200,778.669797,3.893304
mpi,15,15,200,796.242369,3.956753
mpi_jax,5,5,200,39.966929,0.199813
mpi_jax,10,10,200,40.174792,0.200852
mpi_jax,15,15,200,40.26092,0.201283
mpi_jax,20,20,200,40.399966,0.201978
mpi_jax,25,25,200,40.270402,0.20133
serial,1,5,20,392.333935,19.616696
serial,1,5,200,4002.104321,20.010521


In [21]:
std_deviation = df_combined.groupby(['type', 'process', 'size_dataset', 'num_epochs']).std()
std_deviation

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,tot_time,iter_time
type,process,size_dataset,num_epochs,Unnamed: 4_level_1,Unnamed: 5_level_1
mpi,5,5,200,3.355117,0.016775
mpi,10,10,200,2.183629,0.010917
mpi,15,15,200,4.188852,0.021374
mpi_jax,5,5,200,0.053274,0.000266
mpi_jax,10,10,200,0.214462,0.001072
mpi_jax,15,15,200,0.194258,0.000971
mpi_jax,20,20,200,0.288446,0.001442
mpi_jax,25,25,200,0.087213,0.000436
serial,1,5,20,,
serial,1,5,200,,


From the tables, it is evident that the MPI implementation is significantly more time-consuming compared to the JIT implementation. While the first iteration of the JIT implementation is computationally expensive (taking approximately 40 seconds), subsequent iterations are remarkably efficient (around 0.002 seconds). This highlights the JIT implementation's efficiency advantage for scenarios with a high number of iterations over the more traditional MPI approach.

In contrast, the serial version requires a considerable amount of time to perform 20 iterations on 5 images, emphasizing its limitations in handling larger workloads effectively.

The significant performance of the JIT approach stems from its ability to compile the convolution and loss functions into optimized machine code during their initial execution. This process, facilitated by JAX, minimizes computational overhead in subsequent iterations by utilizing caching and bypassing Python's interpreter, ensuring faster execution for repetitive tasks.