In [19]:
import numpy as np
import random

In [20]:
!pip install pyspark

Defaulting to user installation because normal site-packages is not writeable


In [21]:
from pyspark.sql import SparkSession
spark = SparkSession\
    .builder\
    .appName("recommender")\
    .getOrCreate()
sc = spark.sparkContext

In [22]:
data = sc.textFile("ml-100k/u.data")
header = data.first() #extract header
data = data.filter(lambda row: row != header)
M = data.map(
  lambda l: l.split('\t')
).map(
  lambda l: (int(l[0]), int(l[1]), float(l[2]))
)


In [23]:
numWorkers = sc.defaultParallelism
numWorkers

4

In [24]:
M.max(lambda x: x[0])

(943, 58, 4.0)

In [25]:
def assignBlockIndex (index, numData, numWorkers):
    blockSize = numData/numWorkers
    if(numData % numWorkers != 0): blockSize = blockSize + 1
    return int(np.floor(index/np.ceil(blockSize)))+1


In [26]:
numFactors = 10
numRows = M.max(lambda x : x[0])[0] + 1
numCols = M.max(lambda x : x[1])[1] + 1
avgRating = M.map(lambda x: x[2]).mean()

#compute the scaling factor that the randomly initialized W and H matrices need to be scaled by so that dot(W_0,H_0) results in values that are similar to the average ratings
scaleRating = np.sqrt(avgRating / numFactors)

tau = 100


                                                                                

In [27]:
W = M.map(lambda x: tuple([int(x[0]),1])).reduceByKey(lambda x,y : x+y).map(lambda x: tuple([x[0], tuple([x[1], np.random.rand(1,numFactors).astype('float16')])])).persist()
H = M.map(lambda x: tuple([int(x[1]),1])).reduceByKey(lambda x,y : x+y).map(lambda x: tuple([x[0], tuple([x[1], np.random.rand(1,numFactors).astype('float16')])])).persist()

Iterater turn

In [28]:
it = 1
beta = 0.1
mse = sc.accumulator(0.0)
nUpdates = sc.accumulator(0.0)
#broadcast the stepsize
#stepSize = sc.broadcast(np.power(tau + it, -beta))
stepSize = sc.broadcast(0.5)
#generate random strata
lam = sc.broadcast(0.1)

perms = np.random.permutation(numWorkers)+1
perms

array([2, 4, 3, 1])

In [29]:
Mblocked = M.keyBy(lambda x: assignBlockIndex(x[0], numRows, numWorkers)).partitionBy(numWorkers)

In [30]:
def SGD(keyed_iterable, stepSize, numFactors,lam, mse, nUpdates, scaleRating):
    iterlist = (next(keyed_iterable))
    key = iterlist[0]
    Miter = iterlist[1][0]
    Hiter = iterlist[1][1]
    Witer = iterlist[1][2]
    
    Wdict = {}
    Hdict = {}
    
    Wout = {}
    Hout = {}
    
    #iterate through H and W and create dictionary of elements
    for h in Hiter:
        Hdict[h[0]] = h[1]
    
    for w in Witer:
        Wdict[w[0]] = w[1]
    #iterate through entries of M and compute L2-loss
    counter = 1
    for m in Miter:
        (i,j,rat) = m
        if i not in Wdict:
            Wdict[i] = tuple([i,np.random.uniform(0,1,numFactors).astype('float32')])
        if j not in Hdict:
            Hdict[j] = tuple([j,np.random.uniform(0,1,numFactors).astype('float32')])

        (Nw, Wprev) = Wdict[i]
        (Nh, Hprev) = Hdict[j]
        error = (rat - np.dot(Wprev, Hprev.T))
        mse += pow((error[0][0]), 2)
        
        Wnew = Wprev - stepSize.value*(-2*error*Hprev+ (2.0*lam.value)*Wprev)
        Hnew = Hprev - stepSize.value*(-2*error*Wprev + (2.0*lam.value)*Hprev)
        nUpdates += 1
        Wdict[i] = tuple([Nw, Wnew])
        Hdict[j] = tuple([Nh, Hnew])
        
    return (tuple(['W',Wdict.items()]), tuple(['H',Hdict.items()]))

