<a href="https://colab.research.google.com/github/dmmdanielg/ImageProcessingColab/blob/main/ImageProcessingGitHub.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import multiprocessing

cores = multiprocessing.cpu_count() # Count the number of cores in a computer
print('cores', cores)

import os
os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=10"

import jax.numpy as jnp
import numpy as np
from jax import grad, jit, vmap, pmap, random, block_until_ready, config
from jax.lib import xla_bridge
from jax.scipy.signal import convolve2d as convolve2d_jax
from jax.config import config
from scipy.signal import convolve2d
from matplotlib import pyplot as plt
from functools import partial
from jax.numpy.fft import fft2 as jfft2
from jax.numpy.fft import ifft2 as jifft2
from jax.numpy.fft import fftshift as jfftshift
from jax.numpy.fft import rfft2 as jrfft2
from jax.numpy.fft import irfft2 as jrifft2
from scipy.fft import rfft2, irfft2, fftshift, fft2, ifft2

import time
from dataclasses import dataclass
from jax import device_count, devices
config.update("jax_enable_x64", True)
print('device_count', device_count())
print('devices', devices())




def processImage(image, alpha , beta, array1, array2):
    
    mean  = np.mean(image)
    
    i = np.arange(0,image.shape[0])
    var1 = (i * np.exp(alpha) +  beta*beta*(beta + (i - 30)) + image[i,i]*0.1)*mean
    var2 = np.matmul(array1, array2.T) + np.matmul(array1, (array2*array2*(array2+array1)).T)
    
    sol = np.sum(np.diagonal(var2) * np.sum(var1))
        
    kernel = sol * np.fromfunction(lambda i ,j: (1+i)*(1+j), (5,5), dtype=int)
    kernel2 = kernel/np.max(kernel)

    image_spectrum = rfft2(image, s = image.shape)
    kernel_spectrum = rfft2(kernel2, s = image.shape)
    filtered_spectrum = image_spectrum * kernel_spectrum
    return irfft2(filtered_spectrum)

@jit
def processImage_jax(image, alpha, beta, array1, array2):
    
    mean  = jnp.mean(image)
        
    i = jnp.arange(0,image.shape[0])
    var1 = (i * jnp.exp(alpha) +  beta*beta*(beta + (i - 30)) + image[i,i]*0.1)*mean    
    var2 = jnp.matmul(array1, array2.T) + jnp.matmul(array1, (array2*array2*(array2+array1)).T)
    
    sol = jnp.sum(jnp.diagonal(var2) * jnp.sum(var1))
        
    kernel = sol * np.fromfunction(lambda i ,j: (1+i)*(1+j), (5,5), dtype=int)
    kernel2 = kernel/np.max(kernel)
    
    image_spectrum = jrfft2(image, s = image.shape)
    kernel_spectrum = jrfft2(kernel2, s = image.shape)
    filtered_spectrum = image_spectrum * kernel_spectrum
    return jrifft2(filtered_spectrum)
        
def processImage_batch_jax(batch_images, alpha, beta, array1, array2):
    
    image_results = np.zeros((batch_images.shape[0],batch_images.shape[1],batch_images.shape[2]))
    for i in range(batch_images.shape[0]):
        image_results[i,:,:] = processImage_jax(batch_images[i,:,:], alpha, beta, array1, array2)
        
    return image_results
        
def processImage_batch(batch_images, alpha, beta, array1, array2):
    
    image_results = np.zeros((batch_images.shape[0],batch_images.shape[1],batch_images.shape[2]))
    for i in range(batch_images.shape[0]):
        image_results[i,:,:] = processImage(batch_images[i,:,:], alpha, beta, array1, array2)
        
    return image_results

np.random.seed(5)
image = (np.random.rand(1000,1000))
batch_images = ((np.random.rand(40,500,500)))
batch_images_pmap = np.reshape(batch_images,(10,4,500,500))

array1 = (np.random.rand(100,1000))
array2 = (np.random.rand(100,1000))

process_image_jax_vmap = vmap(processImage_jax, in_axes = (0,None,None,None,None))
process_image_jax_vmap_jit = jit(process_image_jax_vmap)
process_image_jax_pmap = pmap(process_image_jax_vmap, in_axes = (0,None,None,None,None))

%timeit -n 7 -r 5 processImage_batch_jax(batch_images, 0.5, 0.3, array1, array2)
%timeit -n 7 -r 3 processImage_batch(batch_images, 0.5, 0.3, array1, array2)
%timeit -n 7 -r 5 process_image_jax_vmap_jit(batch_images, 0.5, 0.3, array1, array2).block_until_ready()
%timeit -n 7 -r 5 process_image_jax_vmap(batch_images, 0.5, 0.3, array1, array2).block_until_ready()
%timeit -n 7 -r 5 process_image_jax_pmap(batch_images_pmap, 0.5, 0.3, array1, array2).block_until_ready()

batch_images_results = processImage_batch(batch_images, 0.5, 0.3, array1, array2)

batch_images_results_vmap_jax= process_image_jax_vmap_jit(batch_images, 0.5, 0.3, array1, array2)
batch_images_results_jax = processImage_batch_jax(batch_images, 0.5, 0.3, array1, array2)
batch_images_results_numpy = processImage_batch(batch_images, 0.5, 0.3, array1, array2)
batch_images_results_jax_2 = process_image_jax_pmap(batch_images_pmap, 0.5, 0.3, array1, array2)
batch_images_results_jax_2 = np.vstack(batch_images_results_jax_2)

print(np.allclose(batch_images_results_vmap_jax,batch_images_results))
print(np.allclose(batch_images_results_jax_2,batch_images_results))



cores 2




device_count 10
devices [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9)]
1.16 s ± 636 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)
