# Document Classification
This tutorial will show how to perform document classification in Tribuo, using a variety of different methods to extract features from the text. We'll use the venerable [20-newsgroups dataset](http://qwone.com/~jason/20Newsgroups/) where the task is to predict what newsgroup a particular post is from, though this tutorial would be equally applicable to any document classification task (including tasks like sentiment analysis).

# Setup

You'll need a copy of the 20 newsgroups dataset, so first download and unpack it:

```
wget http://qwone.com/~jason/20Newsgroups/20news-bydate.tar.gz
mkdir 20news
cd 20news
tar -zxf ../20news-bydate.tar.gz
```

This leaves you with two directories `20news-bydate-train` and `20news-bydate-test`, which contain the standard train and test split for this data.

20 newsgroups comes in a fairly standard format, the dataset is represented by a set of directories where the directory name is the class label, and the directory contains a collection of documents with one document in each file. Each file is a single Usenet post. For the purposes of this tutorial, we'll use the subject and body of the post as the input text for classification.

Here's an example:

```
$ ls 20news-bydate-train/
alt.atheism/               comp.sys.mac.hardware/  rec.motorcycles/     sci.electronics/         talk.politics.guns/
comp.graphics/             comp.windows.x/         rec.sport.baseball/  sci.med/                 talk.politics.mideast/
comp.os.ms-windows.misc/   misc.forsale/           rec.sport.hockey/    sci.space/               talk.politics.misc/
comp.sys.ibm.pc.hardware/  rec.autos/              sci.crypt/           soc.religion.christian/  talk.religion.misc/
$ ls 20news-bydate-train/comp.graphics/
37261  37949  38233  38270  38305  38344  38381  38417  38454  38489  38525  38562  38598  38633  38668  38703  38739
37913  37950  38234  38271  38306  38346  38382  38418  38455  38490  38526  38563  38599  38634  38669  38704  38740
37914  37951  38235  38272  38307  38347  38383  38420  38456  38491  38527  38564  38600  38635  38670  38705  38741
37915  37952  38236  38273  38308  38348  38384  38421  38457  38492  38528  38565  38601  38636  38671  38706  38742
...
```

As this is a pretty common format, Tribuo has a specific `DataSource` which can be used to read in this sort of data, `org.tribuo.data.text.DirectoryFileSource`.

We're going to use the classification experiments jar, along with the ONNX jar which provides support for loading in contextual word embedding models like [BERT](https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html).

In [1]:
%jars ./tribuo-classification-experiments-4.1.0-SNAPSHOT-jar-with-dependencies.jar
%jars ./tribuo-onnx-4.1.0-SNAPSHOT-jar-with-dependencies.jar

We'll also need a selection of imports from the `org.tribuo.data.text` package, along with the usual imports from `org.tribuo` and `org.tribuo.classification` we use when working with classification tasks. We'll load in the BERT support from the `org.tribuo.interop.onnx.extractors` package. Tribuo's BERT support loads in models and tokenizers from [HuggingFace's Transformer](https://huggingface.co/transformers/) package, and can be easily extended to support non-BERT models.

In [2]:
import java.nio.file.Paths;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.*;
import org.tribuo.data.text.*;
import org.tribuo.data.text.impl.*;
import org.tribuo.classification.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;
import org.tribuo.interop.onnx.extractors.BERTFeatureExtractor;
import org.tribuo.util.tokens.universal.UniversalTokenizer;

We'll instantiate a few classes that we'll use throughout this tutorial, the label factory, the evaluator and the paths to the train and test data.

In [3]:
var labelFactory = new LabelFactory();
var labelEvaluator = new LabelEvaluator();
var trainPath = Paths.get("./20-news/20news-bydate-train");
var testPath = Paths.get("./20-news/20news-bydate-test");

# Extracting features from text
Much of the work of machine learning is in presenting an appropriate representation of the data to the model. This is especially true when working with text data, as there is a plethora of approaches for converting text into the numbers that ML algorithms operate on. The `DirectoryFileSource` allows the user to choose the feature extraction, as it requires a `TextFeatureExtractor` which converts the `String` representing the input text into a Tribuo `Example`. We'll cover several different implementations of the `TextFeatureExtractor` interface in this tutorial, and we expect that users will implement it in their own classes to cope with specific feature extraction requirements.

We'll start with the simplest approach, a "bag of words", where each document is represented by the counts of the words in that document. This means the feature space is equal to the number of words, and most documents only have a positive value for a small number of words (as most words don't appear in any given document). This is particularly well suited to Tribuo's sparse vector representation of examples, and this suitability for NLP tasks is the reason that Tribuo is designed this way. Of course, first we'll need to tell the extractor what a word is, and for this we use a `Tokenizer`. Tokenizers split up a `String` into a stream of tokens. Tribuo provides several basic tokenizers, and an interface for tokenization. We're going to use Tribuo's `UniversalTokenizer` which is descended from tokenizers developed at Sun Labs in the 90s, and used in a variety of Sun products since that time. Once we've got tokens, we'll build a `TokenPipeline` which can convert `String`s into features, and then pass that to the basic `TextFeatureExtractor` implementation, helpfully called `TextFeatureExtractorImpl`.

