In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from tqdm.notebook import tqdm
import scipy.io as sio
import os

# Sparse coding 

Our model assumes that images, $I$, are encoded linearly by the patterns of neural activation, $\boldsymbol{a}$

$$
\boldsymbol{I}(\boldsymbol{x}) = \sum_i a_i \phi_i (\boldsymbol{x}) + \epsilon(\boldsymbol{x})= \Phi \boldsymbol{a} + \epsilon(\boldsymbol{x})
\tag{1}
$$

The _energy_ is given by

$$
E(\boldsymbol{a}, \Phi) = \underbrace{\left\|\boldsymbol{I}-\Phi \boldsymbol{a}\right\|^2}_{\text{preserve information}} + \lambda \underbrace{\sum_i S\left(\frac{a_i}{\sigma}\right)}_{\text{sparseness of}\ \boldsymbol{a}} \tag{2}
$$

Our goal is to find a set of basis functions and activations that minimize $E$ - in other words, that do a good jo constructing the images which keeping activations sparse. 

## Natural Images

In [10]:
# Downloading the datasets from O&F 1997

OF_URLS = ["http://www.rctn.org/bruno/sparsenet/IMAGES.mat", "http://www.rctn.org/bruno/sparsenet/IMAGES_RAW.mat"]

for file_url in OF_URLS:
    file_name = os.path.basename(file_url)
    if os.path.exists(file_name):
        print(f"{file_url} has already been downloaded.")
    else:
        !wget $file_url

--2021-05-11 09:46:49--  http://www.rctn.org/bruno/sparsenet/IMAGES.mat
Resolving www.rctn.org (www.rctn.org)... 208.113.160.104
Connecting to www.rctn.org (www.rctn.org)|208.113.160.104|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 20971720 (20M)
Saving to: ‘IMAGES.mat’


2021-05-11 09:54:10 (46.5 KB/s) - ‘IMAGES.mat’ saved [20971720/20971720]

--2021-05-11 09:54:11--  http://www.rctn.org/bruno/sparsenet/IMAGES_RAW.mat
Resolving www.rctn.org (www.rctn.org)... 208.113.160.104
Connecting to www.rctn.org (www.rctn.org)|208.113.160.104|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 20971720 (20M)
Saving to: ‘IMAGES_RAW.mat’


utime(IMAGES_RAW.mat): No such file or directory
2021-05-11 10:00:33 (53.8 KB/s) - ‘IMAGES_RAW.mat’ saved [20971720/20971720]



In [11]:
# Loading datasets from http://www.rctn.org/bruno/sparsenet/

mat_images = sio.loadmat('IMAGES.mat')
natural_imgs = mat_images['IMAGES']
mat_images_raw = sio.loadmat('IMAGES_RAW.mat')
natural_imgs_raw = mat_images_raw['IMAGESr']

FileNotFoundError: [Errno 2] No such file or directory: 'IMAGES.mat'

In [None]:
# Plot datasets

N = 10 # number of images
c = 5 # number of columns
r = N // c # number of rows

In [None]:
fig = plt.figure(figsize=(8, 2*r))
for i in range(N):
    plt.subplot(r, c, i+1)
    plt.imshow(natural_imgs_raw[:,:,i], cmap="gray")
    plt.axis("off")
plt.tight_layout()
fig.suptitle("Natural Images", fontsize=20)
plt.subplots_adjust(top=0.9) 

In [None]:
fig = plt.figure(figsize=(8, 2*r))
for i in range(N):
    plt.subplot(r, c, i+1)
    plt.imshow(natural_imgs[:,:,i], cmap="gray")
    plt.axis("off")
plt.tight_layout()
fig.suptitle("Whitened Natural Images", fontsize=20)
plt.subplots_adjust(top=0.9) 

In [None]:
# Importing a somewhat less natural image

OTHER_URLS = ["https://dz2cdn1.dzone.com/storage/temp/3542733-printed-circuit-boards.jpg"]

for file_url in OF_URLS:
    file_name = os.path.basename(file_url)
    if os.path.exists(file_name):
        print(f"{file_url} has already been downloaded.")
    else:
        !wget $file_url

In [None]:
# Loading "natural" images
import imageio
circuit_imgs_raw = imageio.imread("3542733-printed-circuit-boards.jpg")

In [None]:
fig = plt.figure(figsize=(8, 3))
for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.imshow(circuit_imgs_raw[:,:,i], cmap="gray")
    plt.axis("off")
plt.tight_layout()
fig.suptitle("\"Natural\" Images", fontsize=20)
plt.subplots_adjust(top=0.9) 

In [None]:
circuit_imgs = np.copy(circuit_imgs_raw)
w, h, n = circuit_imgs.shape 

for i in range(n):
    current_image = circuit_imgs[:, :, i]
    mean = np.mean(current_image)
    std = np.std(current_image)
    circuit_imgs[:, :, i] = (current_image - mean) / std

In [None]:
fig = plt.figure(figsize=(8, 3))
for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.imshow(circuit_imgs[:,:,i], cmap="gray")
    plt.axis("off")
