In [6]:
"""
!python3.7 -m pip install sentencepiece
!python3.7 -m pip install bert-for-tf2
!python3.7 -m pip install tensorflow_hub
"""


Collecting sentencepiece
  Downloading https://files.pythonhosted.org/packages/11/e0/1264990c559fb945cfb6664742001608e1ed8359eeec6722830ae085062b/sentencepiece-0.1.85-cp37-cp37m-manylinux1_x86_64.whl (1.0MB)
[K    100% |████████████████████████████████| 1.0MB 1.2MB/s eta 0:00:01
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.85


In [1]:
import pyspark

from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *

import pandas as pd

import re

import random

import socket

import tensorflow as tf
import tensorflow_hub as thub
import bert

import matplotlib.pyplot as plt
%matplotlib inline

  from ._conv import register_converters as _register_converters


In [2]:
bert_layer = thub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/1",
                            trainable=True)
vocabulary_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
BertTokenizer = bert.bert_tokenization.FullTokenizer
tokenizer = BertTokenizer(vocabulary_file, do_lower_case=False)

In [3]:
config = pyspark.SparkConf().setAll([("spark.dynamicAllocation.enabled","True"),
                                    ("spark.executor.cores","2")])

In [4]:
sc = SparkContext(conf=config)
spark = SparkSession(sc)

In [12]:
%time tsarc = spark.read.csv("gs://sarc-bucket-5/reddit_trunc.csv", inferSchema=True, header=False, sep = ',')

CPU times: user 13.4 ms, sys: 340 µs, total: 13.8 ms
Wall time: 1min 12s


In [13]:
# Rename columns: 

tsarc = tsarc.withColumnRenamed('_c0','label').withColumnRenamed('_c1','subreddit').withColumnRenamed('_c2','context')


In [14]:
tsarc.show(10)

+-----+---------+--------------------+
|label|subreddit|             context|
+-----+---------+--------------------+
|    0| Portland|All these fucking...|
|    0|     milf|Mother of one on ...|
|    0| gonewild|{F}uckable? ;) Th...|
|    0| politics|Took a loan. The ...|
|    0|     pics|I see your kitche...|
|    0| politics|Man, that was pol...|
|    0|  atheism|"Met my first fun...|
|    0|     gifs|Emma Watson danci...|
|    0|    ducks|The Duck is losin...|
|    0|      CFB|I would just like...|
+-----+---------+--------------------+
only showing top 10 rows



In [16]:
strings = "{F}uckable? ;) The dumb peopld yeah 123 !!1"
tokenized = ["[CLS]"] + tokenizer.tokenize(strings) + ["[SEP]"]
ids = tokenizer.convert_tokens_to_ids(tokenized)
ids

[101,
 196,
 143,
 198,
 190,
 2158,
 1895,
 136,
 132,
 114,
 1109,
 14908,
 185,
 8209,
 1643,
 5253,
 8147,
 13414,
 106,
 106,
 122,
 102]

In [21]:
def tokenize_sample(context):
    
    """
    To be applied over dataframe.
    Takes a string and converts it to token IDs via BERT tokenizer,
    adding the necessary beginning and end tokens
    """
    
    tokenized = ["[CLS]"] + tokenizer.tokenize(context) + ["[SEP]"]
    ids = tokenizer.convert_tokens_to_ids(tokenized)
    
    return ids

tokenize_sample_udf = F.udf(tokenize_sample, ArrayType(IntegerType()))

In [24]:
tsarc = tsarc.withColumn("tokens", tokenize_sample_udf(tsarc.context))

In [26]:
tsarc = tsarc.drop('context')
tsarc.show()

