In [None]:
import re
import ast
import pyspark
import time
import datetime
import html
from pyspark import SparkContext
from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql.functions import unix_timestamp

In [None]:
sc = SparkContext()
sqlContext = SQLContext(sc)

In [None]:
# Funkcja pomocnicza: wyszukiwanie podanego atrybutu w linii pliku XML,
# zwraca wartość znalezionego atrybutu lub None jeśli atrybut nie istnieje

def attribute_search(attribute, string):
    result = re.search(attribute + '=\"(.*?)\"', string)
    if result:
        return result.group(1).replace('"', '')
    else:
        return None

In [None]:
# Funkcje pomocnicze: interpretacja plików XML (różne schematy danych)
# ze Stackoverflow

def tags_from_xml(line):
    c = line.replace('<row', '').replace('/>', '')
    row = dict()
    row['Id'] = int(attribute_search('Id', c));
    row['TagName'] = attribute_search('TagName', c);
    count = attribute_search('Count', c);    
    row['Count'] = int(count) if count else None;
    return pyspark.Row(**row)

def badges_from_xml(line):
    c = line.replace('<row', '').replace('/>', '')
    row = dict()
    row['Id'] = int(attribute_search('Id', c));
    row['UserId'] = int(attribute_search('UserId', c));
    row['Name'] = attribute_search('Name', c);
    row['Date'] = datetime.datetime.strptime(attribute_search('Date', c), "%Y-%m-%dT%H:%M:%S.%f");
    row['Class'] = int(attribute_search('Class', c));
    row['TagBased'] = ast.literal_eval(attribute_search('TagBased', c));
    return pyspark.Row(**row)

def users_from_xml(line):
    c = line.replace('<row', '').replace('/>', '')
    row = dict()
    row['Id'] = int(attribute_search('Id', c));
    row['Reputation'] = int(attribute_search('Reputation', c));
    row['CreationDate'] = datetime.datetime.strptime(attribute_search('CreationDate', c), "%Y-%m-%dT%H:%M:%S.%f");
    row['DisplayName'] = attribute_search('DisplayName', c);
    row['LastAccessDate'] = datetime.datetime.strptime(attribute_search('LastAccessDate', c), "%Y-%m-%dT%H:%M:%S.%f");
    row['WebsiteUrl'] = attribute_search('WebsiteUrl', c);
    row['Location'] = attribute_search('Location', c);
    age = attribute_search('Age', c);
    row['Age'] = int(age) if age else None;
    row['Views'] = int(attribute_search('Views', c));
    row['UpVotes'] = int(attribute_search('UpVotes', c));
    row['DownVotes'] = int(attribute_search('DownVotes', c));    
    return pyspark.Row(**row)

def posts_from_xml(line):
    c = line.replace('<row', '').replace('/>', '')
    row = dict()
    row['Id'] = int(attribute_search('Id', c));
    row['PostTypeId'] = int(attribute_search('PostTypeId', c));
    found_id = attribute_search('ParentId', c);
    row['ParentId'] = int(found_id) if found_id else None;
    found_id = attribute_search('AcceptedAnswerId', c);
    row['AcceptedAnswerId'] = int(found_id) if found_id else None;    
    row['CreationDate'] = datetime.datetime.strptime(attribute_search('CreationDate', c), "%Y-%m-%dT%H:%M:%S.%f");
    row['Score'] = int(attribute_search('Score', c));
    vc = attribute_search('ViewCount', c);
    row['ViewCount'] = int(vc) if vc else None;
    owner = attribute_search('OwnerUserId', c);
    row['OwnerUserId'] = int(owner) if owner else None;

    row['Body'] = re.sub('(<!--.*?-->|<[^>]*>)', '', html.unescape(attribute_search('Body', c)));
    title = attribute_search('Title', c);
    row['Title'] = title if title else None;
    tags = attribute_search('Tags', c);
    row['Tags'] = html.unescape(tags).replace('<', '').replace('>', ' ') if tags else None;
    date = attribute_search('ClosedDate', c);
    row['ClosedDate'] = datetime.datetime.strptime(date, "%Y-%m-%dT%H:%M:%S.%f") if date else None;

    count = attribute_search('AnswerCount', c);
    row['AnswerCount'] = int(count) if count else None;
    count = attribute_search('CommentCount', c);    
    row['CommentCount'] = int(count) if count else None;
    count = attribute_search('FavoriteCount', c);
    row['FavoriteCount'] = int(count) if count else None;
        
    return pyspark.Row(**row)