plt.tight_layout()
fig.suptitle("\"Natural\" Images", fontsize=20)
plt.subplots_adjust(top=0.9) 

## Sparseness penalty 

In [None]:
x = np.linspace(-20, 20, 1000)
color_map = cm.cool 

fig, (axS, axP) = plt.subplots(ncols=2, figsize=(8,4))
axS.plot(x, np.abs(x), label=r"$|x|$", color=color_map(0))
axS.plot(x, np.log(1+x**2), label=r"$\ln(1+x^2)$", color=color_map(.3333))
axS.plot(x, 1-np.exp(-x**2), label=r"$1-\exp(-x^2)$", color=color_map(.6666))

axS.set_xlabel("x")
axS.set_ylabel(r"$S(x)$")
axS.legend()
axS.set_xlim(-5, 5)
axS.set_ylim(0, 5)

f1 = lambda c: np.e**(-np.abs(x))
f1_const = np.trapz(f1(x), x)
axP.plot(x, f1(x) / f1_const, label=r"$p \propto e^{-|x|}$", color=color_map(0))
f2 = lambda c: np.e**(-np.log(1+x**2))
f2_const = np.trapz(f2(x), x)
axP.plot(x, f2(x) / f2_const, label=r"$p \propto e^{-\ln(1+x^2)}$", color=color_map(.3333))
f3 = lambda c: np.e**(-(1-np.exp(-x**2)))
f3_const = np.trapz(f3(x), x)
axP.plot(x, f3(x) / f3_const, label=r"$p \propto e^{-1+\exp(-x^2)}$", color=color_map(.6666))

axP.set_xlabel("x")
axP.set_ylabel(r"$p(x)$")
axP.legend()
axP.set_xlim(-5, 5)
axP.set_ylim(0, 1)

fig.tight_layout()

## Olshausen & Field Model

Training the model requires the following steps

1. Start with a random set of decoding functions, $\phi_i$.
2. Find the patterns of activity that minimize E given the current $\phi_i$, $\min_a E$, for each image $I$. 
3. Improve the $\phi_i$ such that they lower the expected value of the energy for all images, $E(\min_a E)$
4. Repeat 2-3 until the $\phi_i$ and $\boldsymbol{a}$ converge. 


In [None]:
class OlshausenField1996Model:
    def __init__(self, num_inputs, num_units, batch_size,
                    thresh_type="soft",
                    nt_max=1000, eps=1e-2,
                    lr_r=1e-2, lr_Phi=1e-2, lmda=5e-3):
        self.lr_r = lr_r # learning rate of r
        self.lr_Phi = lr_Phi # learning rate of Phi
        self.lmda = lmda # regularization parameter

        self.nt_max = nt_max # Maximum number of simulation time
        self.eps = eps  # small value which determines convergence
        
        self.num_inputs = num_inputs
        self.num_units = num_units
        self.batch_size = batch_size

        assert thresh_type in ["soft", "ln"]
        self.thresh_type = thresh_type
        if self.thresh_type == "soft":
            self._spasity_func = lambda x: np.abs(x)
            self._thresh_func = self.soft_thresholding_func
        elif self.thesh_type == "ln":
            self._spasity_func = lambda x: np.ln(1 + x**2)
            self._thresh_func = self.ln_thresholding_func
        elif self.thesh_type == "cauchy":
            self._spasity_func = lambda x: np.abs(x)
            self._thresh_func = self.cauchy_thresholding_func

        # Weights
        Phi = np.random.randn(self.num_inputs, self.num_units).astype(np.float32)
        self.Phi = Phi * np.sqrt(1/self.num_units)

        # activity of neurons
        self.r = np.zeros((self.batch_size, self.num_units))
    
    def initialize_states(self):
        self.r = np.zeros((self.batch_size, self.num_units))
        
    def normalize_rows(self):
        self.Phi = self.Phi / np.maximum(np.linalg.norm(self.Phi, ord=2, axis=0, keepdims=True), 1e-8)

    # thresholding function of S(x)=|x|
    def soft_thresholding_func(self, x, lmda):
        return np.maximum(x - lmda, 0) - np.maximum(-x - lmda, 0)

    # thresholding function of S(x)=ln(1+x^2)
    def ln_thresholding_func(self, x, lmda):
        f = 9*lmda*x - 2*np.power(x, 3) - 18*x
        g = 3*lmda - np.square(x) + 3
        h = np.cbrt(np.sqrt(np.square(f) + 4*np.power(g, 3)) + f)
        two_croot = np.cbrt(2) # cubic root of two
        return (1/3)*(x - h / two_croot + two_croot*g / (1e-8+h))

    # thresholding function https://arxiv.org/abs/2003.12507
    def cauchy_thresholding_func(self, x, lmda):
        f = 0.5*(x + np.sqrt(np.maximum(x**2 - lmda,0)))
        g = 0.5*(x - np.sqrt(np.maximum(x**2 - lmda,0)))
        return f*(x>=lmda) + g*(x<=-lmda) 

    def calculate_error(self, inputs):
        error = inputs - self.r @ self.Phi.T
        return(error)

    def calculate_total_error(self, error):
        recon_error = np.mean(error**2)
        sparsity_r = self.lmda*np.mean(self._spasity_func(self.r)) 
        return(recon_error + sparsity_r)

    def update_r(self, inputs):
        error = self.calculate_error(inputs)
        r = self.r + self.lr_r * error @ self.Phi
        self.r = self._thresh_func(r, self.lmda)
        return(error)

    def update_Phi(self, inputs):
        error = self.calculate_error(inputs)
        dPhi = error.T @ self.r
        self.Phi += self.lr_Phi * dPhi
        return(error)
    
    def train(self, inputs):
        self.initialize_states() # Reset states
        self.normalize_rows() # Normalize weights
        
        # Input an image patch until latent variables are converged 
        r_tm1 = self.r # set previous r (t minus 1)
        for t in range(self.nt_max):
            # Update r without updating weights 
            error = self.update_r(inputs)
            dr = self.r - r_tm1 

            # Compute norm of r
            dr_norm = np.linalg.norm(dr, ord=2) / (self.eps + np.linalg.norm(r_tm1, ord=2))
            r_tm1 = self.r # update r_tm1
            
            # Check convergence of r, then update weights
            if dr_norm < self.eps:
                error = self.update_r(inputs)
                error = self.update_Phi(inputs)
                break
            
            # If failure to convergence, break and print error
            if t >= self.nt_max-2: 
                print("Error at patch:", iter_)
                print(dr_norm)
                break
        return(error)