In [4]:
var tokenizer = new UniversalTokenizer();
var unigramPipeline = new TokenPipeline(tokenizer, 1, true);
var unigramExtractor = new TextFeatureExtractorImpl<Label>(unigramPipeline);

We're now almost ready to make our train and test data sources, and load in the data. The `DirectoryFileSource` also accepts an array of `DocumentPreprocessor`s which can be used to transform the text before feature extraction takes place. We're going to use a specific preprocessor which standardises the 20 newsgroups data by stripping out the mail headers and returning only the subject and the body of the email. In general the preprocessors are dataset and task specific, which is why Tribuo doesn't ship with many implementations as in most cases users will need to write one from scratch for their specific task. We'll pass in an instance of the `NewsPreprocessor` and construct our data sources.

In [5]:
var newsProc = new NewsPreprocessor();

We'll make a helper function to load the data sources and create the datasets. We're also going to restrict the test dataset so it only contains valid examples, as 20 newsgroups has some test examples that share no words with the train examples (and so have no features we could use to make predictions with).

Let's check our datasets and see if everything has loaded in correctly.

In [6]:
public Pair<Dataset<Label>,Dataset<Label>> mkDatasets(String name, TextFeatureExtractor<Label> extractor) {
    var trainSource = new DirectoryFileSource<>(trainPath,labelFactory,extractor,newsProc);
    var testSource = new DirectoryFileSource<>(testPath,labelFactory,extractor,newsProc);
    var trainDS = new MutableDataset<>(trainSource);
    var testDS = new ImmutableDataset<>(testSource,trainDS.getFeatureIDMap(),trainDS.getOutputIDInfo(),true);
    System.out.println(String.format(name + " training data size = %d, number of features = %d, number of classes = %d",trainDS.size(),trainDS.getFeatureMap().size(),trainDS.getOutputInfo().size()));
    System.out.println(String.format(name + " testing data size = %d, number of features = %d, number of classes = %d",testDS.size(),testDS.getFeatureMap().size(),testDS.getOutputInfo().size()));
    return new Pair<>(trainDS,testDS);
}

var unigramPair = mkDatasets("unigram",unigramExtractor);

unigram training data size = 11314, number of features = 146037, number of classes = 20
unigram testing data size = 7531, number of features = 146037, number of classes = 20


We've loaded in 11,314 training documents containing 146,037 unique words and 7,532 test documents, each with the expected 20 classes.

Now we're ready to train a model. Let's start with a simple logistic regression.

In [7]:
var lrTrainer = new LogisticRegressionTrainer();
var unigramModel = lrTrainer.train(unigramPair.getA());
var unigramEval = labelEvaluator.evaluate(unigramModel,unigramPair.getB());
System.out.println(unigramEval);

Class                                n          tp          fn          fp      recall        prec          f1
soc.religion.christian             398         314          84         131       0.789       0.706       0.745
rec.autos                          396         305          91          80       0.770       0.792       0.781
talk.religion.misc                 251         147         104         146       0.586       0.502       0.540
comp.windows.x                     394         306          88         106       0.777       0.743       0.759
rec.sport.baseball                 397         322          75          63       0.811       0.836       0.824
comp.graphics                      389         258         131         127       0.663       0.670       0.667
talk.politics.mideast              376         287          89          46       0.763       0.862       0.810
comp.sys.ibm.pc.hardware           392         253         139         183       0.645       0.580       0.611
s

