In [6]:
"""
!python3.7 -m pip install sentencepiece
!python3.7 -m pip install bert-for-tf2
!python3.7 -m pip install tensorflow_hub
"""


Collecting sentencepiece
  Downloading https://files.pythonhosted.org/packages/11/e0/1264990c559fb945cfb6664742001608e1ed8359eeec6722830ae085062b/sentencepiece-0.1.85-cp37-cp37m-manylinux1_x86_64.whl (1.0MB)
[K    100% |████████████████████████████████| 1.0MB 1.2MB/s eta 0:00:01
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.85


In [22]:
import pyspark

from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *

import pandas as pd

import random

import tensorflow as tf
import tensorflow_hub as thub
import bert

import matplotlib.pyplot as plt
%matplotlib inline

In [23]:
bert_layer = thub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/1",
                            trainable=True)
vocabulary_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
BertTokenizer = bert.bert_tokenization.FullTokenizer
tokenizer = BertTokenizer(vocabulary_file, do_lower_case=False)

In [24]:
config = pyspark.SparkConf().setAll([("spark.dynamicAllocation.enabled","True"),
                                    ("spark.executor.cores","2")])

In [25]:
sc = SparkContext(conf=config)
spark = SparkSession(sc)

In [26]:
%time tsarc = spark.read.csv("gs://sarc-bucket-5/reddit_trunc.csv", inferSchema=True, header=False, sep = ',')

CPU times: user 11.8 ms, sys: 3.98 ms, total: 15.8 ms
Wall time: 1min 8s


In [27]:
# Rename columns: 

tsarc = tsarc.withColumnRenamed('_c0','label').withColumnRenamed('_c1','subreddit').withColumnRenamed('_c2','context')


In [28]:
tsarc.show(10)

+-----+---------+--------------------+
|label|subreddit|             context|
+-----+---------+--------------------+
|    0| Portland|All these fucking...|
|    0|     milf|Mother of one on ...|
|    0| gonewild|{F}uckable? ;) Th...|
|    0| politics|Took a loan. The ...|
|    0|     pics|I see your kitche...|
|    0| politics|Man, that was pol...|
|    0|  atheism|"Met my first fun...|
|    0|     gifs|Emma Watson danci...|
|    0|    ducks|The Duck is losin...|
|    0|      CFB|I would just like...|
+-----+---------+--------------------+
only showing top 10 rows



In [29]:
tsarc.select([F.count(F.when(F.isnull(c), c)).alias(c) for c in tsarc.columns]).show()

+-----+---------+-------+
|label|subreddit|context|
+-----+---------+-------+
|    0|      799|    799|
+-----+---------+-------+



In [30]:
tsarc = tsarc.where(F.col("context").isNotNull())

In [31]:
tsarc.select([F.count(F.when(F.isnull(c), c)).alias(c) for c in tsarc.columns]).show()

+-----+---------+-------+
|label|subreddit|context|
+-----+---------+-------+
|    0|        0|      0|
+-----+---------+-------+



In [32]:
def tokenize_sample(context):
    
    """
    To be applied over dataframe.
    Takes a string and converts it to token IDs via BERT tokenizer,
    adding the necessary beginning and end tokens
    """
    
    tokenized = ["[CLS]"] + tokenizer.tokenize(context) + ["[SEP]"]
    ids = tokenizer.convert_tokens_to_ids(tokenized)
    
    return ids

tokenize_sample_udf = F.udf(tokenize_sample, ArrayType(IntegerType()))

In [33]:
tsarc = tsarc.withColumn("tokens", tokenize_sample_udf(tsarc.context))

In [34]:
tsarc = tsarc.drop('context')
tsarc.show()

+-----+-------------------+--------------------+
|label|          subreddit|              tokens|
+-----+-------------------+--------------------+
|    0|           Portland|[101, 1398, 1292,...|
|    0|               milf|[101, 4872, 1104,...|
|    0|           gonewild|[101, 196, 143, 1...|
|    0|           politics|[101, 6466, 1377,...|
|    0|               pics|[101, 146, 1267, ...|
|    0|           politics|[101, 2268, 117, ...|
|    0|            atheism|[101, 107, 19415,...|
|    0|               gifs|[101, 4913, 7422,...|
|    0|              ducks|[101, 1109, 16627...|
|    0|                CFB|[101, 146, 1156, ...|
|    0|              funny|[101, 1284, 2028,...|
|    0|              funny|[101, 138, 16723,...|
|    0|              funny|[101, 1135, 112, ...|
|    0|          AskReddit|[101, 5749, 1207,...|
|    0|                aww|[101, 107, 107, 1...|
|    0|      todayilearned|[101, 157, 17656,...|
|    0|PoliticalDiscussion|[101, 2082, 7691,...|
|    0|          Ask

In [35]:
dense_format_udf = F.udf(lambda x: str(x), StringType())

In [36]:
tsarc = tsarc.withColumn('tokens_string', dense_format_udf(F.col('tokens')))
tsarc.show()

+-----+-------------------+--------------------+--------------------+
|label|          subreddit|              tokens|       tokens_string|
+-----+-------------------+--------------------+--------------------+
|    0|           Portland|[101, 1398, 1292,...|[101, 1398, 1292,...|
|    0|               milf|[101, 4872, 1104,...|[101, 4872, 1104,...|
|    0|           gonewild|[101, 196, 143, 1...|[101, 196, 143, 1...|
|    0|           politics|[101, 6466, 1377,...|[101, 6466, 1377,...|
|    0|               pics|[101, 146, 1267, ...|[101, 146, 1267, ...|
|    0|           politics|[101, 2268, 117, ...|[101, 2268, 117, ...|
|    0|            atheism|[101, 107, 19415,...|[101, 107, 19415,...|
|    0|               gifs|[101, 4913, 7422,...|[101, 4913, 7422,...|
|    0|              ducks|[101, 1109, 16627...|[101, 1109, 16627...|
|    0|                CFB|[101, 146, 1156, ...|[101, 146, 1156, ...|
|    0|              funny|[101, 1284, 2028,...|[101, 1284, 2028,...|
|    0|             

In [37]:
tsarc = tsarc.drop('tokens')
tsarc = tsarc.drop('subreddit')
tsarc = tsarc.drop('label')
tsarc.printSchema()

root
 |-- tokens_string: string (nullable = true)



In [38]:
tsarc.show()

+--------------------+
|       tokens_string|
+--------------------+
|[101, 1398, 1292,...|
|[101, 4872, 1104,...|
|[101, 196, 143, 1...|
|[101, 6466, 1377,...|
|[101, 146, 1267, ...|
|[101, 2268, 117, ...|
|[101, 107, 19415,...|
|[101, 4913, 7422,...|
|[101, 1109, 16627...|
|[101, 146, 1156, ...|
|[101, 1284, 2028,...|
|[101, 138, 16723,...|
|[101, 1135, 112, ...|
|[101, 5749, 1207,...|
|[101, 107, 107, 1...|
|[101, 157, 17656,...|
|[101, 2082, 7691,...|
|[101, 6682, 1118,...|
|[101, 1422, 4126,...|
|[101, 157, 17656,...|
+--------------------+
only showing top 20 rows



In [None]:
%time tsarc.write.csv('gs://sarc-bucket-5/tokens.csv')

In [None]:
sc.stop()