# LSTM Example: 20 Newsgroup Classification

This shows how to do an LSTM classification on the 20 newsgroup collection

## Requirements

Running this notebook requires a minimum of 16GB executor memory.  Starting it up in an ordinary jupyter notebook 
won't work as you'll get an out of memory exception.  Quit jupyter and make sure you start this with at least --executor-memory 16gb

In [1]:
import itertools
import re
from optparse import OptionParser

from bigdl.dataset import news20
from bigdl.nn.layer import *
from bigdl.nn.criterion import *
from bigdl.optim.optimizer import *
from bigdl.util.common import *
from bigdl.util.common import Sample
import datetime as dt



In [2]:

def text_to_words(review_text):
    letters_only = re.sub("[^a-zA-Z]", " ", review_text)
    words = letters_only.lower().split()
    return words


def analyze_texts(data_rdd):
    def index(w_c_i):
        ((w, c), i) = w_c_i
        return (w, (i + 1, c))
    return data_rdd.flatMap(lambda text_label: text_to_words(text_label[0])) \
        .map(lambda word: (word, 1)).reduceByKey(lambda a, b: a + b) \
        .sortBy(lambda w_c: - w_c[1]).zipWithIndex() \
        .map(lambda w_c_i: index(w_c_i)).collect()


# pad([1, 2, 3, 4, 5], 0, 6)
def pad(l, fill_value, width):
    if len(l) >= width:
        return l[0: width]
    else:
        l.extend([fill_value] * (width - len(l)))
        return l


def to_vec(token, b_w2v, embedding_dim):
    if token in b_w2v:
        return b_w2v[token]
    else:
        return pad([], 0, embedding_dim)


def to_sample(vectors, label, embedding_dim):
    # flatten nested list
    flatten_features = list(itertools.chain(*vectors))
    features = np.array(flatten_features, dtype='float').reshape(
        [sequence_len, embedding_dim])

    return Sample.from_ndarray(features, np.array(label))



## The Model: An LSTM

Notice that this is an LSTM model.

In [None]:

def build_model(class_num):
    model = Sequential()

    model.add(Recurrent().add(LSTM(embedding_dim, 256, p)))
    model.add(Select(2, -1))
    model.add(Linear(256, 128)) \
        .add(Dropout(0.2)) \
        .add(ReLU()) \
        .add(Linear(128, class_num)) \
        .add(LogSoftMax())

    return model



In [3]:

action = "train"
learning_rate = 0.05
batch_size = 128
embedding_dim = 300
max_epoch = 30
model_type = "lstm"
p = 0.0
data_path="/tmp/news20"
sequence_len = 500
max_words = 5000
training_split = 0.8

In [4]:

redire_spark_logs()
show_bigdl_info_logs()
init_engine()

In [13]:
#train(sc, data_path, batch_size, sequence_len, max_words, embedding_dim, training_split)
print('Processing text dataset')
texts = news20.get_news20(source_dir=data_path)
data_rdd = sc.parallelize(texts, 2)

word_to_ic = analyze_texts(data_rdd)

# Only take the top wc between [10, sequence_len]
word_to_ic = dict(word_to_ic[10: max_words])
bword_to_ic = sc.broadcast(word_to_ic)

w2v = news20.get_glove_w2v(dim=embedding_dim)
filtered_w2v = dict((w, v) for w, v in w2v.items() if w in word_to_ic)
bfiltered_w2v = sc.broadcast(filtered_w2v)

tokens_rdd = data_rdd.map(lambda text_label:
                          ([w for w in text_to_words(text_label[0]) if
                            w in bword_to_ic.value], text_label[1]))

padded_tokens_rdd = tokens_rdd.map(
    lambda tokens_label: (pad(tokens_label[0], "##", sequence_len), tokens_label[1]))

vector_rdd = padded_tokens_rdd.map(lambda tokens_label:
                                    ([to_vec(w, bfiltered_w2v.value,
                                            embedding_dim) for w in tokens_label[0]], tokens_label[1]))
sample_rdd = vector_rdd.map(
    lambda vectors_label: to_sample(vectors_label[0], vectors_label[1], embedding_dim))

train_rdd, val_rdd = sample_rdd.randomSplit(
    [training_split, 1-training_split])

optimizer = Optimizer(
    model=build_model(news20.CLASS_NUM),
    training_rdd=train_rdd,
    criterion=ClassNLLCriterion(),
    end_trigger=MaxEpoch(max_epoch),
    batch_size=batch_size,
    optim_method=Adagrad(learningrate=learning_rate, learningrate_decay=0.001))

optimizer.set_validation(
    batch_size=batch_size,
    val_rdd=val_rdd,
    trigger=EveryEpoch(),
    val_method=[Top1Accuracy()]
)

logdir = '/tmp/.bigdl/'
app_name = 'adam-' + dt.datetime.now().strftime("%Y%m%d-%H%M%S")

train_summary = TrainSummary(log_dir=logdir, app_name=app_name)
train_summary.set_summary_trigger("Parameters", SeveralIteration(50))
val_summary = ValidationSummary(log_dir=logdir, app_name=app_name)
optimizer.set_train_summary(train_summary)
optimizer.set_val_summary(val_summary)
    
train_model = optimizer.optimize()





Processing text dataset
Found 19997 texts.
creating: createSequential
creating: createRecurrent
creating: createLSTM
creating: createSelect
creating: createLinear
creating: createDropout
creating: createReLU
creating: createLinear
creating: createLogSoftMax
creating: createClassNLLCriterion
creating: createMaxEpoch
creating: createAdagrad
creating: createOptimizer
creating: createEveryEpoch
creating: createTop1Accuracy
creating: createTrainSummary
creating: createSeveralIteration
creating: createValidationSummary