In [None]:
def generate_patches(input_images, patch_size, batch_size):
    H, W, num_images = input_images.shape
    
    # Set the coordinates of the upper left corner of random image sz X sz image clips
    x0 = np.random.randint(0, W-patch_size, batch_size)
    y0 = np.random.randint(0, H-patch_size, batch_size)

    # Generating inputs
    patches_list = []
    for i in range(batch_size):        
        idx = np.random.randint(0, num_images)
        img = input_images[:, :, idx]
        clip = img[y0[i]:y0[i]+patch_size, x0[i]:x0[i]+patch_size].flatten()
        patches_list.append(clip - np.mean(clip))
        
    patches = np.array(patches_list) # Input image patches
    return(patches)

In [None]:
# Simulation constants
num_units = 100 # number of neurons (units)
patch_size = 16 # image patch size

num_iter = 500 # number of iterations
batch_size = 250 # Batch size

lmda = 5e-3 # Sparisty weight

# Image set
image_set = natural_imgs

# Define model
model = OlshausenField1996Model(num_inputs=patch_size**2, 
                                num_units=num_units,
                                batch_size=batch_size,
                                lmda=lmda)

In [None]:
# Run simulation
error_list = [] # List to save errors
for iter_ in tqdm(range(num_iter)):
    patches = generate_patches(image_set, patch_size, batch_size) # Generating image patches
    error = model.train(patches) # train model with patches 

    error_list.append(model.calculate_total_error(error))
    # Print moving average error
    if iter_ % 100 == 99:  
        print("iter: "+str(iter_+1)+"/"+str(num_iter)+", Moving error:",
              np.mean(error_list[iter_-99:iter_]))

In [None]:
# Plot error
plt.figure(figsize=(6, 4))
plt.ylabel("Error")
plt.xlabel("Iterations")
plt.plot(np.arange(len(error_list)), np.array(error_list))
plt.tight_layout()
plt.show()

In [None]:
def plot_receptive_fields(fields, c=10, title="Receptive Fields"):
    num_units, num_inputs = fields.shape
    patch_size = int(np.sqrt(num_inputs))
    r = num_units // c

    fig = plt.figure(figsize=(6, .6*(r+4)))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i in tqdm(range(num_units)):
        plt.subplot(r, c, i+1)
        plt.imshow(np.reshape(fields[i], (patch_size, patch_size)), cmap="gray")
        plt.axis("off")
    plt.tight_layout()
    fig.suptitle(title, fontsize=20)
    plt.subplots_adjust(top=0.9)

In [None]:
plot_receptive_fields(model.Phi.T)

## What else could we compare it to?

In [None]:
from sklearn.decomposition import FastICA, PCA
patches = generate_patches(image_set, patch_size, num_iter*batch_size)

### PCA

Principal components analysis (PCA) does not care about sparseness per se. Instead it aims to find receptive fields (basis functions) that captures the most variability in the images. The first basis function captures the most, the second basis function catures the second most, etc...

In [None]:
# perform PCA
pca = PCA(n_components=num_units)
pca.fit(patches)

plot_receptive_fields(pca.components_, title="PCA Receptive Fields")

### ICA

Independant component analysis (ICA) is an approach which attempts to find receptive fields (basis functions) that result in _statistically independant_ activations. In other words 

$$ p(\boldsymbol{a}) = \prod_i p(a_i) $$

In [None]:
# perform ICA
ica = FastICA(n_components=num_units)
ica.fit(patches)

plot_receptive_fields(ica.components_, title="ICA Receptive Fields")