## <font color='green'> <div align="center">In the name of God </div></font>

### <font color='red'> Author: Sayed Kamaledin Ghiasi-Shrirazi <a href="http://profsite.um.ac.ir/~k.ghiasi">(http://profsite.um.ac.ir/~k.ghiasi)</a> </font>

# Experiments on the MNIST dataset with CCE (Competitive Cross Entropy)

### This is a python implementation of the *competitive cross entropy* algorithm introduced in the following paper:

### Ghiasi-Shirazi, K. Competitive Cross-Entropy Loss: A Study on Training Single-Layer Neural Networks for Solving Nonlinearly Separable Classification Problems. Neural Process Lett 50, 1115–1122 (2019). https://doi.org/10.1007/s11063-018-9906-5

##### Importing general modules.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sklearn.cluster
import scipy.io as sio
import time
import matplotlib as mpl
import os.path
import pickle

##### Importing modules wtitten by the author

In [None]:
from CCETrainingDataPreparation import TrainingData
from CompetitiveCrossEntropy import CompetitiveCrossEntropy

##### Preparing matplotlib to genrate high-quality images for the paper

In [None]:
%matplotlib inline
mpl.rcParams['figure.dpi']= 600

##### Set K=6 as in <a href="https://link.springer.com/article/10.1007/s11063-018-9906-5">Competitive Cross-Entropy Paper.</a>

In [None]:
C  = 10
K = 6
maxVqIteration = 100
L = C * K
width = 28
height = 28
N = None
NTest = None
learning_rate = 0.01
# note that the weight-decay is multiplied by learning rate
weight_decay = 0.0001
max_epochs = 50
noise_std = 0
lr_decay_mult  = 0.95
exp_no = 1

##### Load Training and Testing Data and Normalize them

In [None]:
XTrain = sio.loadmat ('./mnist/MnistTrainX')['MnistTrainX']
yTrain = sio.loadmat ('./mnist/MnistTrainY')['MnistTrainY']
XTrain = XTrain / 255.0
if N:
    XTrain = XTrain[:N,:]
    yTrain = yTrain[:N]
    
XTest = sio.loadmat ('./mnist/MnistTestX')['MnistTestX']
yTest = sio.loadmat ('./mnist/MnistTestY')['MnistTestY']
XTest = XTest / 255.0

if NTest:
    XTest = XTest[:NTest,:]
    yTest = yTest[:NTest]

##### Prepare data

In [None]:
np.random.seed(1)
td = TrainingData(XTrain, yTrain)
filename = 'mnist_cce_clusters.pickle'
if os.path.isfile(filename):
    with open(filename, 'rb') as file:
        subclassMeans = pickle.load(file)
    td.setSubclasses(subclassMeans)
else:
    clusAlg = sklearn.cluster.KMeans()
    clusAlg.max_iter = maxVqIteration
    start = time.time()
    td.findSubclasses(K, clusAlg)
    end = time.time()
    print ('Time for clustering: ', end - start)
    with open(filename, 'wb') as file:
        pickle.dump(td.subclassMeans, file)

##### Show Clustering Result

In [None]:
img = np.ones ([1+K*(height+1), 1+td.C * (width+1),3])
for c in range(td.C):
    for k in range (td.K):
        idx1 = k*(height+1)+1
        idx2 = c*(width+1)+1
        img[idx1:idx1+height, idx2:idx2+width,0] = td.subclassMeans[c*td.K+k].reshape ([height,width])
        img[idx1:idx1+height, idx2:idx2+width,1] = img[idx1:idx1+height, idx2:idx2+width,0]
        img[idx1:idx1+height, idx2:idx2+width,2] = img[idx1:idx1+height, idx2:idx2+width,0]
plt.axis('off')
plt.imshow (img)
fn = 'mnist_clustering_cce_exp{}'.format(exp_no) + '.png'
plt.imsave(fn, img)        
plt.show()

##### The function for performing an experiment. 

In [None]:
class empty_class:
    pass
max_epochs_one = 1
cce = CompetitiveCrossEntropy(td, learning_rate, lr_decay_mult, max_epochs_one, weight_decay, noise_std)

start = time.time()
if not os.path.exists('./pickle'):
    os.mkdir('./pickle')
if not os.path.exists('./png'):
    os.mkdir('./png')    
for epoch in range (max_epochs):
    cce.fit()
    filename = './pickle/cce-epoch-{}-exp-{}.pickle'.format(epoch, exp_no)
    with open(filename, 'wb') as file:
        cce_small = empty_class()
        cce_small.W = cce.W
        cce_small.b = cce.b
        pickle.dump(cce_small, file)
    yHat = cce.classifyByMaxClassifier(XTest)
    yHat = np.array(yHat, dtype='int')
    outVal = sklearn.metrics.accuracy_score(yTest, yHat)
    print('Test classification accuracy: ' + str(outVal))
    img = cce.GenerateImagesOfWeights(width, height, color = 'gray', n_images=1, rows=K, cols = C, eps=0.1)
    plt.axis('off')
    plt.imshow (img[0])
    fn = './png/mnist_cce_epoch_{}_acc_{}_exp_{}'.format(epoch, outVal,exp_no) + '.png'
    plt.imsave(fn, img[0])       
    plt.show()    
    
end = time.time()
print ('cca.fit took time: ', end - start)

### <font color='red'> Author: Sayed Kamaledin Ghiasi-Shrirazi <a href="http://profsite.um.ac.ir/~k.ghiasi">(http://profsite.um.ac.ir/~k.ghiasi)</a> </font>