In [31]:
#mse = sc.accumulator(0.0)
#nUpdates = sc.accumulator(0)
stepSize = 0.01
max_iter = 100
stepSize = sc.broadcast(stepSize)
for it in range(max_iter):
    mse = sc.accumulator(0.0)
    nUpdates = sc.accumulator(0)

    beta = 0.1
    #broadcast the stepsize
    #stepSize = sc.broadcast(1/np.power(tau + it, beta))
    stepSize = sc.broadcast(stepSize.value * 0.9)
    #generate random strata
    lam = sc.broadcast(0.01)

    perms = np.random.permutation(numWorkers)+1
    perms_dict = {i: val for i, val in enumerate(perms)}
    rev_perms=list(i+1 for i in (dict(sorted(perms_dict.items(), key=lambda item: item[1]))).keys())
    Mfilt = Mblocked.filter(lambda x: perms[x[0]-1]==assignBlockIndex(x[1][1],numCols,numWorkers)).persist()
    Hblocked = H.keyBy(lambda x: rev_perms[assignBlockIndex(x[0], numCols, numWorkers)-1])
    Wblocked = W.keyBy(lambda x: assignBlockIndex(x[0], numRows, numWorkers))
    groupRDD = Mfilt.groupWith(Hblocked, Wblocked).partitionBy(numWorkers)
    WH = groupRDD.mapPartitions(lambda x: SGD(x, stepSize, numFactors,lam, mse, nUpdates,scaleRating))
    W.unpersist()
    H.unpersist()
    W = WH.filter(lambda x: x[0]=='W').flatMap(lambda x: x[1]).persist()
    H = WH.filter(lambda x: x[0]=='H').flatMap(lambda x: x[1]).persist()
    Wvec = W.collect()
    Hvec = H.collect()
    print("MSE/update for {}-th iteration is: {}/{} ".format(it, mse.value, nUpdates.value))
    print("RMSE: {}".format((mse.value/nUpdates.value)))


                                                                                

MSE/update for 0-th iteration is: 70622.98502588272/51806 
RMSE: 1.3632201873505525


                                                                                

MSE/update for 1-th iteration is: 53182.59436607361/52344 
RMSE: 1.0160208307747518


                                                                                

MSE/update for 2-th iteration is: 49432.68571329117/51240 
RMSE: 0.9647284487371423


                                                                                

MSE/update for 3-th iteration is: 47256.09651851654/48286 
RMSE: 0.9786707641659392


                                                                                

MSE/update for 4-th iteration is: 49056.025265693665/53266 
RMSE: 0.9209631897588267


                                                                                

MSE/update for 5-th iteration is: 44715.931720256805/52344 
RMSE: 0.8542704363490907


                                                                                

MSE/update for 6-th iteration is: 41899.186269283295/47238 
RMSE: 0.8869805298548477


                                                                                

MSE/update for 7-th iteration is: 44834.97287321091/51866 
RMSE: 0.8644386085915804


                                                                                

MSE/update for 8-th iteration is: 42419.8792014122/51240 
RMSE: 0.8278664949533997


                                                                                

MSE/update for 9-th iteration is: 40780.8784403801/48236 
RMSE: 0.8454448635952421


                                                                                

MSE/update for 10-th iteration is: 44037.98388147354/52648 
RMSE: 0.83646071800398


                                                                                

MSE/update for 11-th iteration is: 41826.875583171844/51220 
RMSE: 0.8166121746031207


                                                                                

MSE/update for 12-th iteration is: 38627.90281677246/47266 
RMSE: 0.8172450136836724


                                                                                

MSE/update for 13-th iteration is: 40409.23011302948/51866 
RMSE: 0.7791082812059823


                                                                                

MSE/update for 14-th iteration is: 42757.88488292694/52344 
RMSE: 0.8168631530438435


                                                                                

MSE/update for 15-th iteration is: 41983.177961826324/48932 
RMSE: 0.8579902305613162


                                                                                

MSE/update for 16-th iteration is: 41779.956810474396/52648 
RMSE: 0.7935715850644735


                                                                                

MSE/update for 17-th iteration is: 43592.05365514755/47502 
RMSE: 0.9176888058428604


                                                                                

MSE/update for 18-th iteration is: 43493.329659461975/51484 
RMSE: 0.8447931330017476


                                                                                

