# Compressed sensing on images

This notebook intends to illustrate the application of compressed sensing to images.

In particular, given an undersampled image $y$ (on the left), the objective is to reconstruct the original image $x$ (on the right):

![](extra/compressed_sensing.png)

This notebook applies compressed sensing as in [compressed_sensing.ipynb](compressed_sensing.ipynb) but, in this case, to a 2D problem.

Another difference is that in this notebook the basis matrix $\Psi$ is defined as a collection of images defined as a column vectors instead of the Discrete Cosine Transform matrix.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from numpy import ma
from scipy.sparse import dok_array
import torch
from icecream import ic
from torchvision.datasets import MNIST
from torchvision.transforms import Compose
from torchvision.transforms import Resize
from torchvision.transforms import ToTensor
from torchvision.transforms import Grayscale
from torchvision.utils import make_grid

from optimize import cosamp

In [None]:
DATA_LABELS = (9, 1, 3)
TARGET_LABEL = DATA_LABELS[0]

## Data

In [None]:
# data
img_size = 128
n_channels = 1
transform = Compose([Resize(img_size), ToTensor(), Grayscale()])
train_dataset = MNIST('data', train=True, transform=transform, download=True)
test_dataset = MNIST('data', train=False, transform=transform, download=True)

## Problem definition

In [None]:
n = img_size * img_size * n_channels
p = n
i = 0
signal_dim = n
undersampled_dim = int(n * 0.1)  # measure 10%
dictionary_size = 2048

## Signal
$x$

In [None]:
# sample image
i = 0
while True:
    test_image, target = test_dataset[i]
    if target == TARGET_LABEL:
        break
    i += 1

In [None]:
plt.imshow(test_image.permute(1, 2, 0))
plt.title('original image')

In [None]:
x = test_image.view(-1).numpy()

## Sample matrix
$C$

In [None]:
def sample_matrix(undersampled_dim, signal_dim):
    return np.eye(signal_dim)[np.random.randint(signal_dim, size=undersampled_dim)]

C = sample_matrix(undersampled_dim, signal_dim)
ic(np.count_nonzero(C))
plt.imshow(C)

## Undersampled signal
$y = C x$

In [None]:
y = C @ x
_coord = C @ np.arange(n)
plt.plot(x, alpha=0.2, label='original signal')
plt.plot(_coord, y, '.', c='orange', label='undersampled signal')
plt.xlabel('pixels')
plt.ylabel('values')
plt.legend()

In [None]:
# sampling mask
ii, jj = np.meshgrid(np.arange(img_size), np.arange(img_size), indexing='ij')
ii = C @ ii.flatten()
jj = C @ jj.flatten()
mask = dok_array((img_size, img_size), dtype=bool)
mask[ii, jj] = True

# plot
subsampled_image = ma.masked_array(
    x.reshape(img_size, img_size),
    mask=~mask.todense()
)
plt.imshow(subsampled_image)
plt.title('undersampled image')

## The problem
As mentioned at the beggining, given an undersampled image $y$ (on the left), the objective is to reconstruct the original image $x$ (on the right):

In [None]:
_, (ax1, ax2) = plt.subplots(ncols=2)

ax1.imshow(subsampled_image)
ax1.set_title('undersampled image $y$')
ax2.imshow(test_image.permute(1, 2, 0))
ax2.set_title('original image $x$')
plt.tight_layout()

## Dictionary
$\Psi$

In [None]:
#  dictionary images
dictionary_images = []
indices = []
i = 0
while len(dictionary_images) < dictionary_size:
    image, target = train_dataset[i]
    if target in DATA_LABELS:
        dictionary_images.append(image)
        indices.append(i)
    i += 1

In [None]:
# preview
grid = make_grid(dictionary_images[:100], nrow=10)
ic(grid.shape)
plt.imshow(grid[0], cmap='viridis')
plt.show()

In [None]:
Psi = torch.stack(dictionary_images, dim=-1).view(n_channels * img_size * img_size, -1).numpy()
ic(Psi.shape)

In [None]:
plt.imshow(Psi)
plt.xlabel('images')
plt.ylabel('pixels')

## Theta

In [None]:
Theta = C @ Psi
plt.imshow(Theta)

## Find sparse representation $s$ of $y$

Find $s$ by solving:

$min ||s||_1$ s.t. $y = \Theta s$

In [None]:
# optimize
sparsity = 5  # number of sparse elements in solution
s = cosamp(Theta, y, sparsity, max_iter=10000)
ic(s)

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2, nrows=1)

ax1.plot(s, 'red', alpha=.5)
ax1.plot(s, '.')
ax1.set_title('s')
ax2.hist(s)
ax2.set_title('histogram of s')
plt.tight_layout()

### Check $s$ quality

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.plot(y, label='$y$')
ax1.plot(Theta @ s, '.', label='$\Theta s$')
ax1.set_title('$s$ vs. $\Theta s$')
ax1.legend()

ax2.plot(Theta @ s - y)
ax2.set_title('difference')

## Recover $x$

This is done by computing $x = \Psi s$ using the obtained sparse vector $s$.

In [None]:
x_r = Psi @ s
_coord = C @ np.arange(n)
plt.plot(x, alpha=0.5, label='original signal')
plt.plot(_coord, y, '.', c='orange', label='undersampled signal')
plt.plot(x_r, c='green', label='recovered signal')
plt.legend()

In [None]:
_x = x.reshape(img_size, img_size)
_x_r = x_r.reshape(img_size, img_size)

fig, (ax1, ax2) = plt.subplots(ncols=2)

img = ax1.imshow(_x) #, cmap='Grays_r')
ax1.set_title('y')
fig.colorbar(img)

img = ax2.imshow(_x_r)#, cmap='Grays_r')
ax2.set_title('reconstructed y')
fig.colorbar(img)

plt.tight_layout()

In [None]:
_, _idx = torch.topk(torch.tensor(np.abs(s)), sparsity)
ic(_idx)
grid = make_grid([torch.tensor(Psi[:, i].reshape(img_size, img_size)).unsqueeze(0) for i in _idx], nrow=sparsity // 2)
ic(grid.shape)
# plt.imshow(grid.permute(1, 2, 0), cmap='viridis')
plt.imshow(grid[0], cmap='viridis')
plt.show()