+-----+-------------------+--------------------+
|label|          subreddit|              tokens|
+-----+-------------------+--------------------+
|    0|           Portland|[101, 1398, 1292,...|
|    0|               milf|[101, 4872, 1104,...|
|    0|           gonewild|[101, 196, 143, 1...|
|    0|           politics|[101, 6466, 1377,...|
|    0|               pics|[101, 146, 1267, ...|
|    0|           politics|[101, 2268, 117, ...|
|    0|            atheism|[101, 107, 19415,...|
|    0|               gifs|[101, 4913, 7422,...|
|    0|              ducks|[101, 1109, 16627...|
|    0|                CFB|[101, 146, 1156, ...|
|    0|              funny|[101, 1284, 2028,...|
|    0|              funny|[101, 138, 16723,...|
|    0|              funny|[101, 1135, 112, ...|
|    0|          AskReddit|[101, 5749, 1207,...|
|    0|                aww|[101, 107, 107, 1...|
|    0|      todayilearned|[101, 157, 17656,...|
|    0|PoliticalDiscussion|[101, 2082, 7691,...|
|    0|          Ask

In [36]:
tsarc.printSchema()

root
 |-- label: integer (nullable = true)
 |-- subreddit: string (nullable = true)
 |-- tokens: array (nullable = true)
 |    |-- element: integer (containsNull = true)
 |-- tokens2: string (nullable = true)



In [35]:
tsarc = tsarc.withColumn("tokens2",F.col("tokens").cast(StringType()))

In [40]:
tsarc = tsarc.drop('tokens')

In [41]:
tsarc.show(10)

+-----+---------+--------------------+
|label|subreddit|             tokens2|
+-----+---------+--------------------+
|    0| Portland|[101, 1398, 1292,...|
|    0|     milf|[101, 4872, 1104,...|
|    0| gonewild|[101, 196, 143, 1...|
|    0| politics|[101, 6466, 1377,...|
|    0|     pics|[101, 146, 1267, ...|
|    0| politics|[101, 2268, 117, ...|
|    0|  atheism|[101, 107, 19415,...|
|    0|     gifs|[101, 4913, 7422,...|
|    0|    ducks|[101, 1109, 16627...|
|    0|      CFB|[101, 146, 1156, ...|
+-----+---------+--------------------+
only showing top 10 rows



In [43]:
%time tsarc.write.csv('gs://sarc-bucket-5/tokens.csv')
#%time tsarc.write.format("text").option("header", "true").save("tokens.txt")

Py4JJavaError: An error occurred while calling o320.csv.
: org.apache.spark.SparkException: Job aborted.
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:198)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:159)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:104)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:102)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.doExecute(commands.scala:122)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:83)
	at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:81)
	at org.apache.spark.sql.DataFrameWriter$$anonfun$runCommand$1.apply(DataFrameWriter.scala:676)
	at org.apache.spark.sql.DataFrameWriter$$anonfun$runCommand$1.apply(DataFrameWriter.scala:676)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:80)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:127)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:75)
	at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:676)
	at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:285)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:271)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:229)
	at org.apache.spark.sql.DataFrameWriter.csv(DataFrameWriter.scala:664)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Task 14 in stage 15.0 failed 4 times, most recent failure: Lost task 14.3 in stage 15.0 (TID 1368, sarc-cluster-w-2.us-central1-a.c.sarcasm-5.internal, executor 16): org.apache.spark.SparkException: Task failed while writing rows.
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask(FileFormatWriter.scala:257)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:170)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:169)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:123)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
    process()
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 352, in dump_stream
    self.serializer.dump_stream(self._batched(iterator), stream)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 142, in dump_stream
    for obj in iterator:
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 341, in _batched
    for item in iterator:
  File "<string>", line 1, in <lambda>
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 85, in <lambda>
    return lambda *a: f(*a)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-21-eaec17c90935>", line 9, in tokenize_sample
  File "/opt/conda/default/lib/python3.6/site-packages/bert/tokenization/bert_tokenization.py", line 172, in tokenize
    for token in self.basic_tokenizer.tokenize(text):
  File "/opt/conda/default/lib/python3.6/site-packages/bert/tokenization/bert_tokenization.py", line 198, in tokenize
    text = convert_to_unicode(text)
  File "/opt/conda/default/lib/python3.6/site-packages/bert/tokenization/bert_tokenization.py", line 86, in convert_to_unicode
    raise ValueError("Unsupported string type: %s" % (type(text)))
ValueError: Unsupported string type: <class 'NoneType'>

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:456)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:81)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:64)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:410)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:244)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:242)
	at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1394)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask(FileFormatWriter.scala:248)
	... 10 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1892)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1880)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1879)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1879)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:927)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:927)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:927)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2113)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2062)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2051)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:738)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:167)
	... 33 more
Caused by: org.apache.spark.SparkException: Task failed while writing rows.
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask(FileFormatWriter.scala:257)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:170)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$write$1.apply(FileFormatWriter.scala:169)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:123)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 377, in main
    process()
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 372, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 352, in dump_stream
    self.serializer.dump_stream(self._batched(iterator), stream)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 142, in dump_stream
    for obj in iterator:
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 341, in _batched
    for item in iterator:
  File "<string>", line 1, in <lambda>
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 85, in <lambda>
    return lambda *a: f(*a)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
  File "<ipython-input-21-eaec17c90935>", line 9, in tokenize_sample
  File "/opt/conda/default/lib/python3.6/site-packages/bert/tokenization/bert_tokenization.py", line 172, in tokenize
    for token in self.basic_tokenizer.tokenize(text):
  File "/opt/conda/default/lib/python3.6/site-packages/bert/tokenization/bert_tokenization.py", line 198, in tokenize
    text = convert_to_unicode(text)
  File "/opt/conda/default/lib/python3.6/site-packages/bert/tokenization/bert_tokenization.py", line 86, in convert_to_unicode
    raise ValueError("Unsupported string type: %s" % (type(text)))
ValueError: Unsupported string type: <class 'NoneType'>

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:456)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:81)
	at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.read(PythonUDFRunner.scala:64)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:410)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:244)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$$anonfun$org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask$3.apply(FileFormatWriter.scala:242)
	at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1394)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.org$apache$spark$sql$execution$datasources$FileFormatWriter$$executeTask(FileFormatWriter.scala:248)
	... 10 more
