#### Sandbox Demo for CSC

In [6]:
import numpy as np
from numpy import linalg as LA
import matplotlib.pyplot as plt

import random
import os
import yaml

import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
import sparse_coding_classifier_functions as scc

from AuxiliaryFunctions import showFilters

%pdb on

Automatic pdb calling has been turned ON


Defining Model Class and functions:

In [13]:
# CSC CLASSES AND CONSISTENCY FUNCTIONS
class SL_CSC_FISTA(nn.Module):
    def __init__(self, stride=1, dp_channels=1, atom_r=1, atom_c=1, numb_atom=1, tau=1, T_SC=1, T_PM=1):
        super(SL_CSC_FISTA, self).__init__()
        self.D_trans = nn.Conv2d(dp_channels, numb_atom, (atom_r, atom_c), stride, padding=0, dilation=1, groups=1, bias=False)
        self.D = nn.ConvTranspose2d(numb_atom, dp_channels, (atom_c, atom_r), stride, padding=0, output_padding=0, groups=1, bias=False, dilation=1)
        self.normalise_weights()
        self.D_trans.weight.data = self.D.weight.data.permute(0,1,3,2)
        self.tau = tau
        self.T_SC = T_SC
        self.T_PM = T_PM
        self.forward_type = 'FISTA'

    def forward(self, Y):
    # Initialise t variables needed for FISTA
        t1 = 1
        # Initialise X1 Variable - note we need X1 and X2 as we need to use the prior two prior estimates for each update
        y_dims = list(Y.data.size())
        w_dims = list(self.D_trans.weight.data.size())
        # Initialise our guess for X
        X1 = Variable(torch.rand(y_dims[0], w_dims[0], (y_dims[2]-w_dims[2]+1),(y_dims[3]-w_dims[3]+1)))
        # Calculate first update
        X2, FISTA_error, alpha = self.linesearch(Y,X1)

        #raise Exception()

        # computing max eigenvalue with power method
#         for i in range(T_pm):
#             X1 = self.D_trans(self.D(X))
#             X1 = X1./norm(X1)
        
        for i in range(self.T_SC):
            # Update t variables
            t2 = (1 + np.sqrt(1+4*(t1**2)))/2 
            # Update Z
            Z = X2 + (X2-X1)*(t1 - 1)/t2


            # Update variables for next iteration
            X1 = X2.clone() #untoggle
            t1 = t2
            # Print at intervals to present progress
            if i==0 or (i+1)%5 == 0:
                av_num_zeros_per_image = X2.data.nonzero().numpy().shape[0]/y_dims[0]
                percent_zeros_per_image = 100*av_num_zeros_per_image/(y_dims[2]*y_dims[3])
                l2_error = np.sum((Y-self.D(X2)).data.numpy()**2)
                l1_error = np.sum(np.abs(X2.data.numpy()))
                # pix_error = l2_error/(y_dims[0]*y_dims[2]*y_dims[3])
                error_percent = l2_error*100/(np.sum((Y).data.numpy()**2))
                # print("Iteration: "+repr(i) + ", l2 error:{0:1.2f}".format(l2_error) + ", l1 error: {0:1.2f}".format(l1_error) + ", l2 error percent: {0:1.2f}".format(error_percent)+ "%, Total FISTA error: {0:1.2f}".format(FISTA_error) + ", Av. sparsity: {0:1.2f}".format(percent_zeros_per_image) +"%")
                print("After " +repr(i+1) + " iterations of FISTA, average l2 error over batch: {0:1.2f}".format(error_percent) + "% , Av. sparsity per image: {0:1.2f}".format(percent_zeros_per_image) +"%")
        return X2


    def reverse(self, X):
        out = self.D(X)
        return out

# ------------- ------------- ------------- ------------- ------------- ------------- ------------- -------------
def linesearch(self,Y,X):
		# Define search parameter for Armijo method
		c = 0.5
		alpha = 1
		g = self.D_trans(Y-self.D(X))
		ST_arg = X + alpha*g
		X_update = soft_thresh(ST_arg, self.tau*alpha)
		# Calculate cost of current X location
		l1_error = np.sum(np.abs(X.data.numpy()))
		l2_error = np.sum((Y-self.D(X)).data.numpy()**2)
		current_cost = l2_error + self.tau*l1_error
		# print("Cost at the beginning of the linesearch: {0:1.2f}".format(current_cost)+", l2 error:{0:1.2f}".format(l2_error) + ", l1 error: {0:1.2f}".format(l1_error))
		# Calculate the cost of the updated position
		update_cost = np.sum((Y-self.D(X_update)).data.numpy()**2) + self.tau*np.sum(np.abs(X.data.numpy()))
		# While the cost at the next location is higher than the current one iterate
		count = 0
		while update_cost >= current_cost and count<=15:
			alpha = alpha*c
			ST_arg = X + alpha*g
			X_update = soft_thresh(ST_arg, self.tau*alpha)
			l1_error = np.sum(np.abs(X_update.data.numpy()))
			l2_error = np.sum((Y-self.D(X_update)).data.numpy()**2)
			update_cost = l2_error + self.tau*l1_error
			count +=1
		# print("Cost at the end of the linesearch: {0:1.2f}".format(update_cost)+ ", l2 error:{0:1.2f}".format(l2_error) + ", l1 error: {0:1.2f}".format(l1_error))
		return X_update, update_cost, alpha

    
    
	def normalise_weights(self):
		print("Normalising kernels")
		filter_dims = list(np.shape(self.D.weight.data.numpy()))
		for i in range(filter_dims[0]):
			for j in range(filter_dims[1]):
				l2_norm = ((np.sum(self.D.weight.data[i][j].numpy()**2))**0.5)
				if l2_norm > 10^(-7): 
					self.D.weight.data[i][j] = self.D.weight.data[i][j]/l2_norm
				else:
					print("Kernel with 0 l2 norm identified, setting to zero")
					self.D.weight.data[i][j] = torch.zeros(filter_dims[2], filter_dims[3])
                    