def comments_from_xml(line):
    c = line.replace('<row', '').replace('/>', '')
    row = dict()    
    row['Id'] = int(attribute_search('Id', c));
    row['PostId'] = int(attribute_search('PostId', c));
    row['Score'] = int(attribute_search('Score', c));
    row['Text'] = re.sub('(<!--.*?-->|<[^>]*>)', '', html.unescape(attribute_search('Text', c)));
    row['CreationDate'] = datetime.datetime.strptime(attribute_search('CreationDate', c), "%Y-%m-%dT%H:%M:%S.%f");
    user = attribute_search('UserId', c);
    row['UserId'] = int(user) if user else None;
    return pyspark.Row(**row)

def post_history_from_xml(line):
    c = line.replace('<row', '').replace('/>', '')
    row = dict()
    row['Id'] = int(attribute_search('Id', c));
    row['PostHistoryTypeId'] = int(attribute_search('PostHistoryTypeId', c));
    row['PostId'] = int(attribute_search('PostId', c));
    comm = attribute_search('Comment', c);
    row['Comment'] = comm if comm else None;
    text = attribute_search('Text', c);
    row['Text'] = re.sub('(<!--.*?-->|<[^>]*>)', '', html.unescape(text)) if text else None;    
    return pyspark.Row(**row)

def post_links_from_xml(line):
    c = line.replace('<row', '').replace('/>', '')
    row = dict()
    row['Id'] = int(attribute_search('Id', c));
    row['CreationDate'] = datetime.datetime.strptime(attribute_search('CreationDate', c), "%Y-%m-%dT%H:%M:%S.%f");
    row['PostId'] = int(attribute_search('PostId', c));
    row['RelatedPostId'] = int(attribute_search('RelatedPostId', c));
    row['LinkTypeId'] = int(attribute_search('LinkTypeId', c));    
    return pyspark.Row(**row)

In [None]:
# Wczytanie danych do RDD z plików XML, a następnie konwersja RDD do DF: 

xml_load_path = 'file:///home/marek/Dokumenty/Notebooks/gis_stack_spark/data/'

# Słownik: nazwa pliku i odpowiadająca mu funkcja pomocnicza interpretująca
# schemat danych w pliku XML
xml_load_list = {'Tags.xml': tags_from_xml, 'Badges.xml': badges_from_xml, \
                 'Users.xml': users_from_xml,'Posts.xml': posts_from_xml, \
                 'Comments.xml': comments_from_xml,'PostHistory.xml': post_history_from_xml, \
                 'PostLinks.xml': post_links_from_xml}

tags_rdd = sc.textFile(xml_load_path + 'Tags.xml').filter(lambda line: "row" in line) \
             .map(lambda l: xml_load_list['Tags.xml'](l))

badges_rdd = sc.textFile(xml_load_path + 'Badges.xml').filter(lambda line: "row" in line) \
               .map(lambda l: xml_load_list['Badges.xml'](l))

users_rdd = sc.textFile(xml_load_path + 'Users.xml').filter(lambda line: "row" in line) \
              .map(lambda l: xml_load_list['Users.xml'](l))

posts_rdd = sc.textFile(xml_load_path + 'Posts.xml').filter(lambda line: "row" in line) \
              .map(lambda l: xml_load_list['Posts.xml'](l))

comments_rdd = sc.textFile(xml_load_path + 'Comments.xml').filter(lambda line: "row" in line) \
                 .map(lambda l: xml_load_list['Comments.xml'](l))

post_history_rdd = sc.textFile(xml_load_path + 'PostHistory.xml').filter(lambda line: "row" in line) \
                     .map(lambda l: xml_load_list['PostHistory.xml'](l))

post_links_rdd = sc.textFile(xml_load_path + 'PostLinks.xml').filter(lambda line: "row" in line) \
                   .map(lambda l: xml_load_list['PostLinks.xml'](l))

# Konwersja na DataFrame:
users = sqlContext.createDataFrame(users_rdd)
badges = sqlContext.createDataFrame(badges_rdd)
posts = sqlContext.createDataFrame(posts_rdd)
tags = sqlContext.createDataFrame(tags_rdd)
comments = sqlContext.createDataFrame(comments_rdd)
post_history = sqlContext.createDataFrame(post_history_rdd)
post_links = sqlContext.createDataFrame(post_links_rdd)

In [None]:
post_links.printSchema()
post_links.show()

In [None]:
users.printSchema()
users.show()