In [None]:
! pip install --quiet langdetect
! pip install networkx

In [None]:
from pyspark.sql import SparkSession, Row

spark = SparkSession \
    .builder \
    .appName("Retweet Regression Data") \
    .config("spark.jars", "/home/jovyan/work/gcs-connector-hadoop2-latest.jar") \
    .config("spark.hadoop.fs.gs.impl", "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem") \
    .config("spark.hadoop.google.cloud.auth.service.account.enable", "true") \
    .config("spark.hadoop.google.cloud.auth.service.account.json.keyfile", "/home/jovyan/work/key.json") \
    .config("spark.driver.memory", "22g") \
    .config("spark.driver.maxResultSize", "12g") \
    .config("spark.executor.memory", "3g") \
    .getOrCreate()

In [None]:
from datetime import datetime

tweets = spark.read.parquet('gs://spain-tweets-warehouse')

tweets = tweets \
    .where(tweets.datestamp >= datetime(2017,8,1)) \
    .where(tweets.datestamp < datetime(2017,8,2))

In [None]:
tweets.createOrReplaceTempView('tweets')
tweets.printSchema()

In [None]:
from langdetect import detect_langs
from langdetect.lang_detect_exception import LangDetectException
from itertools import permutations

def confident_lang(text):
    try: 
        langs = detect_langs(text)
        top = langs[0]
        if top.prob > 0.75:
            return top.lang
        elif top.lang == 'cat' or top.lang == 'es':
            # print(f'could not find language.\n Probs: {langs}.\n Text: {text}')
            return None
        else:
            return None
    except LangDetectException:
        return None
    
    
def user_lang(di):
    """Picks the language of the user

    :param di: dictionary of language-> percentage
    :returns: language picked for user, as a string
    
    Examples
    --------
    >>> user_lang({ 'es': .7, 'en': .3 })
    es

    """
    if di.get('ca', 0) > 0.10:
        return 'ca'
    
    lang,val = None,0

    for k,v in di.items():
        if v > val:
            lang,val = k,v

    return lang

def create_user(di):
    lang = user_lang(di['langs'])
    return (lang, di['retweets'])


def user_stats(user_info, user_networks, users, user):

    # this allows network to be decoupled from this stage
    network = user_networks[user]
    relevant = [(u, user_info[u]) for u in users if u in network]

    get_lang = lambda lang: [id_ for id_, (lang, _) in relevant
                             if lang == lang]

    # compute wanted stats:
    tot_engaged = len(relevant)
    tot_cat_engaged = len(get_lang('cat'))
    net_size = len(network)
    user_lang, _ = user_info[user]

    return user, user_lang, tot_engaged, tot_cat_engaged, net_size


def get_stats_for_users(user_info, user_networks, tweet, lang, users):
    stats = [user_stats(user_info, user_networks, users, user)
             for user in users]

    # TODO: add tweet language...
    stats = [(tweet, lang) + s for s in stats]
    return stats


In [None]:
# NOTE: not getting extended text!!!
# text should not be used for anything!!
query = """
SELECT retweeted_status.id AS id, 
       retweeted_status.text as text,
       user.id AS user
FROM tweets 
WHERE retweeted_status IS NOT NULL
"""

df = spark.sql(query)
df.printSchema()

In [None]:
partitions = int(df.rdd.getNumPartitions() * 4)

Tweet = Row('id', 'text', 'user', 'lang')

tweets = df.rdd \
           .repartition(partitions) \
           .map(lambda x: x.asDict(True)) \
           .map(lambda d: {**d, 'lang': confident_lang(d['text'])}) \
           .filter(lambda d: d['lang'] is not None) \
           .map(lambda d: Tweet(*d.values())) \
           .toDF() \
           .cache()

tweets.createOrReplaceTempView('tweets')
tweets.printSchema()

In [None]:
query = """
with t as (
SELECT 
    id, 
    first(lang), 
    count(*) as count, 
    collect_set(user) as users 
FROM tweets
GROUP BY id
)
SELECT *
FROM t
WHERE count > 1
"""

tweet_users = spark.sql(query).repartition(96).cache()
tweet_users.createOrReplaceTempView('tweet_users')
tweet_users.printSchema()

In [None]:
tweet_users.rdd.getNumPartitions()

In [None]:
%%time

from pyspark.sql.types import ArrayType, LongType

def perms(a):
    p = permutations(a, 2)
    return list(p)

spark.udf.register('perms', perms,  ArrayType(ArrayType(LongType())))

query = """
with t as (
SELECT explode(perms(users)) as pairs
FROM tweet_users
)
SELECT pairs[0] as a, pairs[1] as b
FROM t
"""

user_pairs = spark.sql(query)


# user_pairs.repartition(200, 'a', 'b')

for up in user_pairs.toLocalIterator():
    pass
# user_edges = spark.sql(query).toLocalIterator()

In [None]:
spark.sql(query).rdd.getNumPartitions()

In [None]:
%%time

query = """

with tt as (
with t as (
SELECT COUNT(*) as lang_counts, user, lang
FROM tweets
GROUP BY user, lang
)
SELECT SUM(lang_counts) OVER (partition by user) as retweets,
       lang_counts, 
       lang,
       user
FROM t
)
SELECT map_from_arrays(collect_list(lang), collect_list(lang_counts / retweets)) as langs,
       FIRST(retweets) as retweets,
       user
FROM tt
GROUP BY user
"""

user_info = spark.sql(query) \
    .rdd \
    .map(lambda r: (r.user, r.asDict())) \
    .mapValues(create_user) \
    .collectAsMap()

In [None]:
%time

import networkx as nx
G = nx.Graph()

for user, (lang, count) in user_info.items(): 
    G.add_node(user, lang = lang, count = count)

for row in user_edges:
    G.add_edge(row.a, row.b, weight=row.weight)

In [None]:
G

In [None]:
RegressionData = Row('id', 'lang', 'user', 'user_lang', 'engaged', 'cat_engaged', 'net_size')

bc_ui = spark.sparkContext.broadcast(user_info)
bc_un = spark.sparkContext.broadcast(user_networks)

reg_data = tweet_users \
    .flatMap(lambda d: get_stats_for_users(bc_ui.value, bc_un.value, d['id'], d['lang'], d['users'])) \
    .map(lambda t: RegressionData(*t)) \
    .toDF()

In [None]:
bc_ui = spark.sparkContext.broadcast(user_info)

In [None]:
reg_data = pd.DataFrame(reg_data)

In [None]:
import numpy as np
import seaborn as sns

edges = np.array([len(v) for v in user_networks.values()])
sns.distplot(edges)