In [6]:
import keras
from keras.models import Sequential
from keras.layers import Dense
from PIL import Image
import json
import requests
from io import BytesIO
import numpy as np

In [67]:
def get_image_from_url(url):
    """
    returns PIL Image object from response
    :param url:
    :return:
    """
    response = requests.get(url, stream=True)
    img_loc = Image.open(BytesIO(response.content))
    return img_loc


def rgbToGray(rgb):
    """
    :param rgb: 1D np array consisting of RGB values of the pixel
    :return: gray = 0.21r + 0.72g  + 0.07b
    """
    return np.dot(np.array([0.21, 0.72, 0.07]), rgb)


def generate_data_set_from_image(image: Image, x_data, y_data, window_shape = (3, 3)):
    """
    Window dimensions are always assumed to be odd, i.e. we always have a middle element
    """
    arr = np.asarray(image)
    n, m = arr.shape[0], arr.shape[1]

    # base case: image must be larger than the window size
    if n < window_shape[0] or m < window_shape[1]:
        return None

    row_margin = window_shape[0] // 2
    col_margin = window_shape[1] // 2
    for i in range(0 + row_margin, n - row_margin):
        for j in range(0 + col_margin, m - col_margin):
            data_pt = []
            for x in range(i - row_margin, i + row_margin + 1):
                for y in range(j - col_margin, j + col_margin + 1):
                    rgb = arr[x, y, ] / 255
                    # middle element of filter, to be used as output
                    if x == i and y == j:
                        y_data.append(rgb)
                    data_pt.append(rgbToGray(rgb) / 255)
            x_data.append(np.array(data_pt))
    return x_data, y_data

In [93]:
# global variables
window_shape = (5, 5)
row_margin = window_shape[0] // 2
col_margin = window_shape[1] // 2

if __name__ == '__main__':
    x_data = []
    y_data = []
    with open('dataset/forest_photos_info.json', 'r') as fo:
        dataset = json.loads(fo.read())
        for i, image_key in enumerate(oceans):
            print(f"Processing Img: {image_key}")
            img = get_image_from_url(dataset[image_key]['url'])
            # img.show()
            # input()
            generate_data_set_from_image(img, x_data, y_data, window_shape)
            if i == 10:
                print(f"Generating data from {i+1} images")
                break

Processing Img: 27950322870_771d53a6e2.jpg
Processing Img: 27927622835_32b1427ea7.jpg
Processing Img: 22888085296_81b4b91bb0.jpg
Processing Img: 25386994904_7e78d93464.jpg
Processing Img: 3765844098_5a00513949.jpg
Processing Img: 8391674818_19b47906b4.jpg
Processing Img: 26237203136_59d280b8e4.jpg
Processing Img: 28063736211_9321570bbd.jpg
Processing Img: 26601819584_77220c0eb5.jpg
Processing Img: 8023983958_567e126058.jpg
Processing Img: 16479227725_ac8b713757.jpg
Generating data from 11 images


In [94]:
model = Sequential()

model.add(Dense(30, input_dim=window_shape[0] * window_shape[1], activation='sigmoid'))
model.add(Dense(15, activation='sigmoid'))
model.add(Dense(3))
          
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])

In [None]:
model.fit(np.array(x_data), np.array(y_data))

Epoch 1/1

### Now Let's try to predict the output image

In [76]:
list(dataset.keys())[-10:-1]

['1299455_fb3636f431.jpg',
 '7460265344_f3a60b2a6d.jpg',
 '5660119953_229aac8e2d.jpg',
 '14759591072_a6bfaa05ac.jpg',
 '26911681691_7b7494d1bc.jpg',
 '7793140654_4f8abfb57c.jpg',
 '13193044613_345bf475bd.jpg',
 '4692542180_71ebaa29a7.jpg',
 '4124025640_59d557c75a.jpg']

In [91]:
# test image
img_str = '4124025640_59d557c75a.jpg'
img = get_image_from_url(dataset[img_str]['url'])
np_img = np.asarray(img)
img.show()

# gray scale
gray_img = 0.21*np_img[:, :, 0] + 0.72*np_img[:, :, 1] + 0.07*np_img[:, :, 2]
gray_img = Image.fromarray(gray_img.astype(np.uint8))
gray_img.show()

In [92]:
out = np.zeros(shape)
x_test, y_test = [], []

generate_data_set_from_image(img, x_test, y_test, window_shape)
out = model.predict(np.array(x_test))
out = out * 255
out = out.reshape(np_img.shape - np.array([2*row_margin, 2*col_margin, 0]))
print(out.shape)

out_img = Image.fromarray(out.astype(np.uint8))
out_img.show()

(371, 496, 3)


#### end #####