In [1]:
import numpy as np
import time
import random

In [2]:
import findspark
findspark.init() 

import pyspark
sc = pyspark.SparkContext()

In [3]:
from mnist import MNIST
mndata = MNIST('/Users/dcusworth/Desktop/mnist/MNIST/python-mnist/data')
images, labels = mndata.load_training()

In [4]:
#Build feature map
N = 1000 #How many images I want to load
d = 784 #Pixels of MNIST data
    
#label_func = lambda x,choose_label: [1 if la == choose_label else -1 for la in x]
def label_func(x, choose_label):
    if x == choose_label:
        return 1
    else:
        return -1

In [8]:
#Retrieve data and labels - do preprocessing
y_labs = labels[0:N]

#Loop over set of regularization parameters
vaccs = []
lambdas = [10**q for q in np.linspace(-5,5,10)]

#Load images
feature_map = np.zeros((N,d))
for i in range(N): #Just do a subset of training for now
    feature_map[i,:] = images[i]

#Start spark instance on points
#Take train test split
sinds = range(N)
random.shuffle(sinds)
tint = int(.8*N)
tind = sinds[0:tint]
vind = sinds[tint:-1]

#Center - i.e. remove mean image
fpoints = sc.parallelize(feature_map)
fmean = fpoints.map(lambda x: x).reduce(lambda x,y: (x+y) ) / float(N)
x_c = fpoints.map(lambda x: x-fmean).collect()

#Create Spark context for feature matrix
x_t = sc.parallelize([xx for idx,xx in enumerate(x_c) if idx in tind])
xtb = sc.broadcast(x_t.collect())
x_v = sc.parallelize([xx for idx,xx in enumerate(x_c) if idx in vind])

start = time.time()
for ll in lambdas:

    ws = []
    iouts = []
    classes = []
    
    #Get denominator - depends on lambda/regularization and not label
    denom_map = x_t.map(lambda x: np.dot(x, x.T) + N*ll) 
    denom_sum = denom_map.reduce(lambda x,y: x+y)

    ### Loop over all labels
    for choose_label in range(10): 

        #Make Spark contexts for certain label/lambda
        y_label = [label_func(q,choose_label) for q in y_labs]
        tpoints = sc.parallelize(zip([yy for idx,yy in enumerate(y_labs) if idx in tind], xtb.value))
        vpoints = sc.parallelize([yy for idx,yy in enumerate(y_labs) if idx in vind])
        y_val = vpoints.map(lambda x:x).collect()

        #Analytical solution to problem for certain label
        #Do numerator first - requires label data
        numer_map = tpoints.map(lambda x:x[1] * (label_func(x[0],choose_label)))
        numer_sum = numer_map.reduce(lambda x,y: x+y)

        #Use previously computed denominator 
        iw = numer_sum / float(denom_sum)

        #Test on validation set
        ires = x_v.map(lambda x:np.dot(x,iw))
        iout = ires.collect()
        iclass = ires.map(lambda x: np.sign(x)).collect()

        #Append to output  - Add MPI communication or further spark-ize
        ws.append(iw)
        iouts.append(iout)
        classes.append(iclass)

    #Collect all digit predictions
    out_pred = zip(*iouts)

    #Make prediction among all digits
    preds = []
    for idx in range(len(out_pred)):
        ipreds = np.asarray(out_pred[idx])
        iclass = np.where(ipreds == np.max(ipreds))[0][0] 
        preds.append(iclass)

    #Determine accuracy on validation
    vacc = np.sum([y == p for y,p in zip(y_val, preds)]) / float(len(preds))
    
    #Append to lambda
    vaccs.append(vacc)

end = time.time()


In [9]:
best_val = np.where(vaccs == np.max(vaccs))[0][0]
print 'validation accuracy = ', vaccs[best_val]
print 'best lambda =', lambdas[best_val]
print 'elapsed time for', N, 'samples = ', end-start, 'seconds'

validation accuracy =  0.708542713568
best lambda = 1e-05
elapsed time for 1000 samples =  41.8670392036 seconds
