In [1]:
import itertools
import gc # for removing rdds from memory
from pyspark import SparkContext

In [2]:
sc = SparkContext(master='local', appName="Assignment1_E1")

22/04/24 11:43:43 WARN Utils: Your hostname, Luiss-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.0.126 instead (on interface en0)
22/04/24 11:43:43 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/04/24 11:43:44 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
#hdfs dfs -mkdir -p data
#hdfs dfs -put data/small_conditions.csv data/

In [4]:
SUPPORT_THRESHOLD = 1000
data = sc.textFile("data/conditions.csv.gz")
header = data.first() #extract header

                                                                                

In [5]:
# START,STOP,PATIENT,ENCOUNTER,CODE,DESCRIPTION
# PATIENT is the patient identifier
# CODE is a condition identifier 
# DESCRIPTION is the name of the condition

In [6]:
# Reorganizing the data to a list of baskets (lists)
# and removing the header row
lines = data.filter(lambda row: row != header) \
                .map(lambda line: tuple(line.split(",")))

lines.take(3)

                                                                                

[('2017-01-14',
  '2017-03-30',
  '09e4e8cb-29c2-4ef4-86c0-a6ff0ba25d2a',
  '88e540ab-a7d7-47de-93c1-720a06f3d601',
  '65363002',
  'Otitis media'),
 ('2012-09-15',
  '2012-09-16',
  'b0a03e8c-8d0f-4242-9548-40f4d294eba8',
  'e89414dc-d0c6-478f-86c0-d08bac6ad0a2',
  '241929008',
  'Acute allergic reaction'),
 ('2018-06-17',
  '2018-06-24',
  '09e4e8cb-29c2-4ef4-86c0-a6ff0ba25d2a',
  'c14325b0-f7ec-4314-bba8-dddc37f0067d',
  '444814009',
  'Viral sinusitis (disorder)')]

In [7]:
# Freeing memory
del data
gc.collect()

221

In [8]:
conditions = lines.map(lambda x: (int(x[4]), x[5])) \
                .distinct() \
                .collectAsMap()

#conditions

                                                                                

In [9]:
item_baskets = lines.map(lambda x: (x[2], {x[4]})) \
                    .reduceByKey(lambda a, b: a | b) \
                    .map(lambda x: tuple(x[1]))
                    

item_baskets.take(3)

                                                                                

[('241929008', '24079001', '444814009', '233678006', '43878008', '232353008'),
 ('444814009', '36971009', '58150001'),
 ('195662009', '44465007', '15777000', '271737000')]

In [10]:
#item_baskets.count()

In [11]:
# Freeing memory
del lines
gc.collect()

238

## Apriori Phase 1

In [12]:
# Flat listing all the baskets 
freqItemCounts = item_baskets.flatMap(lambda x: x) \
                    .map(lambda item: (item, 1)) \
                    .reduceByKey(lambda a, b: a + b) \
                    .filter(lambda item: item[1] >= SUPPORT_THRESHOLD)
                    

# Mapping -> create pairs (item, 1)
#itemPairs = items.map(lambda item: (item, 1))

# Reducing
#itemCounts = itemPairs.reduceByKey(lambda a, b: a + b)

# Keeping only the ones above the support threshold
#freqItemCounts = itemCounts.filter(lambda item: item[1] >= SUPPORT_THRESHOLD)


# Taking the 10 most frequent itemsets for k = 1
#freqItemCounts.takeOrdered(10, key=lambda x: -x[1])

## Intermediate step

In [13]:
# Creating the frequent items table
freqItemTable = freqItemCounts.map(lambda x: x[0]).collect()

#freqItemTable.take(10)

                                                                                

In [14]:
# In order for a pair to be frequent both its items have to be frequent. 
# As such, we can remove the unfrequent items from the baskets.

# Remove the unfrequent items from the baskets
item_baskets = item_baskets.filter(lambda basket: {item for item in basket if item in freqItemTable}) 
                            #.map(lambda x: tuple(x))    


item_baskets.take(3)

                                                                                

[('44465007',
  '70704007',
  '192127007',
  '10509002',
  '43878008',
  '195662009',
  '62106007',
  '65363002'),
 ('195662009', '444814009', '162864005'),
 ('19169002',
  '15777000',
  '59621000',
  '40055000',
  '271737000',
  '444814009',
  '62106007')]

