In [1]:
import pyspark
from pyspark import SparkContext
import imageio
import os
import numpy as np

In [2]:
def readImg(path):
    img = imageio.imread(path)
    im = np.array(img,dtype='uint8')
    return im

def writeImg(path,buf):
    imageio.imwrite(path,buf)

def part_median_filter(local_data):
    part_id = local_data[0]
    first   = local_data[1]
    end     = local_data[2]
    buf     = local_data[3]
    nx = buf.shape[0]
    ny = buf.shape[1]
    
    # CREATE NEW BUF WITH MEDIAN FILTER SOLUTION
    new_buf = np.array([end-first,ny],dtype='uint8')
    new_buf = buf[int(first):int(end),:,:]
    
    # TODO COMPUTE MEDIAN FILTER
    for i in range(1, int(end)-int(first)-1):
        for j in range(1, ny-1):
            surr_pixels = [new_buf[i-1,j-1], new_buf[i-1,j], new_buf[i-1,j+1], 
                           new_buf[i,j-1], new_buf[i,j], new_buf[i,j+1], 
                           new_buf[i+1,j-1], new_buf[i+1,j], new_buf[i+1,j+1]]
            new_buf[i,j] = np.median(surr_pixels)
            
    # RETURN LOCAL IMAGE PART
    
    return part_id,new_buf


In [3]:
def main():
    file = 'lena_noisy.jpg'
    img_buf=readImg(file)
    print('SHAPE',img_buf.shape)
    #print('IMG\n',img_buf)
    nx=img_buf.shape[0]
    ny=img_buf.shape[1]
    
    # SPLT IMAGES IN NB_PARTITIONS PARTS
    nb_partitions = 8
    print("NB PARTITIONS : ", nb_partitions)
    data=[]
    begin=0
    block_size=nx/nb_partitions
    for ip in range(nb_partitions):
        end=min(begin+block_size,nx)
        data.append([ip,begin,end,img_buf])
        begin=end
    
    print(data[0][3].shape)
    # CREATE SPARKCONTEXT
    
    sc = SparkContext()
    data_rdd = sc.parallelize(data,nb_partitions)
    


    # PARALLEL MEDIAN FILTER COMPUTATION
    result_rdd = data_rdd.map(part_median_filter)
    result_data = result_rdd.collect()
    print(result_data[0][1].shape)

    new_img_buf = result_data[0][1]
   
    # COMPUTE NEW IMAGE RESULTS FROM RESULT RDD
    for ip in range(1, nb_partitions):
        new_img_buf = np.concatenate((new_img_buf, result_data[ip][1]), axis=0)
    
    print(new_img_buf.shape)
    print('CREATE NEW PICTURE FILE')
    filter_file = 'lena_filter.jpg'
    writeImg(filter_file,new_img_buf)

if __name__ == '__main__':
    main()

SHAPE (128, 128, 3)
NB PARTITIONS :  8
(128, 128, 3)
(16, 128, 3)
(128, 128, 3)
CREATE NEW PICTURE FILE
