forked from tensorflow/tensorflow
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ref tensorflow#40: Simple example for text classification saving/rest…
…oring
- Loading branch information
1 parent
0631538
commit 0655342
Showing
1 changed file
with
101 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright 2015 Google Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import csv | ||
import numpy as np | ||
from sklearn import metrics | ||
|
||
import tensorflow as tf | ||
from tensorflow.models.rnn import rnn, rnn_cell | ||
import skflow | ||
|
||
### Training data | ||
|
||
# Download dbpedia_csv.tar.gz from | ||
# https://drive.google.com/folderview?id=0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M | ||
# Unpack: tar -xvf dbpedia_csv.tar.gz | ||
|
||
def load_dataset(filename): | ||
target = [] | ||
data = [] | ||
reader = csv.reader(open(filename), delimiter=',') | ||
for line in reader: | ||
target.append(int(line[0])) | ||
data.append(line[2]) | ||
return data, np.array(target, np.float32) | ||
|
||
X_train, y_train = load_dataset('dbpedia_csv/train.csv') | ||
X_test, y_test = load_dataset('dbpedia_csv/test.csv') | ||
|
||
### Process vocabulary | ||
|
||
MAX_DOCUMENT_LENGTH = 10 | ||
|
||
vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH) | ||
X_train = np.array(list(vocab_processor.fit_transform(X_train))) | ||
X_test = np.array(list(vocab_processor.transform(X_test))) | ||
|
||
n_words = len(vocab_processor.vocabulary_) | ||
print('Total words: %d' % n_words) | ||
|
||
### Models | ||
|
||
EMBEDDING_SIZE = 50 | ||
|
||
def average_model(X, y): | ||
word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words, | ||
embedding_size=EMBEDDING_SIZE, name='words') | ||
features = tf.reduce_max(word_vectors, reduction_indices=1) | ||
return skflow.models.logistic_regression(features, y) | ||
|
||
def rnn_model(X, y): | ||
"""Recurrent neural network model to predict from sequence of words | ||
to a class.""" | ||
# Convert indexes of words into embeddings. | ||
# This creates embeddings matrix of [n_words, EMBEDDING_SIZE] and then | ||
# maps word indexes of the sequence into [batch_size, sequence_length, | ||
# EMBEDDING_SIZE]. | ||
word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words, | ||
embedding_size=EMBEDDING_SIZE, name='words') | ||
# Split into list of embedding per word, while removing doc length dim. | ||
# word_list results to be a list of tensors [batch_size, EMBEDDING_SIZE]. | ||
word_list = skflow.ops.split_squeeze(1, MAX_DOCUMENT_LENGTH, word_vectors) | ||
# Create a Gated Recurrent Unit cell with hidden size of EMBEDDING_SIZE. | ||
cell = rnn_cell.GRUCell(EMBEDDING_SIZE) | ||
# Create an unrolled Recurrent Neural Networks to length of | ||
# MAX_DOCUMENT_LENGTH and passes word_list as inputs for each unit. | ||
_, encoding = rnn.rnn(cell, word_list, dtype=tf.float32) | ||
# Given encoding of RNN, take encoding of last step (e.g hidden size of the | ||
# neural network of last step) and pass it as features for logistic | ||
# regression over output classes. | ||
return skflow.models.logistic_regression(encoding[-1], y) | ||
|
||
model_path = '/tmp/skflow_examples/text_classification' | ||
if os.path.exists(model_path): | ||
classifier = skflow.TensorFlowEstimator.restore(model_path) | ||
score = metrics.accuracy_score(classifier.predict(X_test), y_test) | ||
print('Accuracy: {0:f}'.format(score)) | ||
else: | ||
classifier = skflow.TensorFlowEstimator(model_fn=rnn_model, n_classes=15, | ||
steps=100, optimizer='Adam', learning_rate=0.01, continue_training=True) | ||
|
||
# Continuesly train for 1000 steps & predict on test set. | ||
while True: | ||
try: | ||
classifier.fit(X_train, y_train) | ||
except KeyboardInterrupt: | ||
classifier.save(model_path) | ||
break | ||
|