We see that the logistic regression trained on unigrams gets about 74% accuracy.

Let's try a little more complicated feature extractor. The natural step from unigrams is to include word pairs (or bigrams) and count the occurrence of those. This allows us to get simple negations (e.g., "not bad" rather than "not" and "bad") along with places like "New York" rather than "new" and "york". In Tribuo this is as straightforward as telling the token pipeline we'd like bigrams.

In [8]:
var bigramPipeline = new TokenPipeline(tokenizer, 2, true);
var bigramExtractor = new TextFeatureExtractorImpl<Label>(bigramPipeline);
var bigramPair = mkDatasets("bigram",bigramExtractor);

bigram training data size = 11314, number of features = 1253665, number of classes = 20
bigram testing data size = 7531, number of features = 1253665, number of classes = 20


We can see the feature space has massively increased due to the presence of bigram features, we've now got 1.2 million features from the same 11,314 documents.

Now to train another logistic regression.

In [9]:
var bigramModel = lrTrainer.train(bigramPair.getA());
var bigramEval = labelEvaluator.evaluate(bigramModel,bigramPair.getB());
System.out.println(bigramEval);

Class                                n          tp          fn          fp      recall        prec          f1
soc.religion.christian             398         337          61         110       0.847       0.754       0.798
rec.autos                          396         315          81          99       0.795       0.761       0.778
talk.religion.misc                 251         144         107         110       0.574       0.567       0.570
comp.windows.x                     394         287         107          88       0.728       0.765       0.746
rec.sport.baseball                 397         327          70          58       0.824       0.849       0.836
comp.graphics                      389         218         171         108       0.560       0.669       0.610
talk.politics.mideast              376         311          65          54       0.827       0.852       0.839
comp.sys.ibm.pc.hardware           392         264         128         199       0.673       0.570       0.618
s

Our performance only improved a little bit, from 74.2% to 74.5%. This is because despite there being more information in the features, there are also many, many more features making it easier to confuse this simple linear model. As we increase the number of n-gram features we'll start to see diminishing returns as the model complexity increases without a commensurate increase in training data.

One popular technique for reducing the feature space when dealing with such large problems is feature hashing. This is where the features are mapped back down to a smaller space using a hash function. It induces collisions between the features, so the model might treat "New York" and "San Fransisco" as the same feature, but the collisions are generated essentially at random based on the hash function, and so provide a strong regularising effect which frequently improves performance.

To use feature hashing in Tribuo simply pass a hash dimension to the `TokenPipeline` on construction. We'll map everything down to 50,000 features and see how that affects the model.

In [10]:
var hashPipeline = new TokenPipeline(tokenizer, 2, true, 50000);
var hashExtractor = new TextFeatureExtractorImpl<Label>(hashPipeline);
var hashPair = mkDatasets("hash-50k",hashExtractor);

hash-50k training data size = 11314, number of features = 50000, number of classes = 20
hash-50k testing data size = 7532, number of features = 50000, number of classes = 20


As expected we have the same number of training & test examples, but now there are only 50,000 features. Let's build another logistic regression.

In [12]:
var hashModel = lrTrainer.train(hashPair.getA());
var hashEval = labelEvaluator.evaluate(hashModel,hashPair.getB());
System.out.println(hashEval);

Class                                n          tp          fn          fp      recall        prec          f1
soc.religion.christian             398         319          79         105       0.802       0.752       0.776
rec.autos                          396         282         114          76       0.712       0.788       0.748
talk.religion.misc                 251         135         116         128       0.538       0.513       0.525
comp.windows.x                     395         276         119          98       0.699       0.738       0.718
rec.sport.baseball                 397         333          64          97       0.839       0.774       0.805
comp.graphics                      389         211         178         119       0.542       0.639       0.587
talk.politics.mideast              376         261         115          30       0.694       0.897       0.783
comp.sys.ibm.pc.hardware           392         244         148         201       0.622       0.548       0.583
s

- discuss word embeddings
- BERT CLS
- BERT CLS + Average token embeddings
- Provenance and big models (i.e. BERT needs to be on disk, it's not in your model), the feature extractor is provenance not an object.
- Conclusion