In [3]:
def get_combinations(widths, heights):    
   
    sizes = []
    
    for width in widths:
        for height in heights:
            
            size = (width, height)
            reversed_size = (height, width)
                           
            if reversed_size in sizes:
                continue
                
            sizes.append(size)    
   
    sorted_sizes = tuple(sorted(sizes))    

    return sorted_sizes

### 3) Image Processing

In [4]:
def load_one_MNIST_image():

    # Load MNIST dataset

    mnist_dataset = keras.datasets.mnist
    train_data, test_data = mnist_dataset.load_data()

    train_images, train_labels = train_data
    test_images, test_labels = test_data

    # Normalize pixels

    train_images = train_images / 255
    test_images = test_images / 255

    # Add extra dimension for convolution channels

    train_images = np.array(train_images[..., tf.newaxis], requires_grad=False)
    test_images = np.array(test_images[..., tf.newaxis], requires_grad=False)

    one_image = train_images[0]
    
    return one_image


def parse_image(image_size, filter_size, stride_size, filters_count):
    
    image_width, image_height = image_size    
    filter_width, filter_height = filter_size
    stride_width, stride_height = stride_size
        
    filter_surface = filter_width * filter_height
    
    # Repetition counts
    
    if stride_width >= filter_width:
        
        horizontal_filter_repetitions = image_width // stride_width        
        horizontal_filter_repetitions += (image_width % stride_width) >= filter_width  

    else:        
       
        horizontal_filter_repetitions = (image_width - filter_width + 1) // stride_width  

        
    if stride_height >= filter_height:
        
        vertical_filter_repetitions = image_height // stride_height        
        vertical_filter_repetitions += (image_height % stride_height) >= filter_height
        
    else:
 
        vertical_filter_repetitions = (image_height - filter_height + 1) // stride_height
    
    
    filter_repetitions = horizontal_filter_repetitions * vertical_filter_repetitions
    
    filter_applications = filter_repetitions * filters_count

    complexity = (2 ** filter_surface
                  * filters_count
                  * filter_repetitions) 

    result = {"filter_width": filter_width, 
              "filter_height": filter_height,
              "filter_surface": filter_surface,
              "filters_count": filters_count,

              "stride_width": stride_width,
              "stride_height": stride_height,                      
              
              "horizontal_filter_repetitions": horizontal_filter_repetitions,
              "vertical_filter_repetitions": vertical_filter_repetitions,
              "filter_repetitions": filter_repetitions,              
              "filter_applications": filter_applications,   
              
              "complexity": complexity}
    
    return result


def get_pixel_values(image, experiment):
     
    filter_width = experiment['filter_width']
    filter_height = experiment['filter_height']  
    filters_count = experiment['filters_count']

    stride_width = experiment['stride_width']
    stride_height = experiment['stride_height']
    
    feature_width = horizontal_filter_repetitions = experiment['horizontal_filter_repetitions']    
    feature_height = vertical_filter_repetitions = experiment['vertical_filter_repetitions']
    
    channels_count = qubits_count = filter_surface = experiment['filter_surface']
   
    # Fetch pixels
    
    pixel_values_dict = dict()
    
    for feature_x in range(horizontal_filter_repetitions):
        
        image_corner_x = feature_x * stride_width
        
        for feature_y in range(vertical_filter_repetitions):
            
            image_corner_y = feature_y * stride_height
                
            image_fragment = image[image_corner_y : image_corner_y + filter_height,
                                   image_corner_x : image_corner_x + filter_width] 
            
            pixel_values_dict[(feature_x, feature_y)] = image_fragment.flatten()            
    
    pixel_values_array = jnp.array(list(pixel_values_dict.values()))
    
    return pixel_values_array