In [1]:
import pyspark
import numpy as np
from DataUtils import DataUtil
import time

sc = pyspark.SparkContext('local[*]')

In [2]:
dataUtil = DataUtil(sc, 'data/spam.data.txt', 'data/mean_std.txt', True)
rdd = dataUtil.read(sc)

In [None]:
rdd.zipWithIndex().take(3)

In [3]:
def stats(data):
    def f(y, pred):
        if y == pred:
            return 1,0,0,0 if y == 1 else 0,1,0,0
        elif pred == 1:
            return 0,0,1,0
        else:
            return 0,0,0,1
    tp, tn, fp, fn = data.map(lambda x: f(x[2], x[4])).reduce(lambda a, b: tuple(map(sum, zip(a, b))))
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    return precision, recall, f1, accuracy

class ParallelLogReg():
    
    def __init__(self, sc, dataUtils, iterations, learning_rate, lambda_reg, fit_intercept, threshold=0.5):
        self.dataUtils = dataUtils
        # do we need to broadcast these?
        self.iterations = iterations
        self.lr = learning_rate
        self.lambda_reg = lambda_reg
        self.fit_intercept = fit_intercept
        self.data = self.dataUtils.read(sc)
        self.numberObservations = self.data.count()
        self.numberFeatures = 56 #len(self.data.first()[0])
        self.threshold = threshold
        
    def __add_intercept(self):
        self.data = self.data.map(lambda x: (1, x[0], x[1]))
    
    def __sigmoid(self, z):
        return 1 / (1 + np.exp(-z))
    
    def __predict_y(self, w, x):
        return self.__sigmoid(w[1].dot(x[1]) + w[0] * x[0])
    
    def calculateLoss(self):
        return 1
    
    def train(self):
        import sys
        eps = sys.float_info.epsilon
        self.__add_intercept()
        
        lambda_reg = self.lambda_reg
        m = self.numberObservations
        lr = self.lr
        threshold = self.threshold
        
        # initialize the weights
        # w[0]: bias weight
        # w[1]: rest of the weights
        w = (0, np.zeros(self.numberFeatures))
        
        # initialize prediction to rdd
        # x[0]: bias/intercept
        # x[1]: rest features
        # x[2]: true y
        # adding x[3]: predicted y
        data = self.data.map(lambda x: (x[0], x[1], x[2], 1 / (1 + np.exp(-(np.dot(x[1], w[1]) + x[0] * w[0]))))).cache()
                
        for i in range(self.iterations):
            start = time.time()
            # compute derivatives
            temp = data.map(lambda x: ((x[3] - x[2]) * x[0], (x[3] - x[2]) * x[1])) \
                         .reduce(lambda a,b: (a[0] + b[0], a[1] + b[1]))            
            dw = (temp[0]/m, (temp[1]/m) + (lambda_reg/m) * w[1])
            derivatives_end_time = time.time()

            # update weights
            w = (w[0] - lr * dw[0], w[1] - lr * dw[1])
            update_weights_end_time = time.time()

            # update prediction
            data = data.map(lambda x: (x[0], x[1], x[2], 1 / (1 + np.exp(-(np.dot(x[1], w[1]) + x[0] * w[0]))))).cache()
            update_prediction_end_time = time.time()

            if (i%10 == 0):
                loss = data.map(lambda x: x[2] * np.log(x[3] + eps) + (1 - x[2]) * np.log(1 - x[3] + eps)) \
                     .reduce(lambda a,b: a + b)
                loss_end_time = time.time()

                loss = -(1/self.numberObservations) * loss + (self.lambda_reg/(2*self.numberObservations)) * np.sum(w[1]**2)
                end = time.time()
                print('Loss: ' + str(loss))
                print('Total time: ' + str(end - start))
                print('Compute derivatives time: ' + str(derivatives_end_time - start))
                print('Update weights time: ' + str(update_weights_end_time - derivatives_end_time))
                print('Update prediciton time: ' + str(update_prediction_end_time - update_weights_end_time))
                print('Compute loss time: ' + str(loss_end_time - update_prediction_end_time))
                
                data = data.map(lambda x: (x[0], x[1], x[2], x[3], 1 if x[3] >= threshold else 0)).cache()
                precision, recall, f1, accuracy = stats(data)
                print('Precision: {}, Recall: {}, F1: {}, Accuracy: {}'.format(precision, recall, f1, accuracy))
                print('\n')
            
            
                
        data = data.map(lambda x: (x[0], x[1], x[2], x[3], 1 if x[3] >= 0.5 else 0))
        
        self.acc = data.map(lambda x: 1 if x[2] == x[4] else 0) \
                            .reduce(lambda a,b: a+b) / self.numberObservations
            
        return w
            

In [6]:
logReg = ParallelLogReg(sc, dataUtil, 100, 0.1, 0.1, True)

In [7]:
logReg.train()

Loss: 0.6503220242271294
Total time: 0.44766926765441895
Compute derivatives time: 0.26760101318359375
Update weights time: 1.3828277587890625e-05
Update prediciton time: 0.010751724243164062
Compute loss time: 0.16924691200256348
Precision: 0.9370725034199726, Recall: 0.9502890173410404, F1: 0.9436344851337389, Accuracy: 0.8932840686807216


Loss: 0.5221875677942007
Total time: 0.31861329078674316
Compute derivatives time: 0.13772130012512207
Update weights time: 1.2874603271484375e-05
Update prediciton time: 0.014154911041259766
Compute loss time: 0.1666562557220459
Precision: 0.9438202247191011, Recall: 0.9449035812672176, F1: 0.9443615922909258, Accuracy: 0.8945881330145621


Loss: 0.4581797205838391
Total time: 0.3738396167755127
Compute derivatives time: 0.20208072662353516
Update weights time: 1.2874603271484375e-05
Update prediciton time: 0.016008615493774414
Compute loss time: 0.15566635131835938
Precision: 0.9489772466099747, Recall: 0.9429093400319708, F1: 0.9459335624284078

(-0.4283915364779883,
 array([-0.00622156, -0.06707946,  0.12603675,  0.10679152,  0.28608055,
         0.20375194,  0.49958148,  0.24671533,  0.22264238,  0.09935942,
         0.12813877, -0.10040145,  0.05902925,  0.04332017,  0.18387406,
         0.36372831,  0.28676633,  0.20048329,  0.15539207,  0.22677141,
         0.31132481,  0.19344826,  0.4381361 ,  0.26738814, -0.29162514,
        -0.22380624, -0.25198122, -0.05837027, -0.11482588, -0.13995595,
        -0.07090135, -0.03828431, -0.15286331, -0.03645856, -0.11456854,
        -0.02073074, -0.1365954 , -0.05547121, -0.12921777, -0.00276266,
        -0.11337995, -0.20190582, -0.11778833, -0.13805032, -0.23019683,
        -0.22506779, -0.08137009, -0.11471476, -0.1139385 , -0.04269744,
        -0.05166024,  0.33324663,  0.4674863 ,  0.11167972,  0.12264795,
         0.28245354]))

In [None]:
logReg.acc