MSE/update for 19-th iteration is: 39705.18914651871/51240 
RMSE: 0.7748865953653143


                                                                                

MSE/update for 20-th iteration is: 44105.68939638138/53040 
RMSE: 0.8315552299468586


                                                                                

MSE/update for 21-th iteration is: 41891.039134025574/52640 
RMSE: 0.7958024151600603


                                                                                

MSE/update for 22-th iteration is: 41078.74190235138/52648 
RMSE: 0.7802526573155938


                                                                                

MSE/update for 23-th iteration is: 39473.067621707916/47238 
RMSE: 0.8356210597761954


                                                                                

MSE/update for 24-th iteration is: 41062.30830812454/47502 
RMSE: 0.8644332514025629


                                                                                

MSE/update for 25-th iteration is: 42613.115972042084/52718 
RMSE: 0.8083219388452157


                                                                                

MSE/update for 26-th iteration is: 40711.424865722656/51220 
RMSE: 0.7948345346685407


                                                                                

MSE/update for 27-th iteration is: 39200.09096670151/47336 
RMSE: 0.8281242810271571


                                                                                

MSE/update for 28-th iteration is: 39367.19683980942/51240 
RMSE: 0.7682903364521745


                                                                                

MSE/update for 29-th iteration is: 40747.125428676605/48286 
RMSE: 0.8438703853845132


                                                                                

MSE/update for 30-th iteration is: 42090.88639116287/53266 
RMSE: 0.7902017495431021


                                                                                

MSE/update for 31-th iteration is: 42671.75398015976/49098 
RMSE: 0.8691138942555656


                                                                                

MSE/update for 32-th iteration is: 37354.37569093704/47266 
RMSE: 0.7903011824765591


                                                                                

MSE/update for 33-th iteration is: 38315.76947641373/48236 
RMSE: 0.7943396939301295


                                                                                

MSE/update for 34-th iteration is: 39157.28702020645/47012 
RMSE: 0.8329211056795383


                                                                                

MSE/update for 35-th iteration is: 42900.91604280472/53040 
RMSE: 0.8088408002037089


                                                                                

MSE/update for 36-th iteration is: 38836.854192733765/47012 
RMSE: 0.8261051261961577


                                                                                

MSE/update for 37-th iteration is: 39313.57276010513/51240 
RMSE: 0.7672438087452211


                                                                                

MSE/update for 38-th iteration is: 39256.13779258728/48208 
RMSE: 0.8143075380141735


                                                                                

MSE/update for 39-th iteration is: 41491.2747130394/52640 
RMSE: 0.7882081062507484


                                                                                

MSE/update for 40-th iteration is: 42978.94082021713/51806 
RMSE: 0.8296131880519078


                                                                                

MSE/update for 41-th iteration is: 40050.11951494217/48932 
RMSE: 0.8184852349166634


                                                                                

MSE/update for 42-th iteration is: 40283.08359289169/47206 
RMSE: 0.8533466845928842


                                                                                

MSE/update for 43-th iteration is: 40946.50648403168/52344 
RMSE: 0.7822578802543114


                                                                                

MSE/update for 44-th iteration is: 40852.20942354202/52344 
RMSE: 0.7804563927774343


                                                                                

MSE/update for 45-th iteration is: 41360.97428226471/52640 
RMSE: 0.785732794115971


                                                                                

MSE/update for 46-th iteration is: 40212.7286567688/48286 
RMSE: 0.8328030621043118


                                                                                

MSE/update for 47-th iteration is: 42464.73464202881/49098 
RMSE: 0.8648974427070106


                                                                                

MSE/update for 48-th iteration is: 39437.03858470917/51240 
RMSE: 0.769653368163723


                                                                                

MSE/update for 49-th iteration is: 38804.65500497818/47012 
RMSE: 0.8254202119666931


                                                                                

MSE/update for 50-th iteration is: 38478.43241882324/47336 
RMSE: 0.8128788325761206


                                                                                

MSE/update for 51-th iteration is: 40174.446496486664/48286 
RMSE: 0.8320102409909014


                                                                                

MSE/update for 52-th iteration is: 42921.58204174042/51806 
RMSE: 0.8285060039713628


                                                                                