## Phase 2, k = 2

In [15]:
# Receives as input: the baskets and the frequent items table
# Returns: candidate frequent pairs
def freq_pairs(basket,table):
    for item_1 in range(0, len(basket)):
        if basket[item_1] not in table:
            continue
        for item_2 in range(item_1 + 1, len(basket)): # j > i 
            if basket[item_2] in table:
                yield(tuple(sorted((basket[item_1], basket[item_2]))), 1)


In [16]:
def freq_pair(basket):
    candidate_pairs = itertools.combinations(basket, 2)
    for pair in candidate_pairs:
        yield((pair, 1))

In [17]:
# Convert to a list to avoid problems due to passing a rdd to another rdd
#table = freqItemTable.collect()

# counting pairs of frequent items
pairs = item_baskets.flatMap(lambda x: freq_pair(x)) \
                    .reduceByKey(lambda v1, v2: v1 + v2) \
                    .filter(lambda x: x[1] >= SUPPORT_THRESHOLD) \
                    .sortBy(lambda x: x[1], ascending=False)
                    
pairs.take(10)

                                                                                

[(('15777000', '271737000'), 289116),
 (('444814009', '195662009'), 265507),
 (('444814009', '162864005'), 240844),
 (('10509002', '444814009'), 238167),
 (('15777000', '444814009'), 222725),
 (('271737000', '444814009'), 218281),
 (('59621000', '444814009'), 174520),
 (('10509002', '195662009'), 167718),
 (('271737000', '195662009'), 152499),
 (('40055000', '444814009'), 150711)]

In [18]:
frequent_pairs = pairs.map(lambda x: x[0]).collect()

frequent_pairs

