# 5 Multi-Image Dictionary

## 5.1 Imports & Constants

In [1]:
import itertools
import numpy as np
import matplotlib.pyplot as plt

from library import generator
from cifar10_web import cifar10

%config InlineBackend.figure_format='retina'

## 5.2 Utilities

In [2]:
def patch_to_image(pixels, patch_size, channels):
    channel_patches = np.split(pixels, channels)
    for channel in range(channels):
        channel_patches[channel] = np.reshape(channel_patches[channel], (patch_size, patch_size))
    patch = np.dstack(channel_patches)
    return patch

## 5.3 Figure Components

In [3]:
CIFAR_SIZE = 32
CIFAR_CHANNELS = 3

train_images, train_labels, _, _ = cifar10(path=None)
total_images = len(train_images)

In [4]:
CHANNELS = 3
PATCH_SIZE = 8
TOTAL_PATCHES = 100000

observations = np.zeros((PATCH_SIZE * PATCH_SIZE * CHANNELS, TOTAL_PATCHES))
for col in range(TOTAL_PATCHES):
    image_index = np.random.randint(total_images)
    image = patch_to_image(train_images[image_index], CIFAR_SIZE, CIFAR_CHANNELS)
    x_start = np.random.randint(CIFAR_SIZE - PATCH_SIZE)
    y_start = np.random.randint(CIFAR_SIZE - PATCH_SIZE)
    sample_patch = image[x_start:x_start + PATCH_SIZE, y_start:y_start + PATCH_SIZE,:]

    slices = []
    for channel in range(CHANNELS):
        current = sample_patch[:,:,channel]
        slices.append(np.ndarray.flatten(current))
    observations[:,col] = np.concatenate(slices, axis=None)

In [5]:
ITERATIONS = 200

updates = generator.get_dictionary_learning_iterates(observations)
dictionary = next(itertools.islice(updates, ITERATIONS, None))
dictionary = dictionary.T
encoding = dictionary.T @ observations

In [17]:
### CIFAR BASES

norms = np.abs(encoding)
norms = np.sum(norms, axis=1)
all_indices = list(range(len(norms)))
all_indices.sort(key=lambda num: norms[num], reverse=True)

sum_signs = np.sum(encoding, axis=1)
sum_signs = np.sign(sum_signs)

ROWS, COLS = 12, 16

fig, axs = plt.subplots(ROWS, COLS, figsize=(64, 48))
plt.subplots_adjust(left=None, right=None, bottom=None, top=None, wspace=0.05, hspace=0.05)
for index, ax in zip(all_indices[: ROWS * COLS], axs.flat):
    base = dictionary[:,index] * sum_signs[index]
    base = base - base.min()
    base = base / base.max()
    base = patch_to_image(base, PATCH_SIZE, CHANNELS)
    
    ax.imshow(base)
    ax.axis('off')
fig.savefig('05-cifar-bases.pdf', bbox_inches='tight')
plt.close()

In [7]:
### NOISY CIFAR BASES

ITERATIONS = 200

noisy_observations = observations + np.random.normal(scale=0.075, size=observations.shape)
updates = generator.get_dictionary_learning_iterates(noisy_observations)
noisy_dictionary = next(itertools.islice(updates, ITERATIONS, None))
noisy_dictionary = noisy_dictionary.T
noisy_encoding = noisy_dictionary.T @ noisy_observations

In [18]:
norms = np.abs(noisy_encoding)
norms = np.sum(norms, axis=1)
all_indices = list(range(len(norms)))
all_indices.sort(key=lambda num: norms[num], reverse=True)

sum_signs = np.sum(noisy_encoding, axis=1)
sum_signs = np.sign(sum_signs)

ROWS, COLS = 12, 16

fig, axs = plt.subplots(ROWS, COLS, figsize=(64, 48))
plt.subplots_adjust(left=None, right=None, bottom=None, top=None, wspace=0.05, hspace=0.05)
for index, ax in zip(all_indices, axs.flat):
    base = noisy_dictionary[:,index] * sum_signs[index]
    base = base - base.min()
    base = base / base.max()
    base = patch_to_image(base, PATCH_SIZE, CHANNELS)
    
    ax.imshow(base)
    ax.axis('off')
fig.savefig('05-noisy-cifar-bases.pdf', bbox_inches='tight')
plt.close()

In [40]:
### CORRUPTED CIFAR BASES

