3535import org .nd4j .linalg .api .ndarray .INDArray ;
3636import org .nd4j .linalg .factory .Nd4j ;
3737import org .nd4j .linalg .indexing .NDArrayIndex ;
38+ import org .nd4j .linalg .io .ClassPathResource ;
3839import org .nd4j .linalg .learning .config .Adam ;
3940import org .nd4j .linalg .lossfunctions .LossFunctions ;
41+ import org .nd4j .resources .Downloader ;
4042
4143import java .io .File ;
44+ import java .io .IOException ;
4245import java .net .URL ;
4346import java .nio .charset .Charset ;
47+ import java .util .Scanner ;
4448
45- /**Example: Given a movie review (raw text), classify that movie review as either positive or negative based on the words it contains.
49+ /**
50+ * Example: Given a movie review (raw text), classify that movie review as either positive or negative based on the words it contains.
4651 * This is done by combining Word2Vec vectors and a recurrent neural network model. Each word in a review is vectorized
4752 * (using the Word2Vec model) and fed into a recurrent neural network.
4853 * Training data is the "Large Movie Review Dataset" from http://ai.stanford.edu/~amaas/data/sentiment/
4954 * This data set contains 25,000 training reviews + 25,000 testing reviews
50- *
55+ * <p>
5156 * Process:
57+ * 0. If path to the wordvectors is not set and a download not found previously in the default location you will be prompted if you want to download it.
5258 * 1. Automatic on first run of example: Download data (movie reviews) + extract
53- * 2. Load existing Word2Vec model (for example: Google News word vectors. You will have to download this MANUALLY )
59+ * 2. Load existing Word2Vec model (for example: Google News word vectors.)
5460 * 3. Load each each review. Convert words to vectors + reviews to sequences of vectors
5561 * 4. Train network
56- *
62+ * <p>
5763 * With the current configuration, gives approx. 83% accuracy after 1 epoch. Better performance may be possible with
5864 * additional tuning.
5965 *
60- * NOTE / INSTRUCTIONS:
61- * You will have to download the Google News word vector model manually. ~1.5GB
62- * The Google News vector model available here: https://code.google.com/p/word2vec/
63- * Download the GoogleNews-vectors-negative300.bin.gz file
64- * Then: set the WORD_VECTORS_PATH field to point to this location.
65- *
6666 * @author Alex Black
6767 */
6868public class Word2VecSentimentRNN {
6969
70- /** Data URL for downloading */
70+ /**
71+ * Data URL for downloading
72+ */
7173 public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" ;
72- /** Location to save and extract the training/testing data */
74+ /**
75+ * Location to save and extract the training/testing data
76+ */
7377 public static final String DATA_PATH = FilenameUtils .concat (System .getProperty ("java.io.tmpdir" ), "dl4j_w2vSentiment/" );
74- /** Location (local file system) for the Google News vectors. Set this manually. */
75- public static final String WORD_VECTORS_PATH = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz" ;
76-
78+ /**
79+ * Location (local file system) for the Google News vectors. Set this manually.
80+ */
81+ public static String wordVectorsPath = "/PATH/TO/YOUR/VECTORS/GoogleNews-vectors-negative300.bin.gz" ;
7782
7883 public static void main (String [] args ) throws Exception {
79- if (WORD_VECTORS_PATH .startsWith ("/PATH/TO/YOUR/VECTORS/" )){
80- throw new RuntimeException ("Please set the WORD_VECTORS_PATH before running this example" );
84+ if (wordVectorsPath .startsWith ("/PATH/TO/YOUR/VECTORS/" )) {
85+ System .out .println ("wordVectorsPath has not been set. Checking default location in ~/dl4j-examples-data for download..." );
86+ checkDownloadW2VECModel ();
8187 }
82-
8388 //Download and extract data
8489 downloadData ();
8590
@@ -93,23 +98,23 @@ public static void main(String[] args) throws Exception {
9398
9499 //Set up network configuration
95100 MultiLayerConfiguration conf = new NeuralNetConfiguration .Builder ()
96- .seed (seed )
97- .updater (new Adam (5e-3 ))
98- .l2 (1e-5 )
99- .weightInit (WeightInit .XAVIER )
100- .gradientNormalization (GradientNormalization .ClipElementWiseAbsoluteValue ).gradientNormalizationThreshold (1.0 )
101- .list ()
102- .layer (new LSTM .Builder ().nIn (vectorSize ).nOut (256 )
103- .activation (Activation .TANH ).build ())
104- .layer (new RnnOutputLayer .Builder ().activation (Activation .SOFTMAX )
105- .lossFunction (LossFunctions .LossFunction .MCXENT ).nIn (256 ).nOut (2 ).build ())
106- .build ();
101+ .seed (seed )
102+ .updater (new Adam (5e-3 ))
103+ .l2 (1e-5 )
104+ .weightInit (WeightInit .XAVIER )
105+ .gradientNormalization (GradientNormalization .ClipElementWiseAbsoluteValue ).gradientNormalizationThreshold (1.0 )
106+ .list ()
107+ .layer (new LSTM .Builder ().nIn (vectorSize ).nOut (256 )
108+ .activation (Activation .TANH ).build ())
109+ .layer (new RnnOutputLayer .Builder ().activation (Activation .SOFTMAX )
110+ .lossFunction (LossFunctions .LossFunction .MCXENT ).nIn (256 ).nOut (2 ).build ())
111+ .build ();
107112
108113 MultiLayerNetwork net = new MultiLayerNetwork (conf );
109114 net .init ();
110115
111116 //DataSetIterators for training and testing respectively
112- WordVectors wordVectors = WordVectorSerializer .loadStaticModel (new File (WORD_VECTORS_PATH ));
117+ WordVectors wordVectors = WordVectorSerializer .loadStaticModel (new File (wordVectorsPath ));
113118 SentimentExampleIterator train = new SentimentExampleIterator (DATA_PATH , wordVectors , batchSize , truncateReviewsToLength , true );
114119 SentimentExampleIterator test = new SentimentExampleIterator (DATA_PATH , wordVectors , batchSize , truncateReviewsToLength , false );
115120
@@ -119,7 +124,7 @@ public static void main(String[] args) throws Exception {
119124
120125 //After training: load a single example and generate predictions
121126 File shortNegativeReviewFile = new File (FilenameUtils .concat (DATA_PATH , "aclImdb/test/neg/12100_1.txt" ));
122- String shortNegativeReview = FileUtils .readFileToString (shortNegativeReviewFile , (Charset )null );
127+ String shortNegativeReview = FileUtils .readFileToString (shortNegativeReviewFile , (Charset ) null );
123128
124129 INDArray features = test .loadFeaturesFromString (shortNegativeReview , truncateReviewsToLength );
125130 INDArray networkOutput = net .output (features );
@@ -138,15 +143,15 @@ public static void main(String[] args) throws Exception {
138143 public static void downloadData () throws Exception {
139144 //Create directory if required
140145 File directory = new File (DATA_PATH );
141- if (!directory .exists ()) directory .mkdir ();
146+ if (!directory .exists ()) directory .mkdir ();
142147
143148 //Download file:
144149 String archizePath = DATA_PATH + "aclImdb_v1.tar.gz" ;
145150 File archiveFile = new File (archizePath );
146151 String extractedPath = DATA_PATH + "aclImdb" ;
147152 File extractedFile = new File (extractedPath );
148153
149- if ( !archiveFile .exists () ) {
154+ if ( !archiveFile .exists ()) {
150155 System .out .println ("Starting data download (80MB)..." );
151156 FileUtils .copyURLToFile (new URL (DATA_URL ), archiveFile );
152157 System .out .println ("Data (.tar.gz file) downloaded to " + archiveFile .getAbsolutePath ());
@@ -155,14 +160,53 @@ public static void downloadData() throws Exception {
155160 } else {
156161 //Assume if archive (.tar.gz) exists, then data has already been extracted
157162 System .out .println ("Data (.tar.gz file) already exists at " + archiveFile .getAbsolutePath ());
158- if ( !extractedFile .exists ()){
159- //Extract tar.gz file to output directory
160- DataUtilities .extractTarGz (archizePath , DATA_PATH );
163+ if ( !extractedFile .exists ()) {
164+ //Extract tar.gz file to output directory
165+ DataUtilities .extractTarGz (archizePath , DATA_PATH );
161166 } else {
162- System .out .println ("Data (extracted) already exists at " + extractedFile .getAbsolutePath ());
167+ System .out .println ("Data (extracted) already exists at " + extractedFile .getAbsolutePath ());
163168 }
164169 }
165170 }
166171
167-
172+ public static void checkDownloadW2VECModel () throws IOException {
173+ String defaultwordVectorsPath = FilenameUtils .concat (System .getProperty ("user.home" ), "dl4j-examples-data/w2vec300" );
174+ wordVectorsPath = new File (defaultwordVectorsPath , "GoogleNews-vectors-negative300.bin.gz" ).getAbsolutePath ();
175+ if (new File (wordVectorsPath ).exists ()) {
176+ System .out .println ("\n \t GoogleNews-vectors-negative300.bin.gz file found at path: " + defaultwordVectorsPath );
177+ System .out .println ("\t Checking md5 of existing file.." );
178+ if (Downloader .checkMD5OfFile ("1c892c4707a8a1a508b01a01735c0339" , new File (wordVectorsPath ))) {
179+ System .out .println ("\t Existing file hash matches." );
180+ return ;
181+ } else {
182+ System .out .println ("\t Existing file hash doesn't match. Retrying download..." );
183+ }
184+ } else {
185+ System .out .println ("\n \t No previous download of GoogleNews-vectors-negative300.bin.gz found at path: " + defaultwordVectorsPath );
186+ }
187+ System .out .println ("\t WARNING: GoogleNews-vectors-negative300.bin.gz is a 1.5GB file." );
188+ System .out .println ("\t Press \" ENTER\" to start a download of GoogleNews-vectors-negative300.bin.gz to " + defaultwordVectorsPath );
189+ Scanner scanner = new Scanner (System .in );
190+ scanner .nextLine ();
191+ System .out .println ("Starting model download (1.5GB!)..." );
192+ String downloadScript = new ClassPathResource ("w2vecdownload/word2vec-download300model.sh" ).getFile ().getAbsolutePath ();
193+ ProcessBuilder processBuilder = new ProcessBuilder (downloadScript , defaultwordVectorsPath );
194+ try {
195+ processBuilder .inheritIO ();
196+ Process process = processBuilder .start ();
197+ int exitVal = process .waitFor ();
198+ if (exitVal == 0 ) {
199+ System .out .println ("Successfully downloaded word2vec model!" );
200+ } else {
201+ System .out .println ("Download failed. Please download model manually and set the \" wordVectorsPath\" in the code with the path to it." );
202+ System .exit (0 );
203+ }
204+ } catch (IOException e ) {
205+ e .printStackTrace ();
206+ } catch (InterruptedException e ) {
207+ e .printStackTrace ();
208+ }
209+ }
168210}
211+
212+
0 commit comments