[('15777000', '271737000'),
 ('444814009', '195662009'),
 ('444814009', '162864005'),
 ('10509002', '444814009'),
 ('15777000', '444814009'),
 ('271737000', '444814009'),
 ('59621000', '444814009'),
 ('10509002', '195662009'),
 ('271737000', '195662009'),
 ('40055000', '444814009'),
 ('15777000', '195662009'),
 ('162864005', '195662009'),
 ('15777000', '10509002'),
 ('10509002', '162864005'),
 ('19169002', '444814009'),
 ('59621000', '195662009'),
 ('15777000', '59621000'),
 ('10509002', '271737000'),
 ('40055000', '195662009'),
 ('15777000', '162864005'),
 ('15777000', '40055000'),
 ('59621000', '162864005'),
 ('271737000', '162864005'),
 ('59621000', '271737000'),
 ('40055000', '162864005'),
 ('444814009', '43878008'),
 ('19169002', '162864005'),
 ('40055000', '271737000'),
 ('19169002', '195662009'),
 ('10509002', '59621000'),
 ('19169002', '15777000'),
 ('55822004', '444814009'),
 ('72892002', '444814009'),
 ('19169002', '10509002'),
 ('72892002', '195662009'),
 ('195662009', '4448

In [19]:
# Only worth checking for triples in items that are in frequent pairs.
# So we create a table similar to the frequent items table in order to remove
# unfrequent items from the baskets.

freq_pair_table = pairs.flatMap(lambda x: x[0]) \
                        .distinct() \
                        .collect()

len(freq_pair_table)

128

In [22]:
# Removing unfrequent items from the baskets and droping baskets 
# with fewer than 3 items because we need ate least 3 items to make a triple.

item_baskets = item_baskets.filter(lambda basket: {item for item in basket if item in freq_pair_table}) \
                            .filter(lambda x: len(x) > 2 )
                            

item_baskets.take(3)

                                                                                

[('126906006',
  '55822004',
  '92691004',
  '10509002',
  '90560007',
  '7200002',
  '444814009',
  '162864005',
  '263102004',
  '195662009'),
 ('19169002',
  '429007001',
  '10509002',
  '79586000',
  '162864005',
  '195662009',
  '410429000'),
 ('40055000', '59621000', '444814009', '72892002', '75498004', '283385000')]

## Phase 2, k = 3

# Receives as input: the baskets, frequent items and frequent pairs
# Returns: candidate frequent triples
def freq_triples(basket, table, fqt_pairs):
    for item_1 in range(0, len(basket)):
        if basket[item_1] not in table:
            continue
        for item_2 in range(item_1 + 1, len(basket)):  # j > i
            if basket[item_2] not in table:
                continue
                        
            pair = tuple(sorted((basket[item_1], basket[item_2])))
            if pair not in fqt_pairs:  # if the pair is not frequent the triple also won't be frequent
                continue
            
            for item_3 in range(item_2 + 1, len(basket)):
                if basket[item_3] not in table:
                    continue
                
                candidate_pairs = list(itertools.combinations((item_1, item_2, item_3), 2))

                # if all candidate pairs are frequent pairs yield the candidate triple
                if all(candidate_pair in fqt_pairs for candidate_pair in candidate_pairs):
                    yield(tuple(sorted((basket[item_1], basket[item_2], basket[item_3]))), 1)


In [24]:
def freq_triple(basket):
    candidate_triples = itertools.combinations(basket, 3)

    for triple in candidate_triples:
        yield(triple, 1)

In [25]:
triples = item_baskets.flatMap(lambda x: freq_triple(x)) \
                    .reduceByKey(lambda v1, v2: v1 + v2) \
                    .filter(lambda x: x[1] >= SUPPORT_THRESHOLD) \
                    .sortBy(lambda x: x[1], ascending=False)

triples.take(10)

                                                                                

[(('15777000', '271737000', '444814009'), 177924),
 (('15777000', '271737000', '195662009'), 124307),
 (('10509002', '444814009', '195662009'), 112687),
 (('15777000', '444814009', '195662009'), 102564),
 (('271737000', '444814009', '195662009'), 100769),
 (('444814009', '162864005', '195662009'), 99780),
 (('15777000', '10509002', '271737000'), 95176),
 (('10509002', '444814009', '162864005'), 86641),
 (('15777000', '10509002', '444814009'), 85968),
 (('59621000', '444814009', '195662009'), 81822)]

In [26]:
frequent_triples = triples.map(lambda x: x[0]).collect()
frequent_triples

                                                                                

[('15777000', '271737000', '444814009'),
 ('15777000', '271737000', '195662009'),
 ('10509002', '444814009', '195662009'),
 ('15777000', '444814009', '195662009'),
 ('271737000', '444814009', '195662009'),
 ('444814009', '162864005', '195662009'),
 ('15777000', '10509002', '271737000'),
 ('10509002', '444814009', '162864005'),
 ('15777000', '10509002', '444814009'),
 ('59621000', '444814009', '195662009'),
 ('15777000', '59621000', '271737000'),
 ('15777000', '271737000', '162864005'),
 ('15777000', '40055000', '271737000'),
 ('15777000', '59621000', '444814009'),
 ('10509002', '271737000', '444814009'),
 ('40055000', '444814009', '195662009'),
 ('15777000', '444814009', '162864005'),
 ('59621000', '444814009', '162864005'),
 ('15777000', '40055000', '444814009'),
 ('40055000', '444814009', '162864005'),
 ('271737000', '444814009', '162864005'),
 ('15777000', '10509002', '195662009'),
 ('10509002', '162864005', '195662009'),
 ('59621000', '271737000', '444814009'),
 ('19169002', '44481

In [27]:
# triples.count()

                                                                                

12968

## Mining Association Rules

In [None]:
# joining all the frequent baskets in one list
frequent_baskets = (freqItemCounts + pairs + triples).collect()
frequent_baskets[0:10]

In [None]:
# Mining for k = 2
mapa = dict(frequent_baskets)

rules = {}
for pair in pairs.collect():
    
    
    print(mapa[pair[0][0]])
    # Rule 1  X -> Y = #(X u Y) / #X
    confidence = int(pair[1]) / mapa[pair[0][0]]
    rules[pair[0][0]] = (pair[0][1], confidence)
    
    print(mapa[pair[0][1]])
    # Rule 2  Y -> X = #(X u Y) / #Y
    confidence = int(pair[1]) / mapa[pair[0][1]]
    rules[pair[0][1]] = (pair[0][0], confidence)

#rules


In [None]:
for triple in triples.collect():
    pass