SIGMA = 0.5 * np.std(observations)
BETA = 0.2
ITERATIONS = 200

plus_minus = (np.random.binomial(1, 0.5, size=observations.shape) - 1) * 2
corruptions = np.random.binomial(1, BETA, size=observations.shape) * plus_minus * SIGMA
corrupted_observations = observations + corruptions

updates = generator.get_dictionary_learning_iterates(corrupted_observations)
corrupted_dictionary = next(itertools.islice(updates, ITERATIONS, None))
corrupted_dictionary = corrupted_dictionary.T
corrupted_encoding = corrupted_dictionary.T @ corrupted_observations

In [41]:
norms = np.abs(corrupted_encoding)
norms = np.sum(norms, axis=1)
all_indices = list(range(len(norms)))
all_indices.sort(key=lambda num: norms[num], reverse=True)

sum_signs = np.sum(corrupted_encoding, axis=1)
sum_signs = np.sign(sum_signs)

ROWS, COLS = 12, 16

fig, axs = plt.subplots(ROWS, COLS, figsize=(64, 48))
plt.subplots_adjust(left=None, right=None, bottom=None, top=None, wspace=0.05, hspace=0.05)
for index, ax in zip(all_indices, axs.flat):
    base = corrupted_dictionary[:,index] * sum_signs[index]
    base = base - base.min()
    base = base / base.max()
    base = patch_to_image(base, PATCH_SIZE, CHANNELS)
    
    ax.imshow(base)
    ax.axis('off')
fig.savefig('05-corrupted-cifar-bases.pdf', bbox_inches='tight')
plt.close()

In [42]:
### OUTLIERS CIFAR BASES

TAU = 0.2
MU = np.mean(observations)
STDEV = 0.5 * np.std(observations)

num_outliers = int(TAU * len(observations.T))
noise_dimensions = (len(observations), num_outliers)
outlier_observations = np.hstack((observations, np.random.normal(MU, STDEV, noise_dimensions)))

updates = generator.get_dictionary_learning_iterates(outlier_observations)
outlier_dictionary = next(itertools.islice(updates, ITERATIONS, None))
outlier_dictionary = outlier_dictionary.T
outlier_encoding = outlier_dictionary.T @ outlier_observations

In [43]:
norms = np.abs(outlier_encoding)
norms = np.sum(norms, axis=1)
all_indices = list(range(len(norms)))
all_indices.sort(key=lambda num: norms[num], reverse=True)

sum_signs = np.sum(outlier_encoding, axis=1)
sum_signs = np.sign(sum_signs)

ROWS, COLS = 12, 16

fig, axs = plt.subplots(ROWS, COLS, figsize=(64, 48))
plt.subplots_adjust(left=None, right=None, bottom=None, top=None, wspace=0.05, hspace=0.05)
for index, ax in zip(all_indices, axs.flat):
    base = outlier_dictionary[:,index] * sum_signs[index]
    base = base - base.min()
    base = base / base.max()
    base = patch_to_image(base, PATCH_SIZE, CHANNELS)
    
    ax.imshow(base)
    ax.axis('off')
fig.savefig('05-outlier-cifar-bases.pdf', bbox_inches='tight')
plt.close()

## 5.2 Statistics

In [9]:
noise_stdev = 0.075
snr_ratio = observations.mean() / noise_stdev
print('SNR:', snr_ratio)

SNR: 6.2397774110002135


In [20]:
TOP_BASES = 20

norms = np.abs(encoding)
norms = np.sum(norms, axis=1)
clean_priorities = list(range(len(norms)))
clean_priorities.sort(key=lambda row: norms[row], reverse=True)

norms = np.abs(noisy_encoding)
norms = np.sum(norms, axis=1)
noisy_priorities = list(range(len(norms)))
noisy_priorities.sort(key=lambda row: norms[row], reverse=True)

In [22]:
base_angles = np.zeros((TOP_BASES, TOP_BASES))

for row in range(TOP_BASES):
    for col in range(TOP_BASES):
        base_angles[row][col] = np.abs(noisy_dictionary[:,noisy_priorities[row]] @ \
                                       dictionary[:,clean_priorities[col]])
        
top_angles = []
for index in range(TOP_BASES):
    top_angles.append(max(base_angles[index]))

values = np.percentile(top_angles, [0, 25, 50, 75, 100])
print(values)

[0.51467144 0.72026665 0.98916511 0.99979489 0.99999986]
