# 4 Patch Level Color Analysis

## 4.1 Imports & Constants

In [1]:
import itertools
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patch

from PIL import Image
from sklearn import feature_extraction

from library import generator

%config InlineBackend.figure_format='retina'

In [2]:
DUCK = 'images/duck.jpg'

## 4.2 Utilities

In [3]:
def read_color_image(address):
    image = Image.open(address)
    pixels = np.array(image)
    pixels = pixels.astype(np.float64)
    return pixels

def to_color_patches(pixels, patch_size):
    slices = []
    channels = pixels.shape[-1]
    for channel in range(channels):
        current = pixels[:,:,channel]
        shaped_patches = feature_extraction.image.extract_patches_2d(current, patch_size)
        patches = np.reshape(shaped_patches, (len(shaped_patches), -1)).T
        slices.append(patches)
    return np.vstack(slices)

def patch_to_image(pixels, patch_size, channels):
    channel_patches = np.split(pixels, channels)
    for channel in range(channels):
        channel_patches[channel] = np.reshape(channel_patches[channel], (patch_size, patch_size))
    patch = np.dstack(channel_patches)
    return patch

def normalize_image(pixels):
    pixels = np.maximum(pixels, 0)
    pixels = np.minimum(pixels, 255)
    return pixels / 255

def get_color_patch_index(patch_matrix, snapshot, channels):
    slices = []
    for channel in range(channels):
        slices.append(snapshot[:,:,channel])
    target = np.concatenate(slices, axis=None)
    for col in range(len(patch_matrix.T)):
        current_patch = patch_matrix[:,col]
        if (target == current_patch).all():
            return col
    raise KeyError()

## 4.3 Figure Components

In [4]:
duck = read_color_image(DUCK)
noisy_duck = duck + np.random.normal(scale=20, size=duck.shape)

CHANNELS = 3
PATCH_SIZE = 8

clean_patches = to_color_patches(duck, (PATCH_SIZE, PATCH_SIZE))
noisy_patches = to_color_patches(noisy_duck, (PATCH_SIZE, PATCH_SIZE))

In [10]:
### HIGHLIGHTED CLEAN DUCK

LOCATIONS = [(205, 117), (85, 170), (125, 135)]
COLORS = ['red', 'purple', 'blue']

fig, ax = plt.subplots(1, figsize=(64, 64))
plt.imshow(normalize_image(duck))

for location, color in zip(LOCATIONS, COLORS):
    rect = patch.Rectangle(
        location,
        PATCH_SIZE,
        PATCH_SIZE,
        linewidth=30,
        edgecolor=color,
        facecolor='none'
    )
    ax.add_patch(rect)
plt.axis('off')
fig.savefig('04-highlighted-patches.pdf', bbox_inches='tight')
plt.close()

In [6]:
### PATCH SNAPSHOTS

for index, (location, color) in enumerate(zip(LOCATIONS, COLORS)):
    x_start, y_start = location
    
    clean_snapshot = duck[y_start:y_start + PATCH_SIZE, x_start:x_start + PATCH_SIZE]
    fig, ax = plt.subplots(1, figsize=(64, 64))
    plt.imshow(normalize_image(clean_snapshot), interpolation='nearest')
    plt.axis('off')
    fig.savefig(f'04-clean-snapshot-{index}.pdf', bbox_inches='tight')
    plt.close()
    
    noisy_snapshot = noisy_duck[y_start:y_start + PATCH_SIZE, x_start:x_start + PATCH_SIZE]
    fig = plt.figure(figsize=(64, 64))
    plt.imshow(normalize_image(noisy_snapshot), interpolation='nearest')
    plt.axis('off')
    fig.savefig(f'04-noisy-snapshot-{index}.pdf', bbox_inches='tight')
    plt.close()

In [7]:
### SNAPSHOT BASES

correct_indices = []
for x_start, y_start in LOCATIONS:
    clean_snapshot = duck[y_start:y_start + PATCH_SIZE, x_start:x_start + PATCH_SIZE,:]
    index = get_color_patch_index(clean_patches, clean_snapshot, CHANNELS)
    correct_indices.append(index)

In [8]:
ITERATIONS = 100

updates = generator.get_dictionary_learning_iterates(clean_patches)
clean_dictionary = next(itertools.islice(updates, ITERATIONS, None))
clean_dictionary = clean_dictionary.T
clean_encoding = clean_dictionary.T @ clean_patches

updates = generator.get_dictionary_learning_iterates(noisy_patches)
noisy_dictionary = next(itertools.islice(updates, ITERATIONS, None))
noisy_dictionary = noisy_dictionary.T
noisy_encoding = noisy_dictionary.T @ noisy_patches

In [9]:
PACKAGED = [
    ('clean', clean_dictionary, clean_encoding),
    ('noisy', noisy_dictionary, noisy_encoding)
]

BASES = 6

for index, correct_index in enumerate(correct_indices):
    for label, dictionary, encoding in PACKAGED:
        patch_coding = encoding[:,correct_index]
        patch_coding = np.abs(patch_coding)
        
        all_indices = list(range(len(patch_coding)))
        all_indices.sort(key=lambda num: patch_coding[num], reverse=True)
        
        fig, axs = plt.subplots(1, BASES, figsize=(60, 10))
        plt.subplots_adjust(left=None, right=None, bottom=None, top=None, wspace=0.05, hspace=0.05)
        for col, ax in zip(all_indices, axs.flat):
            base = dictionary[:,col] * np.sign(encoding[col, correct_index])
            base = base - base.min()
            base = base / base.max()
            base = patch_to_image(base, PATCH_SIZE, CHANNELS)

            ax.imshow(base)
            ax.axis('off')
        plt.plot()
        fig.savefig(f'04-patch-{label}-bases-{index}.pdf', bbox_inches='tight')
        plt.close()