# ------------- ------------- ------------- ------------- ------------- ------------- ------------- ------------- -------------

def train_SL_CSC(CSC, train_loader, num_epochs, T_DIC, cost_function, optimizer, batch_size):	
	print("Training SL-CSC. Batch size is: " + repr(batch_size))
	# Initialise variables needed to plot a random sample of three kernels as they are trained
	filter_dims = list(np.shape(CSC.D_trans.weight.data.numpy()))
	idx = random.sample(range(0, filter_dims[0]), 3)

	for epoch in range(num_epochs):
# 		print("Training epoch " + repr(epoch+1) + " of " + repr(num_epochs))
		for i, (inputs, labels) in enumerate(train_loader):
# 			print("Batch number " + repr(i+1))
            
			inputs = Variable(inputs)
			labels = Variable(labels)
			# Calculate and update step size for sparse coding step
			input_dims = list(inputs.size())
			# Fix dictionary and calculate sparse code
			if CSC.forward_type == 'FISTA_fixed_step':
				CSC.calc_L(input_dims)
			X = CSC.forward(inputs)
			# Fix sparse code and update dictionary
			print("Running dictionary update")
			for j in range(T_DIC):
				# Zero the gradient
				optimizer.zero_grad()
				# Calculate estimate of reconstructed Y
				inputs_recon = CSC.reverse(X)
				# Calculate loss according to the defined cost function between the true Y and reconstructed Y
				loss = cost_function(inputs_recon, inputs)
				# Calculate the gradient of the cost function wrt to each parameters
				loss.backward()
				# Update each parameter according to the optimizer update rule (single step)
				optimizer.step()
				# At the end of each batch plot a random sample of kernels to observe progress
				if j==0 or (j+1)%20 == 0:
					print("Average loss per data point at iteration {0:1.0f}".format(j+1) + " of SGD: {0:1.4f}".format(np.asscalar(loss.data.numpy())))
					plt.figure(1)
					plt.subplot(1,3,1)
					plt.imshow((CSC.D.weight[idx[0]][0].data.numpy()), cmap='gray')
					plt.title("Filter "+repr(idx[0]))
					plt.subplot(1,3,2)
					plt.imshow((CSC.D.weight[idx[1]][0].data.numpy()), cmap='gray', )
					plt.title("Filter "+repr(idx[1]))
					plt.xlabel("Epoch Number: " + repr(epoch)+ ", Batch number: " + repr(i+1) + ", Average loss: {0:1.4f}".format(np.asscalar(loss.data.numpy())))
					plt.subplot(1,3,3)
					plt.imshow((CSC.D.weight[idx[2]][0].data.numpy()), cmap='gray')
					plt.title("Filter "+repr(idx[2]))
					plt.draw()
					plt.pause(0.001)			
			
			l2_error_percent = 100*np.sum((inputs-CSC.D(X)).data.numpy()**2)/ np.sum((inputs).data.numpy()**2)
			print("After " +repr(j+1) + " iterations of SGD, average l2 error over batch: {0:1.2f}".format(l2_error_percent) + "%")
			# Normalise each atom / kernel
			CSC.normalise_weights()
			# Ensure that weights for the reverse and forward operations are consistent	
			CSC.D_trans.weight.data = CSC.D.weight.data.permute(0,1,3,2)
	# Return trained CSC
	return CSC

IndentationError: expected an indented block (<ipython-input-14-36f198c8fc2d>, line 46)

Defining Paramters and loading data

In [None]:
# Training hyperparameters
num_epochs = 1 #100
batch_size = 256
T_SC = 50
T_DIC = 10
T_PM = 8
stride = 1
learning_rate = 3
momentum = 0.9
weight_decay=0

# Weight importance of sparsity vs. reconstruction
tau = 0.9

# Local dictionary dimensions
atom_r = 28
atom_c = 28
numb_atom = 100
dp_channels = 1 

# Load MNIST
root = './data'
download = False  # download MNIST dataset or not

# Access MNIST dataset and define processing transforms to proces
# trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
trans = transforms.Compose([transforms.ToTensor()])
train_set = dsets.MNIST(root=root, train=True, transform=trans, download=download)
test_set = dsets.MNIST(root=root, train=False, transform=trans)

idx = list(range(10000))
train_sampler = SubsetRandomSampler(idx)

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=batch_size,
                 sampler = train_sampler,# None
                 shuffle=False) #True


test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=batch_size,
                shuffle=False)

# Intitilise Convolutional Sparse Coder CSC
CSC = scc.SL_CSC_FISTA(stride, dp_channels, atom_r, atom_c, numb_atom, tau, T_SC, T_PM)

# Define optimisation parameters
CSC_parameters = [{'params': CSC.D.parameters()}]

# Define training settings/ options
cost_function = nn.MSELoss(size_average=True)
optimizer = torch.optim.SGD(CSC_parameters, lr=learning_rate, momentum=momentum, weight_decay=weight_decay, nesterov=True)
# optimizer = torch.optim.Adam(SSC.parameters(), lr=learning_rate)

Training!

In [None]:
CSC = scc.train_SL_CSC(CSC, train_loader, num_epochs, T_DIC, cost_function, optimizer, batch_size)