In [9]:
import numpy as np
import scipy
from scipy.io import loadmat
import matplotlib.pyplot as plt

In [None]:
def sample_images_raw(fname):
    image_data = loadmat(fname)['IMAGESr']

    patch_size = 12
    n_patches = 10000
    image_size = image_data.shape[0]
    n_images = image_data.shape[2]

    patches = np.zeros(shape=(patch_size * patch_size, n_patches))

    for i in range(n_patches):
        image_id = np.random.randint(0, n_images)
        image_x = np.random.randint(0, image_size - patch_size)
        image_y = np.random.randint(0, image_size - patch_size)

        img = image_data[:, :, image_id]
        patch = img[image_x:image_x + patch_size, image_y:image_y + patch_size].reshape(-1)
        patches[:, i] = patch

    return patches

def display_network(A):
    opt_normalize = True
    opt_graycolor = True

    A = A - np.average(A)

    (row, col) = A.shape
    sz = int(np.ceil(np.sqrt(row)))
    buf = 1
    n = np.ceil(np.sqrt(col))
    m = np.ceil(col / n)
    
    img_shape1 = int(buf + m * (sz + buf))
    img_shape2 = int(buf + n * (sz + buf))
    image = np.ones(shape=(img_shape1, img_shape2))

    if not opt_graycolor:
        image *= 0.1

    k = 0
    for i in range(int(m)):
        for j in range(int(n)):
            if k >= col:
                continue
            clim = np.max(np.abs(A[:, k]))
            if opt_normalize:
                image[buf + i * (sz + buf):buf + i * (sz + buf) + sz, buf + j * (sz + buf):buf + j * (sz + buf) + sz] = \
                    A[:, k].reshape(sz, sz) / clim
            else:
                image[buf + i * (sz + buf):buf + i * (sz + buf) + sz, buf + j * (sz + buf):buf + j * (sz + buf) + sz] = \
                    A[:, k].reshape(sz, sz) / np.max(np.abs(A))
            k += 1
    return image


def get_optimal_k(threshold, s):
    k = 0
    total_sum = np.sum(s)
    sum_ev = 0.0
    for i in range(s.size):
        sum_ev += s[i]     
        ratio = sum_ev / total_sum
        if ratio > threshold: break
        k += 1
    return k   

In [None]:
x = sample_images_raw('IMAGES_RAW.mat')

n, m = x.shape
random_sel = np.random.randint(0, m, 200)
image_x = display_network(x[:, random_sel])

fig = plt.figure()
plt.imshow(image_x, cmap=plt.cm.gray)
plt.title('Raw patch images')