In this example, you learn how to train the Amazon Review dataset.
This dataset includes 30k reviews from Amazon customers on different products.
We only use review_body
and star_rating
for data and label.
You can find the example source code in: TrainAmazonReviewRanking.java.
Follow setup to configure your development environment.
In this example, we used GluonNLP pretrained DistilBert model followed by a simple MLP layer. The input is the BERT formatted tokens and output is the star rating. We recommend using GPU for training since CPU training is slow with this dataset.
cd examples
./gradlew run -Dmain=ai.djl.examples.training.transferlearning.TrainAmazonReviewRanking --args="-e 2 -b 8 -g 1"
You can adjust the maxTokenLength
variable (currently 64) to a larger value to achieve better accuracy.