In [None]:
%matplotlib inline
# import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
# import videoUtility
import numpy.linalg as la
import scipy.io

import sys

# import sparsify
import sparsify_PyTorch
import utility

import torch
import torch.optim as optim
from torchvision import datasets
from torchvision.transforms.v2 import ToTensor
from torch.utils.data import DataLoader, Dataset


In [None]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

dataset_size = 1000
# take first <dataset_size> images as a demo dataset
images = training_data.data[:dataset_size] # <dataset_size> x 28 x 28


In [None]:
utility.imshow(images[26,:,:])

In [None]:
torch.cuda.set_device(0) #use GPU 1
# Now let's start to learn sparse coding basis
# Effective dimensionality is about 200, let's make it 20 times overcomplete.
# Layer1 sparse coding initialization

xdim = 5 #Patch size
ydim = 5 #Patch size
BASIS1_NUM = 2048
BASIS1_SIZE = [xdim*ydim, BASIS1_NUM]
BATCH_SIZE = 20

basis1 = torch.randn(BASIS1_SIZE).cuda()
basis1.div_(basis1.norm(2,0)) # so every base vector's norm = 1

lambd = 1.0
STEPS = 30000

ACT_HISTORY_LEN = 300
HessianDiag = torch.zeros(BASIS1_NUM).cuda()
ActL1 = torch.zeros(BASIS1_NUM).cuda()
signalEnergy = 0.
noiseEnergy = 0.

edgeBuff = 2
spRange_t = images.shape[0]
spRange_x = images.shape[1] - xdim - edgeBuff * 2
spRange_y = images.shape[2] - ydim - edgeBuff * 2
I = np.zeros([xdim*ydim,BATCH_SIZE]).astype('int')
totalSteps1 = 0

In [None]:
for i in range(totalSteps1,STEPS):
    for j in range(BATCH_SIZE):
        sIdx = np.floor(np.random.rand()*spRange_t).astype(int)
        xIdx = np.floor(np.random.rand()*spRange_x + edgeBuff).astype(int)
        yIdx = np.floor(np.random.rand()*spRange_y + edgeBuff).astype(int)
        I[:,j] = images[sIdx,xIdx:xIdx+xdim,yIdx:yIdx+ydim].reshape([xdim*ydim])
    I_cuda = torch.from_numpy(I).cuda()
    
    #Sparse Coefficients Inference by ISTA
    #For positive-only codes, use ISTA
    #For positive-negative codes, use ISTA_PN 
    ahat, Res = sparsify_PyTorch.ISTA_PN(I_cuda, basis1, 0.08, 1000)
    #ahat, Res = sparsify_PyTorch.ISTA(I_cuda, basis1, 0.03, 1000)
    
    #Statistics Collection
    ActL1 = ActL1.mul((ACT_HISTORY_LEN-1.0)/ACT_HISTORY_LEN) + ahat.abs().mean(1)/ACT_HISTORY_LEN
    HessianDiag = HessianDiag.mul((ACT_HISTORY_LEN-1.0)/ACT_HISTORY_LEN) + torch.pow(ahat,2).mean(1)/ACT_HISTORY_LEN
    
    signalEnergy = signalEnergy*((ACT_HISTORY_LEN-1.0)/ACT_HISTORY_LEN) + torch.pow(I_cuda,2).sum()/ACT_HISTORY_LEN
    noiseEnergy = noiseEnergy*((ACT_HISTORY_LEN-1.0)/ACT_HISTORY_LEN) + torch.pow(Res,2).sum()/ACT_HISTORY_LEN
    snr = signalEnergy/noiseEnergy
    
    #Dictionary Update
    totalSteps1 = totalSteps1 + 1
    basis1 = sparsify_PyTorch.quadraticBasisUpdate(basis1, Res, ahat, 0.001, HessianDiag, 0.005)
    
    #Print Information
    if i % 100 == 0:
        print(totalSteps1, snr, HessianDiag.min(), HessianDiag.max(), ActL1.min(), ActL1.max(), ActL1.sum())

In [None]:
#Dictionary Visualization
basis1_host = basis1.cpu().numpy()
fig = plt.figure(figsize = (20,20))
ax = fig.gca()
utility.displayVecArry(basis1_host,32,32,ax=ax,title=i,equal_contrast=True) #Visualize first 1024 Dictionary Elements
fig = plt.figure(figsize = (20,20))
ax = fig.gca()
utility.displayVecArry(basis1_host[:,1024:],32,32,ax=ax,title=i,equal_contrast=True) #Visualize first 1024 Dictionary Elements

In [None]:
np.savez("basis1_IMAGES_Vanhateren_10x.npz", basis1 = basis1_host)