Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Roberta implementation #4

Merged
merged 5 commits into from Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Expand Up @@ -8,13 +8,13 @@ Rust native BERT implementation. Port of Huggingface's [Transformers library](ht

The following models are currently implemented:

| |**DistilBERT**|**BERT**
:-----:|:-----:|:-----:
Masked LM|✅ |✅
Sequence classification|✅ |✅
Token classification|✅ |✅
Question answering|✅ |✅
Multiple choices| |✅
| |**DistilBERT**|**BERT**|**RoBERTa**
:-----:|:-----:|:-----:|:-----:
Masked LM|✅ |✅ |✅
Sequence classification|✅ |✅ |✅
Token classification|✅ |✅ | ✅
Question answering|✅ |✅ |✅
Multiple choices| |✅ |✅

An example for sentiment analysis classification is provided:

Expand Down
76 changes: 76 additions & 0 deletions examples/roberta.rs
@@ -0,0 +1,76 @@
extern crate failure;
extern crate dirs;

use std::path::PathBuf;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, RobertaTokenizer};
use rust_bert::bert::bert::BertConfig;
use rust_bert::common::config::Config;
use rust_bert::roberta::roberta::RobertaForMaskedLM;


fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("roberta");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let merges_path = &home.as_path().join("merges.txt");
let weights_path = &home.as_path().join("model.ot");

// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap());
let config = BertConfig::from_file(config_path);
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;

// Define input
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();

// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][5] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);

// Forward pass
let (output, _, _) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false)
});

// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(5).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

println!("{}", word_1); // Outputs "some" : "Looks like [some] thing is missing"
println!("{}", word_2);// Outputs "apple" : "It\'s like comparing [apple] to apples"

Ok(())
}
33 changes: 16 additions & 17 deletions src/bert/bert.rs
Expand Up @@ -13,7 +13,7 @@

