In [21]:
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from ae import AutoEncoder
import torch
%reload_ext autoreload
%autoreload 2


In [2]:
def load_image_as_np(image_path):
    return np.asarray(Image.open(image_path).convert('RGB'), dtype=np.uint8)

def store_image_from_np(image_path,data,format='RGB'):
    img = Image.fromarray(data, format)
    img.save(image_path)
    return img

In [3]:
def segment_image(image,tile_size,pad_type='reflect'):
    img_height,img_width = image.shape[:2]

    # Pads the image so it can be chunked down to a grid even if the size of the image is not
    # divisible by the chunk size
    v_pad = (0,tile_size - (img_height % tile_size)) if img_height % tile_size != 0 else (0,0)
    h_pad = (0,tile_size - (img_width % tile_size)) if img_width % tile_size != 0 else (0,0)
        
    image = np.pad(image, (v_pad,h_pad,(0,0)), pad_type)

    img_height , img_width, channels = image.shape

    tiled_array =  image.reshape(img_height // tile_size,
                                 tile_size,
                                 img_width // tile_size,
                                 tile_size,
                                 channels)

    tiled_array = tiled_array.swapaxes(1,2)

    return np.concatenate(tiled_array,axis=0)


In [4]:
def rebuild_image(tile_array,image_size):
    img_height, img_width, channels = image_size
    
    tile_size = tile_array.shape[1]
    tile_rows = int(np.ceil(img_height/tile_size))
    tile_cols = int(np.ceil(img_width/tile_size))

    tile_array = tile_array.reshape(tile_rows,
                                    tile_cols,
                                    tile_size,
                                    tile_size,
                                    channels)
    #print("Rows:",chunk_rows,"Cols:",chunk_cols, "New shape:",new_shape)
    
    tile_array = np.concatenate(tile_array,axis=1)
    tile_array = np.concatenate(tile_array,axis=1)

    return tile_array[:img_height,:img_width]

In [44]:
def get_model(model_path):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = AutoEncoder.load_autoencoder(model_path)
    model.to(device)
    return model

def apply_compress_function(tile_tensor, model_path):
    model = get_model(model_path)
    return model.encode(tile_tensor)

def apply_decompression_function(encoded_tile_tensor,model_path) :
    model = get_model(model_path)
    return model.decode(encoded_tile_tensor)    

In [None]:
def make_tensor(tile_list_array):
    tile_tensor = torch.from_numpy(tile_list_array)
    print(tile_tensor.reshape(-1,3*8*8))

    return tile_tensor

def retrieve_array(decoded_tile_tensor):
    return decoded_tile_tensor.detach().cpu().view(-1,3,8,8).numpy()

In [45]:
wd_path = os.path.abspath(os.getcwd())
models_path = os.path.join(wd_path, "models")
model_used_path = os.path.join(models_path, "ae_0.pt")
image_path = os.path.join(wd_path,"images","val_0_0.jpeg")
image_path_out = os.path.join(wd_path,"images","val_0_0_out.jpeg")

test_image= load_image_as_np(image_path)
tile_list = segment_image(test_image,tile_size=8,pad_type='reflect')

c_image = apply_compress_function(tile_list,model_used_path)
dc_image = apply_decompression_function(c_image,model_used_path)

end_image = rebuild_image(dc_image,test_image.shape)
plt.imshow(end_image.astype('uint8'))
plt.axis('off')
plt.show()

[[[[ 59  28 207  18 164   4 205   0]
   [197  21 214   0 198   0 222   0]
   [204  62 102   0 150  26 199  35]
   [204 150   0 128  35 132 190   3]
   [ 34 147  41  65 135  76 160 163]
   [241 156  67 151  17  32  50  58]
   [221  42  22  17  55  13  58  48]
   [ 77  53  20  84   4  43  56 101]]

  [[ 38  18 216  39 193  30 219   1]
   [177   9 224  22 227  16 233   5]
   [180  49 107   6 175  46 204  34]
   [178 133   0 143  54 147 191   0]
   [  4 126  40  74 148  85 156 151]
   [209 131  62 156  27  35  40  40]
   [185  13  13  18  58  13  40  24]
   [ 38  23   9  84   6  39  37  73]]

  [[ 53  29 221  40 191  29 222   7]
   [189  19 226  20 223  13 235   9]
   [194  59 111   6 172  45 208  40]
   [189 141   2 140  50 144 193   2]
   [ 12 131  38  69 141  80 155 153]
   [214 134  58 149  18  28  38  40]
   [189  15   8  10  49   5  38  24]
   [ 41  23   3  74   0  30  33  72]]]]
tensor([[ 59,  28, 207,  18, 164,   4, 205,   0, 197,  21, 214,   0, 198,   0,
         222,   0, 204,  6

RuntimeError: expected scalar type Float but found Byte

In [34]:
wd_path = os.path.abspath(os.getcwd())
models_path = os.path.join(wd_path, "models")
model_used_path = os.path.join(models_path, "ae_0.pt")
image_path = os.path.join(wd_path,"images","tenemos.png")
image_path_out = os.path.join(wd_path,"images","tenemos_out.png")
tile_size= 8

image = load_image_as_np(image_path)

print("Original size:",image.shape)
tile_list = segment_image(image,tile_size,pad_type='reflect')
print(f"Number of {tile_size}x{tile_size} tiles:",tile_list.shape[0])

c_image = apply_compress_function(tile_list,model_used_path)
dc_image = apply_decompression_function(c_image,model_used_path)

end_image = rebuild_image(dc_image,image.shape)
#print("Original and end images are equal:",(end_image==image).all())
#print("With equal shape:",(end_image.shape==image.shape))

img = store_image_from_np(image_path_out,end_image,format='RGB')

plt.imshow(end_image.astype('uint8'))
plt.axis('off')
plt.show()

Original size: (900, 1200, 3)
Number of 8x8 tiles: 16950
(16950, 3, 8, 8)
torch.Size([16950, 3, 8, 8])


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.