# CNN for Image Denoising

## GPU implementation
We used Jax capabilities to run the code on GPU, especially the optimized version of the convolution offered by the lax module. This offers primitive operations that are at the base of other modules, such as jax.numpy. Specifically, we used the `lax.conv_general_dilated` function, which is a general n-dimensional convolution operator, with optional dilation.
First of all, we need to reshape the input image, by adding two extra dimenions at the beginning, which are the batch size and the number of channels. The input image is a grayscale image, so the number of channels is 1. The output image is also a grayscale image, so again the number of channels is 1. The batch size is 1, so both the input and output images have shape `(1, 1, 28, 28)`.
The same operation is applied to the kernel, thus changing its shape to `(1, 1, 3, 3)`. We use a stride of 1 to perform the convolution, since we should not skip any pixel during the denoising process. The padding scheme is 'SAME', since the size of the output image is the same as the input image. Of course, we do not want any dilation. Lastly, we reshape the output tensor to have the original shape of the input.
The convolution operation implemented in the lax module proved way more efficient than the hand written one, since it is automatically integrated with the cuDNN library for NVIDIA GPUs.  

## Performance Tests
First of all, let us note that the input size is always a single 28x28 image, which does not provide a huge degree of parallelization per se. Nonetheless, the GPU capabilites of jax proved very efficient with respect to the type of operation that we are applying to the image, i.e. a convolution. The CPU version is very slow, hence why we did not perform the simulation multiple times. Since the execution times are very high, the standard deviation is expected to be very low. For the simulations of the GPU versions, we performed 30 runs for each number of iterations, and we computed the mean and the standard deviation of the execution times. The results are shown in the table below. 
Note that we did not measure the time to fetch the image from the MNIST dataset, the time to apply salt & pepper noise to it, nor the time to produce and write the plot to a file. These operations are not of our interest and pollute the performance comparisons of the CNN training and the image denoising afterwards. 