Skip to content

Commit

Permalink
Merge pull request #33 from matlab-deep-learning/japanese-bert-fix
Browse files Browse the repository at this point in the history
Updating predictMaskedToken.m to match new tokenizer.
  • Loading branch information
debymf committed Apr 19, 2023
2 parents 6715db0 + 09ad027 commit 10fb90a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ Download or [clone](https://www.mathworks.com/help/matlab/matlab_prog/use-source
## Example: Classify Text Data Using BERT
The simplest use of a pretrained BERT model is to use it as a feature extractor. In particular, you can use the BERT model to convert documents to feature vectors which you can then use as inputs to train a deep learning classification network.

The example [`ClassifyTextDataUsingBERT.m`](./ClassifyTextDataUsingBERT.m) shows how to use a pretrained BERT model to classify failure events given a data set of factory reports.
The example [`ClassifyTextDataUsingBERT.m`](./ClassifyTextDataUsingBERT.m) shows how to use a pretrained BERT model to classify failure events given a data set of factory reports. This example requires the `factoryReports.csv` data set from the Text Analytics Toolbox example [Prepare Text Data for Analysis](https://www.mathworks.com/help/textanalytics/ug/prepare-text-data-for-analysis.html).

## Example: Fine-Tune Pretrained BERT Model
To get the most out of a pretrained BERT model, you can retrain and fine tune the BERT parameters weights for your task.

The example [`FineTuneBERT.m`](./FineTuneBERT.m) shows how to fine-tune a pretrained BERT model to classify failure events given a data set of factory reports.
The example [`FineTuneBERT.m`](./FineTuneBERT.m) shows how to fine-tune a pretrained BERT model to classify failure events given a data set of factory reports. This example requires the `factoryReports.csv` data set from the Text Analytics Toolbox example [Prepare Text Data for Analysis](https://www.mathworks.com/help/textanalytics/ug/prepare-text-data-for-analysis.html).

The example [`FineTuneBERTJapanese.m`](./FineTuneBERTJapanese.m) shows the same workflow using a pretrained Japanese-BERT model. This example requires the `factoryReportsJP.csv` data set from the Text Analytics Toolbox example [Analyze Japanese Text Data](https://www.mathworks.com/help/textanalytics/ug/analyze-japanese-text.html), available in R2023a or later.

## Example: Analyze Sentiment with FinBERT
FinBERT is a sentiment analysis model trained on financial text data and fine-tuned for sentiment analysis.
Expand Down
4 changes: 2 additions & 2 deletions predictMaskedToken.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
% replaces instances of mdl.Tokenizer.MaskToken in the string text with
% the most likely token according to the BERT model mdl.

% Copyright 2021 The MathWorks, Inc.
% Copyright 2021-2023 The MathWorks, Inc.
arguments
mdl {mustBeA(mdl,'struct')}
str {mustBeText}
Expand Down Expand Up @@ -44,7 +44,7 @@
tokens = fulltok.tokenize(pieces(i));
if ~isempty(tokens)
% "" tokenizes to empty - awkward
x = cat(2,x,fulltok.encode(tokens));
x = cat(2,x,fulltok.encode(tokens{1}));
end
if i<numel(pieces)
x = cat(2,x,maskCode);
Expand Down
47 changes: 47 additions & 0 deletions test/tpredictMaskedToken.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
classdef(SharedTestFixtures={
DownloadBERTFixture, DownloadJPBERTFixture}) tpredictMaskedToken < matlab.unittest.TestCase
% tpredictMaskedToken Unit test for predictMaskedToken

% Copyright 2023 The MathWorks, Inc.

properties(TestParameter)
Models = {"tiny","japanese-base-wwm"}
ValidText = iGetValidText;
end

methods(Test)
function verifyOutputDimSizes(test, Models, ValidText)
inSize = size(ValidText);
mdl = bert("Model", Models);
outputText = predictMaskedToken(mdl,ValidText);
test.verifyEqual(size(outputText), inSize);
end

function maskTokenIsRemoved(test, Models)
text = "This has a [MASK] token.";
mdl = bert("Model", Models);
outputText = predictMaskedToken(mdl,text);
test.verifyFalse(contains(outputText, "[MASK]"));
end

function inputWithoutMASKRemainsTheSame(test, Models)
text = "This has a no mask token.";
mdl = bert("Model", Models);
outputText = predictMaskedToken(mdl,text);
test.verifyEqual(text, outputText);
end
end
end

function validText = iGetValidText
manyStrs = ["Accelerating the pace of [MASK] and science";
"The cat [MASK] soundly.";
"The [MASK] set beautifully."];
singleStr = "Artificial intelligence continues to shape the future of industries," + ...
" as innovative applications emerge in fields such as healthcare, transportation," + ...
" entertainment, and finance, driving productivity and enhancing human capabilities.";
validText = struct('StringsAsColumns',manyStrs,...
'StringsAsRows',manyStrs',...
'ManyStrings',repmat(singleStr,3),...
'SingleString',singleStr);
end

0 comments on commit 10fb90a

Please sign in to comment.