Skip to content

Commit

Permalink
Merge pull request NVIDIA#684 from gheinrich/dev/text-classification
Browse files Browse the repository at this point in the history
Add text classification example
  • Loading branch information
lukeyeager committed May 11, 2016
2 parents d06ae8f + 5861685 commit cc0475d
Show file tree
Hide file tree
Showing 9 changed files with 430 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/GettingStarted.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Once you have finished this guide, take a look at some of the other documentatio
* [Train an autoencoder network](../examples/autoencoder/README.md)
* [Train a regression network](../examples/regression/README.md)
* [Train a Siamese network](../examples/siamese/README.md)
* [Train a text classification network](../examples/text-classification/README.md)
* [Learn more about weight initialization](../examples/weight-init/README.md)
* [Use Python layers in your Caffe networks](../examples/python-layer/README.md)
* [Download a model and use it to classify an image outside of DIGITS](../examples/classification/README.md)
172 changes: 172 additions & 0 deletions examples/text-classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Text classification using DIGITS and Torch7

Table of Contents
=================
* [Introduction](#introduction)
* [Dataset creation](#dataset-creation)
* [Model creation](#model-creation)
* [Verification](#verification)
* [Alternative Method](#alternative-method)

## Introduction

This example follows the implementation in [Crepe](https://github.com/zhangxiangxiao/Crepe) of the following paper:

Xiang Zhang, Junbo Zhao, Yann LeCun. [Character-level Convolutional Networks for Text Classification](http://arxiv.org/abs/1509.01626). Advances in Neural Information Processing Systems 28 (NIPS 2015)

This shows how to create a feed-forward convolutional neural network that is able to classify text with high accuracy.
The network operates at the character level and does not require any feature engineering, beside converting characters to arbitrary numbers.

## Dataset Creation

We will use the [DBPedia](http://wiki.dbpedia.org) ontology dataset.
This dataset is available in `.csv` format on @zhangxiangxiao's [Google Drive storage](http://goo.gl/JyCnZq).
Download the file `dbpedia_csv.tar.gz` and extract its contents into a folder which we will later refer to as `$DBPEDIA`.

The following sample is an example from the "company" class:

> "E. D. Abbott Ltd"," Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972."
The first step to creating the dataset is to convert the `.csv` files to a format that DIGITS can use:
```sh
$ cd $DIGITS_HOME/examples/text-classification
$ ./create_dataset.py $DBPEDIA/dbpedia_csv/train.csv dbpedia/train --labels $DBPEDIA/dbpedia_csv/classes.txt --create-images
```

This script parses `train.csv` in order to generate one sample per entry in the file.
Every entry is converted into a 1024-long vector of bytes.
Characters are converted using a very simple mapping: strings are first converted to lower case and each character is replaced with its index (note: indices start from number `2`) in the extended alphabet (abc...xyz0...9 + a number of signs). Other characters (including those that are used for padding) are replaced with number `1`.
Note we have not implemented the backward quantization order mentioned in paragraph 2.2 of the paper.
The script then proceeds to reshaping data into a `32x32` matrix and saves the matrix into an unencoded LMDB file.
The above command additionally enables saving each sample into an actual image file.
Image files are saved into sub-folders that are named after the sample's class.
This makes it possible to have DIGITS proceed as if we were creating an image classification network.
We will see later how that step may be skipped.

On the DIGITS homepage, click `New Dataset > Images > Classification` then:
- change the image type to `grayscale` and the image height and width to `32`,
- point to the location of your dataset,
- use 10% of samples for validation and 1% for testing,
- make sure the image encoding is set to PNG (lossless),
- give your dataset a name then click the "Create" button.

![image classification dataset](dbpedia-dataset-creation.png)

## Model Creation

If you haven't done so already, install the `dpnn` Lua package:

```sh
luarocks install dpnn
```

On the DIGITS homepage, click `New Model > Images > Classification` then:
- select the dataset you just created,
- set the Mean Subtraction method to "None",
- select the "Custom Network" pane then click "Torch",
- in the Custom Network field paste this [network definition](text-classification-model.lua)
- give your model a name

Optionally, for better results:
- set the number of training epochs to `15`,
- set the validation interval to `0.25`,
- click "Show advanced learning rate options",
- set the learning rate policy to "Exponential Decay",
- set Gamma to `0.98`.

The model resembles a typical image classification convolutional neural network, with convolutional layers, max pooling, dropouts and a linear classifier.
The main difference is that the each character is one-hot encoded into a vector and 1D (temporal) convolutions are used instead of 2D (spatial) convolutions.

When you are ready, click the "Create" button.

After a few hours of training, your network loss and accuracy may look like:

![loss](dbpedia-loss.png)

## Verification

At the bottom of the model page, select the model snapshot that achieved the best validation accuracy (this is not necessarily the last one).
Then in the "Test a list of images" section, upload the `test.txt` file from your dataset job folder.
This text file was created by DIGITS during dataset creation and is formatted in a way that allows DIGITS to extract the ground truth and compute accuracy and a confusion matrix.
There you can also see Top-1 and Top-5 average accuracy, and per-class accuracy:

![loss](dbpedia-confusion-matrix.png)

## Alternative Method

If you think creating image files to represent text is overkill, you might be interested in this: you can create LMDB files manually and use them in DIGITS directly.
When you created the dataset with `create_dataset.py`, the script also created an LMDB database out of `train.csv`.
You can use the same script to create another database out of `test.csv` (from DBPedia ontology dataset), for validation purpose:

```sh
./create_dataset.py $DBPEDIA/dbpedia_csv/test.csv dbpedia/val
```

On the DIGITS homepage, click `New Dataset > Images > Other` then:
- in the "Images LMDB" column, select the paths to your train and validation databases, respectively (note that the labels are encoded in these databases so you don't need to specify an alternative database for labels)
- give your dataset a name then click 'Create'.

![Generic Dataset Creation](dbpedia-generic-dataset-creation.png)

What difference does it make to use this alternative method?
The main difference is that you do not need to create image files to create the dataset in DIGITS as you can pass LMDB files directly.
This can save a significant amount of time.
This also implies that you get more freedom in the data formats that you wish to use, as long as you stick to 3D blobs of data.
You may for example choose to work with 16-bit or 32-bit data, or you may choose to work with blobs that have a non-standard or unsupported number of channels.
There is a downside though: since DIGITS is not told that you are creating a classification model, DIGITS does not process the network outputs in any way.
For classification models, DIGITS is able to extract the predicted class by identifying which class had the highest probability in the SoftMax layer.
For generic ("other") models, DIGITS only shows the raw network output.
Besides, quality metrics like accuracy of confusion matrices are not computed automatically for those models.

In order to create the model, on the DIGITS homepage, click `New Model > Images > Other` then proceed exactly as you did when creating the image classification model.

After training you can test samples using the "Test a Database" section.
You just need to point to the location of an LMDB database, for example the validation database.

The following snapshot shows the first 5 inference outputs from the validation dataset:

![Generic Inference](dbpedia-generic-inference.png)

Each line shows the contents of the `logSoftMax` layer, for each sample.
The LMDB key format used by `create_dataset.py` is `(%d_%d) % (index,class)`: for the first item in the database, the key is `0_1`, which means that the item is from class `1`.
You can see that the output takes its maximum at index 1 (indices starting from 1), therefore it was correctly classified.
The predicted probability for class 0 is `math.exp(-5.72204590e-06)=0.99999427797047` (a high degree of confidence).

You may also choose to use the REST API to download predictions in JSON format.
This could be useful if you wish to implement any kind of post-processing of the data.
In the below command, replace the `job_id` and `db_path` with your job ID and LMDB path respectively:

```sh
curl localhost:5000/models/images/generic/infer_db.json -XPOST -F job_id=20160414-040451-9cc5 -F db_path="/path/to/dbpedia/test" > predictions.txt
```

Running this command will dump inference data in a format similar to:
```
{
"outputs": {
"0_1": {
"output": [
-5.7220458984375e-06,
-12.106060028076,
-25.820121765137,
-29.935920715332,
-27.315780639648,
-17.158786773682,
-22.92654800415,
-15.421851158142,
-23.10737991333,
-29.26469039917,
-16.862657546997,
-28.460214614868,
-21.428464889526,
-17.860265731812
]
},
...
```






169 changes: 169 additions & 0 deletions examples/text-classification/create_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#!/usr/bin/env python2
# Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved.
"""
Functions for creating a text classification dataset out of .csv files
The expected CSV structure is:
<Class>,<Text Field 1>, ..., <Text Field N>
"""

import argparse
import caffe
import csv
from collections import defaultdict
import h5py
import lmdb
import numpy as np
import os
import PIL.Image
import random
import re
import shutil
import sys
import time

# Find the best implementation available
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO

DB_BATCH_SIZE = 1024

ALPHABET = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+ =<>()[]{}"
FEATURE_LEN = 1024 # must have integer square root

def _save_image(image, filename):
# convert from (channels, heights, width) to (height, width)
image = image[0]
image = PIL.Image.fromarray(image)
image.save(filename)

def create_dataset(folder, input_file_name, db_batch_size=None, create_images=False, labels_file=None):
"""
Creates LMDB database and images (if create_images==True)
"""

if db_batch_size is None:
db_batch_size = DB_BATCH_SIZE

# open output LMDB
output_db = lmdb.open(folder, map_async=True, max_dbs=0)

print "Reading input file %s..." % input_file_name
# create character dict
cdict = {}
for i,c in enumerate(ALPHABET):
cdict[c] = i + 2 # indices start at 1, skip first index for 'other' characters
samples = {}
with open(input_file_name) as f:
reader = csv.DictReader(f,fieldnames=['class'],restkey='fields')
for row in reader:
label = row['class']
if label not in samples:
samples[label] = []
sample = np.ones(FEATURE_LEN) # one by default (i.e. 'other' character)
count = 0
for field in row['fields']:
for char in field.lower():
if char in cdict:
sample[count] = cdict[char]
count += 1
if count >= FEATURE_LEN-1:
break
samples[label].append(sample)
samples_per_class = None
classes = samples.keys()
class_samples = []
for c in classes:
if samples_per_class is None:
samples_per_class = len(samples[c])
else:
assert samples_per_class == len(samples[c])
class_samples.append(samples[c])

indices = np.arange(samples_per_class)
np.random.shuffle(indices)

labels = None
if labels_file is not None:
labels = map(str.strip,open(labels_file, "r").readlines())
assert len(classes) == len(samples)
else:
labels = classes
print "Class labels: %s" % repr(labels)

if create_images:
for label in labels:
os.makedirs(os.path.join(args['output'], label))

print "Storing data into %s..." % folder

batch = []
for idx in indices:
for c,cname in enumerate(classes):
class_id = c + 1 # indices start at 1
sample = class_samples[c][idx].astype('uint8')
sample = sample[np.newaxis, np.newaxis, ...]
sample = sample.reshape((1,np.sqrt(FEATURE_LEN),np.sqrt(FEATURE_LEN)))
if create_images:
filename = os.path.join(args['output'], labels[c], '%d.png' % idx)
_save_image(sample, filename)
datum = caffe.io.array_to_datum(sample, class_id)
batch.append(('%d_%d' % (idx,class_id), datum))
if len(batch) >= db_batch_size:
_write_batch_to_lmdb(output_db, batch)
batch = []

# close database
output_db.close()

return

def _write_batch_to_lmdb(db, batch):
"""
Write a batch of (key,value) to db
"""
try:
with db.begin(write=True) as lmdb_txn:
for key, datum in batch:
lmdb_txn.put(key, datum.SerializeToString())
except lmdb.MapFullError:
# double the map_size
curr_limit = db.info()['map_size']
new_limit = curr_limit*2
print('Doubling LMDB map size to %sMB ...' % (new_limit>>20,))
try:
db.set_mapsize(new_limit) # double it
except AttributeError as e:
version = tuple(int(x) for x in lmdb.__version__.split('.'))
if version < (0,87):
raise Error('py-lmdb is out of date (%s vs 0.87)' % lmdb.__version__)
else:
raise e
# try again
_write_batch_to_lmdb(db, batch)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Create Dataset tool')

### Positional arguments

parser.add_argument('input', help='Input .csv file')
parser.add_argument('output', help='Output Folder')
parser.add_argument('--create-images', action='store_true')
parser.add_argument('--labels', default=None)

args = vars(parser.parse_args())

if os.path.exists(args['output']):
shutil.rmtree(args['output'])

os.makedirs(args['output'])

start_time = time.time()

create_dataset(args['output'], args['input'], create_images = args['create_images'], labels_file = args['labels'])

print 'Done after %s seconds' % (time.time() - start_time,)

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/text-classification/dbpedia-loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit cc0475d

Please sign in to comment.