MSE/update for 53-th iteration is: 41845.15248441696/53266 
RMSE: 0.7855884144560689


                                                                                

MSE/update for 54-th iteration is: 38477.48880147934/47336 
RMSE: 0.8128588981215004


                                                                                

MSE/update for 55-th iteration is: 41840.01585435867/53266 
RMSE: 0.7854919808951052


                                                                                

MSE/update for 56-th iteration is: 40125.40970802307/48932 
RMSE: 0.820023904766269


                                                                                

MSE/update for 57-th iteration is: 39285.475838661194/48208 
RMSE: 0.8149161101614087


                                                                                

MSE/update for 58-th iteration is: 39436.763696193695/51240 
RMSE: 0.7696480034385967


                                                                                

MSE/update for 59-th iteration is: 42546.159205913544/53040 
RMSE: 0.8021523228867561


                                                                                

MSE/update for 60-th iteration is: 42282.368039131165/52718 
RMSE: 0.8020480298784317


                                                                                

MSE/update for 61-th iteration is: 38422.265840530396/48236 
RMSE: 0.7965475130717803


                                                                                

MSE/update for 62-th iteration is: 41021.81247711182/52648 
RMSE: 0.7791713356084147


                                                                                

MSE/update for 63-th iteration is: 40179.44080162048/48286 
RMSE: 0.8321136727337216


                                                                                

MSE/update for 64-th iteration is: 38795.94292259216/47012 
RMSE: 0.8252348958264307


                                                                                

MSE/update for 65-th iteration is: 42464.43295431137/49098 
RMSE: 0.864891298104024


                                                                                

MSE/update for 66-th iteration is: 39884.8533744812/51866 
RMSE: 0.7689980598943663


                                                                                

MSE/update for 67-th iteration is: 38088.8987531662/47238 
RMSE: 0.8063190387646851


                                                                                

MSE/update for 68-th iteration is: 42930.206290245056/51806 
RMSE: 0.8286724759727648


                                                                                

MSE/update for 69-th iteration is: 40126.056661605835/48932 
RMSE: 0.8200371262487909


                                                                                

MSE/update for 70-th iteration is: 41020.66218805313/52648 
RMSE: 0.7791494869330864


                                                                                

MSE/update for 71-th iteration is: 41833.76697778702/53266 
RMSE: 0.7853746663497732


                                                                                

MSE/update for 72-th iteration is: 41833.70531320572/53266 
RMSE: 0.7853735086773124


                                                                                

MSE/update for 73-th iteration is: 42540.765773296356/53040 
RMSE: 0.8020506367514396


                                                                                

MSE/update for 74-th iteration is: 40126.22468948364/48932 
RMSE: 0.8200405601545746


                                                                                

MSE/update for 75-th iteration is: 42930.26015996933/51806 
RMSE: 0.8286735158083877


                                                                                

MSE/update for 76-th iteration is: 38796.03603553772/47012 
RMSE: 0.8252368764472415


                                                                                

MSE/update for 77-th iteration is: 39937.93934345245/51220 
RMSE: 0.779733294483648


                                                                                

MSE/update for 78-th iteration is: 42540.713431835175/53040 
RMSE: 0.8020496499214776


                                                                                

MSE/update for 79-th iteration is: 42930.24806308746/51806 
RMSE: 0.8286732823048965


                                                                                

MSE/update for 80-th iteration is: 40837.327288627625/52344 
RMSE: 0.7801720787220622


                                                                                

MSE/update for 81-th iteration is: 40820.184403419495/47502 
RMSE: 0.8593361206563828


                                                                                

MSE/update for 82-th iteration is: 37223.019018650055/47266 
RMSE: 0.7875220881532191


                                                                                

MSE/update for 83-th iteration is: 41651.26286840439/48480 
RMSE: 0.8591432109819387


                                                                                

MSE/update for 84-th iteration is: 42279.646689891815/52718 
RMSE: 0.8019964090043593


                                                                                

MSE/update for 85-th iteration is: 40268.220101356506/47206 
RMSE: 0.8530318201363494


                                                                                

MSE/update for 86-th iteration is: 38796.082902908325/47012 
RMSE: 0.8252378733708059


                                                                                

