From 4ac6c5aa0d9b42f4a0788ccffbbadcbb1101bac1 Mon Sep 17 00:00:00 2001 From: chainsawriot Date: Wed, 20 Mar 2024 14:09:44 +0100 Subject: [PATCH] Fix #30 --- DESCRIPTION | 2 +- R/misc.R | 6 ++++++ R/train.R | 6 ++---- data/supported_model_types.rda | Bin 0 -> 231 bytes man/grafzahl.Rd | 2 +- man/hydrate.Rd | 2 +- man/supported_model_types.Rd | 16 ++++++++++++++++ rawdata/createdata.R | 5 +++++ 8 files changed, 32 insertions(+), 7 deletions(-) create mode 100644 data/supported_model_types.rda create mode 100644 man/supported_model_types.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 11d3871..6cc0bdd 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -8,7 +8,7 @@ Description: Duct tape the 'quanteda' ecosystem (Benoit et al., 2018) = 3) Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 URL: https://github.com/chainsawriot/grafzahl BugReports: https://github.com/chainsawriot/grafzahl/issues Suggests: diff --git a/R/misc.R b/R/misc.R index 06b21ed..332b357 100644 --- a/R/misc.R +++ b/R/misc.R @@ -20,6 +20,12 @@ NULL #' Van Atteveldt, W., Van der Velden, M. A., & Boukes, M. (2021). The validity of sentiment analysis: Comparing manual annotation, crowd-coding, dictionary approaches, and machine learning algorithms. Communication Methods and Measures, 15(2), 121-140. "ecosent" +#' Supported model types +#' +#' A vector of all supported model types. +#' +"supported_model_types" + #' Download The Amharic News Text Classification Dataset #' #' This function downloads the training and test sets of the Amharic News Text Classification Dataset from Hugging Face. diff --git a/R/train.R b/R/train.R index 3999a2e..bdf7a66 100644 --- a/R/train.R +++ b/R/train.R @@ -32,9 +32,7 @@ model_type <- .infer_model_type(model_name) } model_type <- gsub("-", "", tolower(model_type)) - if (!model_type %in% c("albert", "bert", "bertweet", "bigbird", "camembert", "deberta", "distilbert", "electra", "flaubert", - "herbert", "layoutlm", "layoutlmv2", "longformer", "mpnet", "mobilebert", "rembert", "roberta", "squeezebert", - "squeezebert", "xlm", "xlmroberta", "xlnet", "debertav2")) { + if (!model_type %in% grafzahl::supported_model_types) { stop("Invalid `model_type`.", call. = FALSE) } return(model_type) @@ -118,7 +116,7 @@ #' @param train_size numeric, proportion of data in `x` and `y` to be used actually for training. The rest will be used for cross validation. #' @param args list, additionally parameters to be used in the underlying simple transformers #' @param cleanup logical, if `TRUE`, the `runs` directory generated will be removed when the training is done -#' @param model_type a string indicating model_type of the input model. If `NULL`, it will be inferred from `model_name`. It can only be one of the following: "albert", "bert", "bertweet", "bigbird", "camembert", "deberta", "debertav2", "distilbert", "electra", "flaubert", "herbert", "layoutlm", "layoutlmv2", "longformer", "mpnet", "mobilebert", "rembert", "roberta", "squeezebert", "squeezebert", "xlm", "xlmroberta", "xlnet". This will be lowercased and hyphens will be removed, e.g. "XLM-RoBERTa" will be normalized to "xlmroberta". +#' @param model_type a string indicating model_type of the input model. If `NULL`, it will be inferred from `model_name`. Supported model types are available in [supported_model_types]. #' @param manual_seed numeric, random seed #' @param verbose logical, if `TRUE`, debug messages will be displayed #' @param ... paramters pass to [grafzahl()] diff --git a/data/supported_model_types.rda b/data/supported_model_types.rda new file mode 100644 index 0000000000000000000000000000000000000000..1b7c90b2d7aa2705ece3f1cb530077b83cc5464d GIT binary patch literal 231 zcmVqydLT4p z&;g(g8UiMx(KOJ}pa1|G0V%1dj3yC@qd*3Y4Iwnqp_uZhuSfwJ#K!v_nhOOE(mWH> zB`Joe217DLCwCYFbg&rqKK$&$uN7=X{Ni+UR;-+7gjxU$VDE|F1Buo00YrLl5<^6k z19>S*% literal 0 HcmV?d00001 diff --git a/man/grafzahl.Rd b/man/grafzahl.Rd index e0607ab..9ae12a4 100644 --- a/man/grafzahl.Rd +++ b/man/grafzahl.Rd @@ -95,7 +95,7 @@ textmodel_transformer(...) \item{cleanup}{logical, if \code{TRUE}, the \code{runs} directory generated will be removed when the training is done} -\item{model_type}{a string indicating model_type of the input model. If \code{NULL}, it will be inferred from \code{model_name}. It can only be one of the following: "albert", "bert", "bertweet", "bigbird", "camembert", "deberta", "distilbert", "electra", "flaubert", "herbert", "layoutlm", "layoutlmv2", "longformer", "mpnet", "mobilebert", "rembert", "roberta", "squeezebert", "squeezebert", "xlm", "xlmroberta", "xlnet". This will be lowercased and hyphens will be removed, e.g. "XLM-RoBERTa" will be normalized to "xlmroberta".} +\item{model_type}{a string indicating model_type of the input model. If \code{NULL}, it will be inferred from \code{model_name}. Supported model types are available in \link{supported_model_types}.} \item{manual_seed}{numeric, random seed} diff --git a/man/hydrate.Rd b/man/hydrate.Rd index 65c1eff..e37564e 100644 --- a/man/hydrate.Rd +++ b/man/hydrate.Rd @@ -9,7 +9,7 @@ hydrate(output_dir, model_type = NULL, regression = FALSE) \arguments{ \item{output_dir}{string, location of the output model. If missing, the model will be stored in a temporary directory. Important: Please note that if this directory exists, it will be overwritten.} -\item{model_type}{a string indicating model_type of the input model. If \code{NULL}, it will be inferred from \code{model_name}. It can only be one of the following: "albert", "bert", "bertweet", "bigbird", "camembert", "deberta", "distilbert", "electra", "flaubert", "herbert", "layoutlm", "layoutlmv2", "longformer", "mpnet", "mobilebert", "rembert", "roberta", "squeezebert", "squeezebert", "xlm", "xlmroberta", "xlnet". This will be lowercased and hyphens will be removed, e.g. "XLM-RoBERTa" will be normalized to "xlmroberta".} +\item{model_type}{a string indicating model_type of the input model. If \code{NULL}, it will be inferred from \code{model_name}. Supported model types are available in \link{supported_model_types}.} \item{regression}{logical, if \code{TRUE}, the task is regression, classification otherwise.} } diff --git a/man/supported_model_types.Rd b/man/supported_model_types.Rd new file mode 100644 index 0000000..6bd4cb6 --- /dev/null +++ b/man/supported_model_types.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc.R +\docType{data} +\name{supported_model_types} +\alias{supported_model_types} +\title{Supported model types} +\format{ +An object of class \code{character} of length 23. +} +\usage{ +supported_model_types +} +\description{ +A vector of all supported model types. +} +\keyword{datasets} diff --git a/rawdata/createdata.R b/rawdata/createdata.R index 8da6e23..fb6b43e 100644 --- a/rawdata/createdata.R +++ b/rawdata/createdata.R @@ -22,3 +22,8 @@ download.file(url <- "https://raw.githubusercontent.com/vanatteveldt/ecosent/mas ecosent <- read.csv("rawdata/sentences_ml.csv", encoding = "UTF-8")[c("id", "headline", "value", "gold")] save(ecosent, file = "data/ecosent.rda", ascii = FALSE, compress = "xz") + +supported_model_types <- c("albert", "bert", "bertweet", "bigbird", "camembert", "deberta", "distilbert", "electra", "flaubert", + "herbert", "layoutlm", "layoutlmv2", "longformer", "mpnet", "mobilebert", "rembert", "roberta", "squeezebert", + "squeezebert", "xlm", "xlmroberta", "xlnet", "debertav2") +usethis::use_data(supported_model_types)