Py4JJavaError: An error occurred while calling o601.optimize.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 31.0 failed 1 times, most recent failure: Lost task 0.0 in stage 31.0 (TID 38, localhost, executor driver): java.lang.OutOfMemoryError: Java heap space
	at java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:57)
	at java.nio.ByteBuffer.allocate(ByteBuffer.java:335)
	at org.apache.spark.storage.ShuffleBlockFetcherIterator$$anonfun$5.apply(ShuffleBlockFetcherIterator.scala:390)
	at org.apache.spark.storage.ShuffleBlockFetcherIterator$$anonfun$5.apply(ShuffleBlockFetcherIterator.scala:390)
	at org.apache.spark.util.io.ChunkedByteBufferOutputStream.allocateNewChunkIfNeeded(ChunkedByteBufferOutputStream.scala:87)
	at org.apache.spark.util.io.ChunkedByteBufferOutputStream.write(ChunkedByteBufferOutputStream.scala:75)
	at org.apache.spark.util.Utils$$anonfun$copyStream$1.apply$mcJ$sp(Utils.scala:342)
	at org.apache.spark.util.Utils$$anonfun$copyStream$1.apply(Utils.scala:327)
	at org.apache.spark.util.Utils$$anonfun$copyStream$1.apply(Utils.scala:327)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1337)
	at org.apache.spark.util.Utils$.copyStream(Utils.scala:348)
	at org.apache.spark.storage.ShuffleBlockFetcherIterator.next(ShuffleBlockFetcherIterator.scala:395)
	at org.apache.spark.storage.ShuffleBlockFetcherIterator.next(ShuffleBlockFetcherIterator.scala:59)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.util.CompletionIterator.hasNext(CompletionIterator.scala:32)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:438)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$class.foreach(Iterator.scala:893)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1336)
	at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
	at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
	at scala.collection.AbstractIterator.to(Iterator.scala:1336)
	at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1336)
	at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1336)
	at com.intel.analytics.bigdl.dataset.DataSet$$anonfun$5.apply(DataSet.scala:360)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1499)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1487)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1486)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1486)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:814)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:814)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:814)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1714)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1669)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1658)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:630)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2022)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2043)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2062)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2087)
	at org.apache.spark.rdd.RDD.count(RDD.scala:1158)
	at com.intel.analytics.bigdl.dataset.DistributedDataSet$$anon$5.cache(DataSet.scala:188)
	at com.intel.analytics.bigdl.optim.DistriOptimizer.prepareInput(DistriOptimizer.scala:757)
	at com.intel.analytics.bigdl.optim.DistriOptimizer.optimize(DistriOptimizer.scala:777)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:280)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.OutOfMemoryError: Java heap space
	at java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:57)
	at java.nio.ByteBuffer.allocate(ByteBuffer.java:335)
	at org.apache.spark.storage.ShuffleBlockFetcherIterator$$anonfun$5.apply(ShuffleBlockFetcherIterator.scala:390)
	at org.apache.spark.storage.ShuffleBlockFetcherIterator$$anonfun$5.apply(ShuffleBlockFetcherIterator.scala:390)
	at org.apache.spark.util.io.ChunkedByteBufferOutputStream.allocateNewChunkIfNeeded(ChunkedByteBufferOutputStream.scala:87)
	at org.apache.spark.util.io.ChunkedByteBufferOutputStream.write(ChunkedByteBufferOutputStream.scala:75)
	at org.apache.spark.util.Utils$$anonfun$copyStream$1.apply$mcJ$sp(Utils.scala:342)
	at org.apache.spark.util.Utils$$anonfun$copyStream$1.apply(Utils.scala:327)
	at org.apache.spark.util.Utils$$anonfun$copyStream$1.apply(Utils.scala:327)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1337)
	at org.apache.spark.util.Utils$.copyStream(Utils.scala:348)
	at org.apache.spark.storage.ShuffleBlockFetcherIterator.next(ShuffleBlockFetcherIterator.scala:395)
	at org.apache.spark.storage.ShuffleBlockFetcherIterator.next(ShuffleBlockFetcherIterator.scala:59)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.util.CompletionIterator.hasNext(CompletionIterator.scala:32)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:438)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$class.foreach(Iterator.scala:893)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1336)
	at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
	at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
	at scala.collection.AbstractIterator.to(Iterator.scala:1336)
	at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1336)
	at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1336)
	at com.intel.analytics.bigdl.dataset.DataSet$$anonfun$5.apply(DataSet.scala:360)


----------------------------------------
Exception happened during processing of request from ('127.0.0.1', 41902)
----------------------------------------


Traceback (most recent call last):
  File "/opt/conda/envs/py27/lib/python2.7/SocketServer.py", line 290, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/opt/conda/envs/py27/lib/python2.7/SocketServer.py", line 318, in process_request
    self.finish_request(request, client_address)
  File "/opt/conda/envs/py27/lib/python2.7/SocketServer.py", line 331, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/opt/conda/envs/py27/lib/python2.7/SocketServer.py", line 652, in __init__
    self.handle()
  File "/usr/local/spark/python/pyspark/accumulators.py", line 235, in handle
    num_updates = read_int(self.rfile)
  File "/usr/local/spark/python/pyspark/serializers.py", line 577, in read_int
    raise EOFError
EOFError
