Skip to content

Commit

Permalink
Project structure refactoring (#387)
Browse files Browse the repository at this point in the history
* Models project refactoring

* Updated changelog
  • Loading branch information
guillaume-be committed Jun 2, 2023
1 parent 540c926 commit 7c10c25
Show file tree
Hide file tree
Showing 117 changed files with 40 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ All notable changes to this project will be documented in this file. The format
The `model_resources` field now needs to be wrapped in the corresponding enum variant, e.g. `model_resources: ModelResources::TORCH(model_resource)` for Torch-based models
- (BREAKING) Added the `forced_bos_token_id` and `forced_eos_token_id` fields to text generation models.
If these are not None, this will trigger a forced BOS/EOS token generation at the first of `max_length` positions (aligns with the Pytorch Transformers library)
- Project structure refactoring (torch-based models moved under common module). Non-breaking change via re-exports.

## Fixed
- MIN/MAX computation for float-like (was set to infinity instead of min/max)
Expand Down
31 changes: 6 additions & 25 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,34 +727,15 @@

extern crate core;

pub mod albert;
pub mod bart;
pub mod bert;
mod common;
pub mod deberta;
pub mod deberta_v2;
pub mod distilbert;
pub mod electra;
pub mod fnet;
pub mod gpt2;
pub mod gpt_j;
pub mod gpt_neo;
pub mod longformer;
pub mod longt5;
pub mod m2m_100;
pub mod marian;
pub mod mbart;
pub mod mobilebert;
pub mod nllb;
pub mod openai_gpt;
pub mod pegasus;
pub mod models;
pub mod pipelines;
pub mod prophetnet;
pub mod reformer;
pub mod roberta;
pub mod t5;
pub mod xlnet;

pub use common::error::RustBertError;
pub use common::resources;
pub use common::{Activation, Config};
pub use models::{
albert, bart, bert, deberta, deberta_v2, distilbert, electra, fnet, gpt2, gpt_j, gpt_neo,
longformer, longt5, m2m_100, marian, mbart, mobilebert, nllb, openai_gpt, pegasus, prophetnet,
reformer, roberta, t5, xlnet,
};
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
27 changes: 27 additions & 0 deletions src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//! # Torch implementation of language models

pub mod albert;
pub mod bart;
pub mod bert;
pub mod deberta;
pub mod deberta_v2;
pub mod distilbert;
pub mod electra;
pub mod fnet;
pub mod gpt2;
pub mod gpt_j;
pub mod gpt_neo;
pub mod longformer;
pub mod longt5;
pub mod m2m_100;
pub mod marian;
pub mod mbart;
pub mod mobilebert;
pub mod nllb;
pub mod openai_gpt;
pub mod pegasus;
pub mod prophetnet;
pub mod reformer;
pub mod roberta;
pub mod t5;
pub mod xlnet;
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ impl RobertaForMaskedLM {
RobertaForMaskedLM { roberta, lm_head }
}

#[allow(rustdoc::invalid_html_tags)]
/// Forward pass through the model
///
/// # Arguments
Expand Down Expand Up @@ -517,6 +518,7 @@ impl RobertaForSequenceClassification {
})
}

#[allow(rustdoc::invalid_html_tags)]
/// Forward pass through the model
///
/// # Arguments
Expand Down Expand Up @@ -602,6 +604,7 @@ impl RobertaForSequenceClassification {
}
}

#[allow(rustdoc::invalid_html_tags)]
/// # RoBERTa for multiple choices
/// Multiple choices model using a RoBERTa base model and a linear classifier.
/// Input should be in the form `<s> Context </s> Possible choice </s>`. The choice is made along the batch axis,
Expand Down Expand Up @@ -653,6 +656,7 @@ impl RobertaForMultipleChoice {
}
}

#[allow(rustdoc::invalid_html_tags)]
/// Forward pass through the model
///
/// # Arguments
Expand Down Expand Up @@ -815,6 +819,7 @@ impl RobertaForTokenClassification {
})
}

#[allow(rustdoc::invalid_html_tags)]
/// Forward pass through the model
///
/// # Arguments
Expand Down Expand Up @@ -957,6 +962,7 @@ impl RobertaForQuestionAnswering {
}
}

#[allow(rustdoc::invalid_html_tags)]
/// Forward pass through the model
///
/// # Arguments
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 7c10c25

Please sign in to comment.