MSE/update for 87-th iteration is: 41833.49755334854/53266 
RMSE: 0.785369608255708


                                                                                

MSE/update for 88-th iteration is: 41020.5260052681/52648 
RMSE: 0.7791469002672105


                                                                                

MSE/update for 89-th iteration is: 37223.04571390152/47266 
RMSE: 0.7875226529408352


                                                                                

MSE/update for 90-th iteration is: 40837.28349971771/52344 
RMSE: 0.7801712421618087


                                                                                

MSE/update for 91-th iteration is: 39440.79470348358/51240 
RMSE: 0.7697266725894532


                                                                                

MSE/update for 92-th iteration is: 42669.223048210144/51484 
RMSE: 0.828786089818393


                                                                                

MSE/update for 93-th iteration is: 42540.6158823967/53040 
RMSE: 0.8020478107540856


                                                                                

MSE/update for 94-th iteration is: 41020.51183748245/52648 
RMSE: 0.7791466311632437


                                                                                

MSE/update for 95-th iteration is: 39288.755153656006/48208 
RMSE: 0.8149841344518753


                                                                                

MSE/update for 96-th iteration is: 40268.2315530777/47206 
RMSE: 0.8530320627267233


                                                                                

MSE/update for 97-th iteration is: 38422.84214544296/48236 
RMSE: 0.79655946068171


                                                                                

MSE/update for 98-th iteration is: 38482.13399839401/47336 
RMSE: 0.8129570305558985


                                                                                

MSE/update for 99-th iteration is: 39440.781863212585/51240 
RMSE: 0.7697264219986844


                                                                                

In [32]:
M.filter(lambda x: x[0]==15).collect()

[(15, 405, 2.0),
 (15, 749, 1.0),
 (15, 25, 3.0),
 (15, 331, 3.0),
 (15, 222, 3.0),
 (15, 473, 1.0),
 (15, 678, 1.0),
 (15, 932, 1.0),
 (15, 127, 2.0),
 (15, 685, 4.0),
 (15, 20, 3.0),
 (15, 301, 4.0),
 (15, 278, 1.0),
 (15, 620, 4.0),
 (15, 742, 2.0),
 (15, 137, 4.0),
 (15, 696, 2.0),
 (15, 924, 3.0),
 (15, 289, 3.0),
 (15, 508, 2.0),
 (15, 754, 5.0),
 (15, 18, 1.0),
 (15, 286, 2.0),
 (15, 148, 3.0),
 (15, 864, 4.0),
 (15, 244, 2.0),
 (15, 274, 4.0),
 (15, 9, 4.0),
 (15, 307, 1.0),
 (15, 458, 5.0),
 (15, 476, 4.0),
 (15, 471, 4.0),
 (15, 7, 1.0),
 (15, 937, 4.0),
 (15, 929, 1.0),
 (15, 889, 3.0),
 (15, 591, 2.0),
 (15, 1, 1.0),
 (15, 933, 1.0),
 (15, 459, 5.0),
 (15, 411, 2.0),
 (15, 744, 4.0),
 (15, 815, 1.0),
 (15, 300, 4.0),
 (15, 926, 1.0),
 (15, 409, 3.0),
 (15, 308, 5.0),
 (15, 303, 3.0),
 (15, 251, 2.0),
 (15, 936, 5.0),
 (15, 455, 1.0),
 (15, 306, 5.0),
 (15, 928, 1.0),
 (15, 14, 4.0),
 (15, 269, 5.0),
 (15, 13, 1.0),
 (15, 252, 2.0),
 (15, 927, 4.0),
 (15, 255, 5.0),
 (15, 50

In [33]:
W.filter(lambda x: x[0]==186).map(lambda x: x[1][1]).collect()[0]

array([[0.842 , 0.785 , 0.225 , 0.2095, 0.489 , 0.579 , 0.983 , 0.8164,
        0.405 , 0.738 ]], dtype=float16)

In [34]:
w_i = np.array(W.filter(lambda x: x[0]==15).map(lambda x: x[1][1]).collect()[0])
w_j = np.array(W.filter(lambda x: x[0]==222).map(lambda x: x[1][1]).collect()[0])
np.dot(w_i, w_j.T)

array([[2.908]], dtype=float16)