use serde::{Deserialize, Serialize};
use crate::common::config::Config;
use crate::bert::embeddings::BertEmbeddings;
use crate::bert::embeddings::{BertEmbeddings, BertEmbedding};
use crate::bert::encoder::{BertEncoder, BertPooler};
use tch::{nn, Tensor, Kind};
use tch::kind::Kind::Float;
Expand Down Expand Up @@ -54,21 +54,20 @@ pub struct BertConfig {

impl Config<BertConfig> for BertConfig {}

pub struct BertModel {
embeddings: BertEmbeddings,
pub struct BertModel<T: BertEmbedding> {
embeddings: T,
encoder: BertEncoder,
pooler: BertPooler,
is_decoder: bool,
}

impl BertModel {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertModel {
let p = &(p / "bert");
impl <T: BertEmbedding> BertModel<T> {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertModel<T> {
let is_decoder = match config.is_decoder {
Some(value) => value,
None => false
};
let embeddings = BertEmbeddings::new(&(p / "embeddings"), config);
let embeddings = T::new(&(p / "embeddings"), config);
let encoder = BertEncoder::new(&(p / "encoder"), config);
let pooler = BertPooler::new(&(p / "pooler"), config);

Expand Down Expand Up @@ -198,13 +197,13 @@ impl BertLMPredictionHead {
}

pub struct BertForMaskedLM {
bert: BertModel,
bert: BertModel<BertEmbeddings>,
cls: BertLMPredictionHead,
}

impl BertForMaskedLM {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForMaskedLM {
let bert = BertModel::new(&p, config);
let bert = BertModel::new(&(p / "bert"), config);
let cls = BertLMPredictionHead::new(&(p / "cls"), config);

BertForMaskedLM { bert, cls }
Expand All @@ -228,14 +227,14 @@ impl BertForMaskedLM {
}

pub struct BertForSequenceClassification {
bert: BertModel,
bert: BertModel<BertEmbeddings>,
dropout: Dropout,
classifier: nn::Linear,
}

impl BertForSequenceClassification {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForSequenceClassification {
let bert = BertModel::new(&p, config);
let bert = BertModel::new(&(p / "bert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.num_labels.expect("num_labels not provided in configuration");
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
Expand All @@ -259,14 +258,14 @@ impl BertForSequenceClassification {
}

pub struct BertForMultipleChoice {
bert: BertModel,
bert: BertModel<BertEmbeddings>,
dropout: Dropout,
classifier: nn::Linear,
}

impl BertForMultipleChoice {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForMultipleChoice {
let bert = BertModel::new(&p, config);
let bert = BertModel::new(&(p / "bert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());

Expand Down Expand Up @@ -306,14 +305,14 @@ impl BertForMultipleChoice {
}

pub struct BertForTokenClassification {
bert: BertModel,
bert: BertModel<BertEmbeddings>,
dropout: Dropout,
classifier: nn::Linear,
}

impl BertForTokenClassification {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForTokenClassification {
let bert = BertModel::new(&p, config);
let bert = BertModel::new(&(p / "bert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.num_labels.expect("num_labels not provided in configuration");
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
Expand All @@ -337,13 +336,13 @@ impl BertForTokenClassification {
}

pub struct BertForQuestionAnswering {
bert: BertModel,
bert: BertModel<BertEmbeddings>,
qa_outputs: nn::Linear,
}

impl BertForQuestionAnswering {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForQuestionAnswering {
let bert = BertModel::new(&p, config);
let bert = BertModel::new(&(p / "bert"), config);
let num_labels = config.num_labels.expect("num_labels not provided in configuration");
let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, num_labels, Default::default());

Expand Down
27 changes: 19 additions & 8 deletions src/bert/embeddings.rs
Expand Up @@ -25,8 +25,19 @@ pub struct BertEmbeddings {
dropout: Dropout,
}

impl BertEmbeddings {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertEmbeddings {
pub trait BertEmbedding {
fn new(p: &nn::Path, config: &BertConfig) -> Self;

fn forward_t(&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str>;
}

impl BertEmbedding for BertEmbeddings {
fn new(p: &nn::Path, config: &BertConfig) -> BertEmbeddings {
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };

let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
Expand All @@ -50,12 +61,12 @@ impl BertEmbeddings {
BertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout }
}

pub fn forward_t(&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
fn forward_t(&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
@@ -1,5 +1,6 @@
pub mod distilbert;
pub mod bert;
pub mod roberta;
pub mod common;

pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertForTokenClassification, DistilBertForQuestionAnswering};
Expand Down
106 changes: 106 additions & 0 deletions src/roberta/embeddings.rs
@@ -0,0 +1,106 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use tch::{nn, Tensor, Kind};
use crate::common::dropout::Dropout;
use crate::bert::embeddings::BertEmbedding;
use crate::BertConfig;
use tch::nn::{EmbeddingConfig, embedding};

#[derive(Debug)]
pub struct RobertaEmbeddings {
word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding,
token_type_embeddings: nn::Embedding,
layer_norm: nn::LayerNorm,
dropout: Dropout,
padding_index: i64,
}

impl RobertaEmbeddings {
fn create_position_ids_from_input_ids(&self, x: &Tensor) -> Tensor {
let mask: Tensor = x.ne(self.padding_index).to_kind(Kind::Int64);
mask.cumsum(1, Kind::Int64) * mask + self.padding_index
}

fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
let input_shape = x.size();
let input_shape = vec!(input_shape[0], input_shape[1]);
let position_ids: Tensor = Tensor::arange1(self.padding_index + 1, input_shape[0], (Kind::Int64, x.device()));
position_ids.unsqueeze(0).expand(&input_shape, true)
}
}

impl BertEmbedding for RobertaEmbeddings {
fn new(p: &nn::Path, config: &BertConfig) -> RobertaEmbeddings {
let embedding_config = EmbeddingConfig { padding_idx: 1, ..Default::default() };

let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
config.vocab_size,
config.hidden_size,
embedding_config);

let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
Default::default());

let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default());

let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
RobertaEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout, padding_index: 1 }
}

fn forward_t(&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
}
None => match &input_embeds {
Some(embeds) => (embeds.copy(), vec!(embeds.size()[0], embeds.size()[1])),
None => { return Err("Only one of input ids or input embeddings may be set"); }
}
};


let position_ids = match position_ids {
Some(value) => value,
None => match input_ids {
Some(value) => self.create_position_ids_from_input_ids(&value),
None => self.create_position_ids_from_embeddings(&input_embeds.unwrap())
}
};

let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
};

let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);

let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
}
}
2 changes: 2 additions & 0 deletions src/roberta/mod.rs
@@ -0,0 +1,2 @@
pub mod embeddings;
pub mod roberta;