-
Notifications
You must be signed in to change notification settings - Fork 5
/
cv-aglm.R
145 lines (134 loc) · 6.75 KB
/
cv-aglm.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# cross-validation function for AGLM model
# written by Kenji Kondo @ 2019/2/24
#' Cross-validation for AGLM models
#'
#' @param x An input matrix or data.frame to be fitted.
#' @param y An integer or numeric vector which represents response variable.
#' @param qualitative_vars_UD_only A list of indices or names for specifying which columns are qualitative and need only U-dummy representations.
#' @param qualitative_vars_both A list of indices or names for specifying which columns are qualitative and need both U-dummy and O-dummy representations.
#' @param qualitative_vars_OD_only A list of indices or names for specifying which columns are qualitative and need only O-dummy representations.
#' @param quantitative_vars A list of indices or names for specyfying which columns are quantitative.
#' @param add_linear_columns A boolean value which indicates whether this function uses linear effects or not.
#' @param add_OD_columns_of_qualitatives A boolean value which indicates whether this function use O-dummy representations for qualitative and ordinal variables or not.
#' @param add_interaction_columns A boolean value which indicates whether this function uses interaction effects or not.
#' @param family Response type. Currently "gaussian", "binomial", and "poisson" are supported.
#' @param bins_list A list of numeric vectors, each element of which is used as breaks when binning of a quantitative variable or a qualitative variable with order.
#' @param bins_names A list of column name or column index, each name or index of which specifies which column of `x` is binned used with an element of `bins_list` in the same position.
#' @param ... Other arguments are passed directly to backend (currently cv.glmnet() is used), and if not given, backend API's default values are used to call backend functions.
#'
#' @return Result of cross-validation.
#'
#' @export
#' @importFrom assertthat assert_that
#' @importFrom glmnet cv.glmnet
cv.aglm <- function(x, y,
qualitative_vars_UD_only=NULL,
qualitative_vars_both=NULL,
qualitative_vars_OD_only=NULL,
quantitative_vars=NULL,
add_linear_columns=TRUE,
add_OD_columns_of_qualitatives=TRUE,
add_interaction_columns=TRUE,
bins_list=NULL,
bins_names=NULL,
weights,
offset=NULL,
lambda=NULL,
type.measure = c("mse", "deviance", "class", "auc", "mae"),
nfolds = 10,
foldid,
grouped = TRUE,
keep = FALSE,
parallel = FALSE,
standardize=TRUE,
family=c("gaussian","binomial","poisson"),
alpha=1.0,
nlambda=100,
lambda.min.ratio=NULL,
intercept=TRUE,
thresh=1e-7,
dfmax=NULL,
pmax=NULL,
exclude,
penalty.factor=NULL,
lower.limits=-Inf,
upper.limits=Inf,
maxit=100000,
type.gaussian=NULL,
type.logistic=c("Newton","modified.Newton"),
standardize.response=FALSE) {
# Create an input object
x <- newInput(x,
qualitative_vars_UD_only=qualitative_vars_UD_only,
qualitative_vars_both=qualitative_vars_both,
qualitative_vars_OD_only=qualitative_vars_OD_only,
quantitative_vars=quantitative_vars,
add_linear_columns=add_linear_columns,
add_OD_columns_of_qualitatives=add_OD_columns_of_qualitatives,
add_interaction_columns=add_interaction_columns,
bins_list,
bins_names)
# Check y
y <- drop(y)
assert_that(class(y) == "integer" | class(y) == "numeric")
assert_that(length(y) == dim(x@data)[1])
# Check family
family <- match.arg(family)
# Create a design matrix which is passed to backend API
x_for_backend <- getDesignMatrix(x)
# Data size
nobs <- dim(x_for_backend)[1]
nvars <- dim(x_for_backend)[2]
assert_that(length(y) == nobs)
# Set default values to some parameters if not given
if (is.null(lambda.min.ratio)) lambda.min.ratio <- ifelse(nobs<nvars,1e-2,1e-4)
if (is.null(dfmax)) dfmax <- nvars+1
if (is.null(pmax)) pmax <- min(dfmax*2+20,nvars)
if (is.null(penalty.factor)) penalty.factor <- rep(1,nvars)
if (is.null(type.gaussian)) type.gaussian <- ifelse(nvars<500,"covariance","naive")
cv.glmnet_result <- cv.glmnet(x=x_for_backend,
y=y,
type.measure=type.measure,
nfolds=nfolds,
foldid=foldid,
grouped=grouped,
keep=keep,
parallel=parallel,
family=family,
weights=weights,
offset=offset,
alpha=alpha,
nlambda=nlambda,
lambda.min.ratio=lambda.min.ratio,
lambda=lambda,
standardize=standardize,
intercept=intercept,
thresh=thresh,
dfmax=dfmax,
pmax=pmax,
exclude=exclude,
penalty.factor=penalty.factor,
lower.limits=lower.limits,
upper.limits=upper.limits,
maxit=maxit,
type.gaussian=type.gaussian,
type.logistic=type.logistic,
standardize.response=standardize.response)
if (!keep) {
cv.glmnet_result$fit.preval <- matrix(0)
cv.glmnet_result$foldid <- integer(0)
}
return(new("AccurateGLM", backend_models=list(cv.glmnet=cv.glmnet_result$glmnet.fit),
lambda=cv.glmnet_result$lambda,
cvm=cv.glmnet_result$cvm,
cvsd=cv.glmnet_result$cvsd,
cvup=cv.glmnet_result$cvup,
cvlo=cv.glmnet_result$cvlo,
nzero=cv.glmnet_result$nzero,
name=cv.glmnet_result$name,
lambda.min=cv.glmnet_result$lambda.min,
lambda.1se=cv.glmnet_result$lambda.1se,
fit.preval=cv.glmnet_result$fit.preval,
foldid=cv.glmnet_result$foldid,
vars_info=x@vars_info))
}