In [1]:
import itertools

#split line
def getItemList(line):
	itemList = line.split(' ')
	return itemList[: -1]

#count item
def getItem(itemList):
	item = [(i, 1) for i in itemList]
	return item

def getNum(tN, nN):
	return tN + nN

#filter support
class checkVal(object):
	def __init__(self, threshold):
		self.threshold = threshold
		return

	def __call__(self, data):
		item, num = data
		return num >= self.threshold

#save frequent
class freqFilter(object):
	def __init__(self, freqNum):
		self.freqDict = dict(freqNum)
		self.freqSet = frozenset(self.freqDict.keys())
		return

	def __call__(self, itemList):
		return list(frozenset(itemList) & self.freqSet)

#use item generate pair
def getPair(itemList):
	pairList = []
	for pair in itertools.combinations(itemList, 2):
		pair = tuple(sorted(pair)) #(u, v), u < v
		pairList.append((pair, 1))
	return pairList

#1 pair -> 2 rules
class getPairConf(object):
	def __init__(self, freqF):
		self.freqDict = freqF.freqDict
		return

	def __call__(self, data):
		pair, num = data
		u, v = pair
		return [((u, v), num/self.freqDict[u]), ((v, u), num/self.freqDict[v])]

#str(data)
def getPairRule(pairRuleList, ruleNum):
	pairRuleList.sort(key = lambda pRule: (-pRule[1], pRule[0][0]))
	for line in pairRuleList[: ruleNum]:
		pair, num = line
		u, v = pair
		print(u + ' -> ' + v + ': ' + '%.3f' %num)
	return pairRuleList[: ruleNum]

#use item generate triple, use pair to check
class getTrip(object):
	def __init__(self, freqF):
		self.freqSet = freqF.freqSet
		return

	def __call__(self, itemList):
		tripList = []
		for trip in itertools.combinations(itemList, 3):
			trip = tuple(sorted(trip)) #(u, v, w), u < v < w
			if self._checkPair(trip):
				tripList.append((trip, 1))
		return tripList

	def _checkPair(self, trip):
		for pair in itertools.combinations(trip, 2):
			if pair not in self.freqSet:
				return False
		return True

#1 triple -> 3 rules
class getTripConf(object):
	def __init__(self, freqF):
		self.freqDict = freqF.freqDict
		return

	def __call__(self, data):
		trip, num = data
		return [((sTrip[0], sTrip[1]), num/self.freqDict[sTrip[0]]) for sTrip in self._splitTrip(trip)]

	def _splitTrip(self, trip):
		u, v, w = trip
		return [((u, v), w), ((u, w), v), ((v, w), u)]

#str(data)
def getTripRule(tripRuleList, ruleNum):
	tripRuleList.sort(key = lambda tRule: (-tRule[1], tRule[0][0][0], tRule[0][0][1]))
	for line in tripRuleList[: ruleNum]:
		trip, num = line
		pair, w = trip
		u, v = pair
		print(u + ' + ' + v + ' -> ' + w + ': ' + '%.3f' %num)
	return tripRuleList[: ruleNum]

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
640,application_1548171028063_0192,pyspark,idle,Link,Link,✔


SparkSession available as 'spark'.


In [2]:
data = sc.textFile('/user/hz333/data/hw1/Q2/browsing.txt')

#all items
line = data.map(getItemList)

#get freq items
item = line.flatMap(getItem)
itemNum = item.reduceByKey(getNum)
itemNum = itemNum.filter(checkVal(100))

#save freq items
freqItemNum = itemNum.collect() #hopefully it is not large #Yes it is not large, len < 1000
itemF = freqFilter(freqItemNum)

#freq items
line = line.map(itemF)

#get pair candidates
pair = line.flatMap(getPair)

#get freq pairs
pairNum = pair.reduceByKey(getNum)
pairNum = pairNum.filter(checkVal(100))

In [3]:
#save freq pairs
freqPairNum = pairNum.collect()
pairF = freqFilter(freqPairNum)

#get triple candidates
trip = line.flatMap(getTrip(pairF))

#get freq triples
tripNum = trip.reduceByKey(getNum)
tripNum = tripNum.filter(checkVal(100))

In [4]:
#get pair confidence
pairConf = pairNum.flatMap(getPairConf(itemF))
pairRule = pairConf.filter(checkVal(0.9)) #hopefully there are at least 5 lines
pairRule = pairRule.collect()

#output
pairRule = getPairRule(pairRule, 5)

DAI93865 -> FRO40251: 1.000
GRO85051 -> FRO40251: 0.999
GRO38636 -> FRO40251: 0.991
ELE12951 -> FRO40251: 0.991
DAI88079 -> FRO40251: 0.987

In [10]:
#get triple confidence
tripConf = tripNum.flatMap(getTripConf(pairF))
tripRule = tripConf.filter(checkVal(0.9)) #hopefully there are at least 5 lines
tripRule = tripRule.collect()

#output
tripRule = getTripRule(tripRule, 5)

DAI23334 + ELE92920 -> DAI62779: 1.000
DAI31081 + GRO85051 -> FRO40251: 1.000
DAI55911 + GRO85051 -> FRO40251: 1.000
DAI62779 + DAI88079 -> FRO40251: 1.000
DAI75645 + GRO85051 -> FRO40251: 1.000

In [11]:
sc.stop()