Skip to content

Latest commit

 

History

History
24 lines (16 loc) · 1.13 KB

train_amazon_review_ranking.md

File metadata and controls

24 lines (16 loc) · 1.13 KB

Train Amazon Review Ranking Dataset using DistilBert

In this example, you learn how to train the Amazon Review dataset. This dataset includes 30k reviews from Amazon customers on different products. We only use review_body and star_rating for data and label.

You can find the example source code in: TrainAmazonReviewRanking.java.

Setup guide

Follow setup to configure your development environment.

Train the model

In this example, we used GluonNLP pretrained DistilBert model followed by a simple MLP layer. The input is the BERT formatted tokens and output is the star rating. We recommend using GPU for training since CPU training is slow with this dataset.

cd examples
./gradlew run -Dmain=ai.djl.examples.training.transferlearning.TrainAmazonReviewRanking --args="-e 2 -b 8 -g 1"

You can adjust the maxTokenLength variable (currently 64) to a larger value to achieve better accuracy.