# Data loader

As explained in the Readme this notebook is not meant to be run again (you can if you can wait one night to collect the data).
If you want to run it knowing this fact, download the full data manually [here]() and extract it in the data folder. 

You can just read the Notebook to understand how we first treat the data. In a nutshell the idea is the following :

1. Get the XML data files and convert it into spark dataframe.

2. Sample the __Posts__ table in order to have only the questions of the users and take only 10% of them.

3. Sample the __users__, __badges__ & __tags__ in function of the posts we have. We also create a table __country_mapping__ to map users locations to their countries.

4. Finnaly save all this new tables in the Data/sample folder !

We already imported those samples [here](https://drive.google.com/drive/folders/1ddsBX4I4hZ8pordSKf5cHRaVBnNVOcKk). The other notebooks (Analysis, ML etc) are given with a function (`stack_overflow_functions.DataLoader.download_data`) which download the data if needed. So, once again, you will have nothing to do with the data loading. :) 

## Packages import

In [None]:
import os

# Spark for data treatment
import pyspark
import sparknlp
from pyspark.sql.types import StructField, StructType, StringType
from pyspark.sql import functions as F
import pandas as pd

# Packages to geolocate users
from pycountry_convert import country_alpha2_to_continent_code, country_name_to_country_alpha2
from geopy.geocoders import Nominatim

# Global variables
seed = 2020 # Ensure reporductibility
sample_size = 0.1 # Sample proportion for the posts

### Custom functions to handle XML file

In [2]:
def line22csv(line, tags_list):
    """Maps XML lines to a CSV format"""
    results = []
    offset=0
    for i in tags_list:
        val=""
        patt=i + "="
        ind=line.find(patt,offset)
        if(ind==-1):
            results.append(None)
            continue
        ind+=(len(i)+2)
        val+='\"'
        while(line[ind]!='\"'):
            val+=line[ind]
            ind+=1
        val+='\"'
        results.append(val.strip('"'))
        offset=ind
    return tuple(results)


def schema(fields):
    """
    Creates a structure of df according to fields
    When you first collect data it is a good practice to keep
    every field in a STRING format to avoid any cast problem.
    """
    return StructType([StructField(str(field),
                                   StringType(),
                                   True) for field in fields])

### Starts the Spark context

In [3]:
spark = sparknlp.start()
conf = pyspark.SparkConf() 
sc = pyspark.SparkContext.getOrCreate(conf=conf)
sqlcontext = pyspark.SQLContext(sc)

## 1) Loading the XMLs

In [4]:
#FieldNames for Users
tags_fields = ['Id', 'TagName', 'Count', 'ExcerptPostId', 'WikiPostId']

post_fields = ['Id', 'PostTypeId', 'AcceptedAnswerId', 'CreationDate', 'Score',
               "ViewCount", "Body", "OwnerUserId", "LastEditorUserId",
               "LastEditorDisplayName",
               "LastEditDate", "LastActivityDat", "Title", "Tags",
               "AnswerCount", "CommentCount","FavoriteCount",
               "CommunityOwnedDate", "ContentLicense"
              ]

badge_fields = ["Id", "UserId", "Name", "Date", "Class", "TagBased"]

user_fields = [ "Id", "Reputation", "CreationDate", "DisplayName", "Location","Views",
               "LastAccessDate", "AboutMe", "UpVotes", "DownVotes"
              ]

In [5]:
tags_file = "Tags.xml"
post_file = "Posts.xml"
badge_file = "Badges.xml"
user_file = "Users.xml"



raw = (sc.textFile(post_file, 4))
posts = (raw.map(lambda x:line22csv(x, post_fields))
             .toDF(schema(post_fields))
             .where(F.col('PostTypeId') == '1') # Only questions
             .where(F.col('Tags').isNotNull()) # With at least one tag
             .sample(False, sample_size, seed) # Sampled at ratio of sample_size
             
            )

raw = (sc.textFile(tags_file, 4))
tags = raw.map(lambda x:line22csv(x, tags_fields)).toDF(schema(tags_fields))

raw = (sc.textFile(badge_file, 4))
badges = raw.map(lambda x:line22csv(x, badge_fields)).toDF(schema(badge_fields))

raw = (sc.textFile(user_file, 4))
users = raw.map(lambda x:line22csv(x, user_fields)).toDF(schema(user_fields))

### Estimates loading time of each tables and count rows
#### Badge

In [6]:
%%time
badges.count()

CPU times: user 30.4 ms, sys: 19 ms, total: 49.5 ms
Wall time: 5min 41s


39178979

#### Tags

In [7]:
%%time
tags.count()

CPU times: user 3.93 ms, sys: 270 µs, total: 4.2 ms
Wall time: 793 ms


60537

#### Users

In [8]:
%%time
users.count()

CPU times: user 18.7 ms, sys: 14.4 ms, total: 33 ms
Wall time: 2min 30s


14080583

#### Posts

In [9]:
%%time
posts.count()

CPU times: user 461 ms, sys: 175 ms, total: 636 ms
Wall time: 1h 21min 32s


2091001

## Sample the data
### Users & Country
#### Sample users according to what we have in posts

In [10]:
distinct_user_id = (posts
                    .select("OwnerUserId")
                    .distinct()
                   )
users = (users
         .join(distinct_user_id, users.Id == distinct_user_id.OwnerUserId)
        )
distinct_user_id.cache().count()

1071922

#### Get  countries infos of our users

To get the country informations about the users, we need to extract those informations from the 'Location' field. 
In order to do so we need to do some API calls with some libraries (here we use the `geopy` package but others can be used).
To avoid any timeout issues due to parralelisation we decided to create a pandas table with all the informations about the user location (the table is quite small), then convert it into a spark df and join them efficiently.

##### Functions to apply

In [11]:
#Creates the functions to get the country infos
def geolocate_country(col):
    """
    Geolocates the country given a location:
    args:
        col(str): The location
    returns
        country(str): The country associated (if found)
    """
    if col is None or col == "None":
        return (None, None, None, None, None)
    geolocator = Nominatim(user_agent = "aa")
    
    try:
        # Geolocate the country
        location = geolocator.geocode(col, timeout=None, language="en")
        lat, lon = location.latitude, location.longitude
        country = location.raw['display_name'].split(',')[-1].strip()
    except:
        country, lat, lon = None, None, None
    
    # Get the country ISO and the continent ISO
    try:
        cn_a2_code =  country_name_to_country_alpha2(country)
    except:
        cn_a2_code = None
    try:
        cn_continent = country_alpha2_to_continent_code(cn_a2_code)
    except:
        cn_continent = None
    
    return (country, cn_a2_code, cn_continent, lat, lon)

##### Creation of the location table

In [12]:
location_mapping = (users
                    .select("Location")
                    .distinct()
                    .where(F.col("Location").isNotNull())
                    .where(F.col("Location") != 'None')
                    .toPandas()
                   )
location_mapping["Infos"] = (location_mapping
                               .loc[:,"Location"]
                               .apply(geolocate_country)
                              )

location_mapping[['Country','Coun_iso','Cont_iso','Lat','Lon']] = \
pd.DataFrame(location_mapping.Infos.tolist()
             ,index= location_mapping.index)
location_mapping.drop('Infos', axis=1, inplace=True)
location_mapping = sqlcontext.createDataFrame(location_mapping)

### Badges
#### Sample users according to what we have in posts

In [13]:
dist_users = (posts
        .select("OwnerUserId")
        .distinct()
       )
badges = (badges
          .join(distinct_user_id, dist_users.OwnerUserId == badges.UserId)
         )
badges.count()

19572945

## Stores all the preparation into parquet format. 

In [14]:
tags_out_file = "sample/Tags"
post_out_file = "sample/Posts"
badge_out_file = "sample/Badges"
user_out_file = "sample/Users"
location_out_file = "sample/Country"

In [15]:
#Folder path to save processed file
lst_df = [
    (tags, tags_out_file),
    (badges, badge_out_file),
    (users, user_out_file),
    (posts, post_out_file),
    (location_mapping, location_out_file)
]

for elt in lst_df:
    print("Starts transformation for: " + elt[1])
    elt[0].write.parquet(elt[1])
    print("Ended succesfully")


Starts transformation for: Data/sample/Tags
Ended succesfully
Starts transformation for: Data/sample/Badges
Ended succesfully
Starts transformation for: Data/sample/Users
Ended succesfully
Starts transformation for: Data/sample/Posts
Ended succesfully
Starts transformation for: Data/sample/Country
Ended succesfully


----------------------------------------
Exception happened during processing of request from ('127.0.0.1', 47678)
Traceback (most recent call last):
  File "/usr/lib/python3.6/socketserver.py", line 320, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/usr/lib/python3.6/socketserver.py", line 351, in process_request
    self.finish_request(request, client_address)
  File "/usr/lib/python3.6/socketserver.py", line 364, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/usr/lib/python3.6/socketserver.py", line 724, in __init__
    self.handle()
  File "/usr/local/lib/python3.6/dist-packages/pyspark/accumulators.py", line 262, in handle
    poll(accum_updates)
  File "/usr/local/lib/python3.6/dist-packages/pyspark/accumulators.py", line 235, in poll
    if func():
  File "/usr/local/lib/python3.6/dist-packages/pyspark/accumulators.py", line 239, in accum_updates
    num_updates = read_int(self.rfile)
  File "/usr/lo