In [1]:
using MatrixLMnet
using DataFrames
using Random
using StatsModels
using Plots
Random.seed!(120)

TaskLocalRNG()

## Overview 

MatrixLMnet is a package for L1 and L2 penalized estimation of matrix linear models; A fast, general methods for fitting sparse matrix linear models to structured high-throughput data. In this demo, we will learn to use this package by a simple simulation study.


## Data Generation

First, construct a `RawData` object consisting of the response variable `Y` and row/column covariates `X` and `Z`. All three matrices must be passed in as 2-dimensional arrays. Note that the `contr` function can be used to set up treatment and/or sum contrasts for categorical variables stored in a DataFrame. By default, `contr` generates treatment contrasts for all specified categorical variables ("treat"). Other options include `"sum"` for sum contrasts, `"noint"` for treatment contrasts with no intercept, and `"sumnoint"` for sum contrasts with no intercept.

In [2]:
# Dimensions of matrices 
n = 100
m = 250
# Number of column covariates
q = 20

# Generate data with two categorical variables and 4 numerical variables.
X_df = hcat(DataFrame(catvar1=rand(1:5, n), catvar2=rand(["A", "B", "C"], n)), DataFrame(rand(n,4),:auto))

# Convert dataframe to predicton matrix
X = Matrix(contr(X_df, [:catvar1, :catvar2], ["treat", "sum"]))

#X = design_matrix(@mlmFormula(catvar1 + catvar2 + x1 + x2 + x3 + x4), X_df, [(:catvar1, )])

p = size(X)[2]

10

In [3]:
Z = rand(m,q)
B = rand(-5:5,p,q)
E = randn(n,m)
Y = X*B*transpose(Z)+E

# Construct a RawData object
dat = RawData(Response(Y), Predictors(X, Z))

RawData(Response([-19.491660660714384 -11.599246293938792 … -9.846025890509917 -3.910552342504766; 1.6869050129842436 15.809262625546225 … 16.77316280061852 27.19168838618037; … ; -36.871293059346954 -45.445727213984966 … -25.67130221861735 -35.67748199691569; 34.25129202722653 28.917993558825945 … 22.068755759265283 6.449966296553046]), Predictors([0.0 0.0 … 0.7629354798627414 0.6997729240121372; 1.0 0.0 … 0.0683641291444198 0.13975716342002353; … ; 0.0 0.0 … 0.8916305292309327 0.6631051207998134; 0.0 0.0 … 0.9728491655926802 0.4753251921489988], [0.21736110326399694 0.9680284954520674 … 0.8414451324120746 0.16389350803169955; 0.4533491360031554 0.9653047071200448 … 0.6532052569477869 0.2122431874934696; … ; 0.17491325894688858 0.221650638553477 … 0.47280492128347373 0.8279142008963407; 0.025477967039622795 0.5373624856736039 … 0.5306255198534721 0.9982313780729508], false, false), 100, 250, 10, 20)

Create a 1d array of lambda penalties values to fit the estimates. If the lambdas are not in descending order, they will be automatically sorted by `mlmnet`.

In [4]:
lambdas = reverse(1.8.^(1:10))

10-element Vector{Float64}:
 357.0467226624001
 198.35929036800005
 110.19960576000003
  61.22200320000001
  34.012224
  18.895680000000002
  10.4976
   5.832000000000001
   3.24
   1.8

Create a 1d array of alpha parameter penalties values that determine the penalties mix between L<sub>1</sub> and L<sub>2</sub> to fit the estimates according to the Elastic Net penalization method.  In the case of Lasso regression (L<sub>1</sub> regularization), alpha should be 1, and 0 for Ridge regression (L<sub>2</sub> regularization). If the alphas are not in descending order, they will be automatically sorted by `mlmnet`.

In [5]:
alphas = reverse(collect(0:0.1:1))

11-element Vector{Float64}:
 1.0
 0.9
 0.8
 0.7
 0.6
 0.5
 0.4
 0.3
 0.2
 0.1
 0.0

L<sub>1</sub> and L<sub>2</sub>-penalized estimates for matrix linear models can be obtained by running `mlmnet`. In addition to the `RawData` object, `lambdas` and `alphas`, `mlmnet` requires as a keyword argument the function name for an algorithm used to fit Elastic Net penalized estimates. Current methods are: `"cd"` (coordinate descent), `"cd_active"` (active coordinate descent), `"ista"` (ISTA with fixed step size), `"fista"` (FISTA with fixed step size), `"fista_bt"` (FISTA with backtracking), and `"admm"` (ADMM). 

An object of type `Mlmnet` will be returned, with variables for the penalized coefficient estimates (`B`) along with the lambda and alpha penalty values used (`lambdas`, `alphas`). By default, `mlmnet` estimates both row and column main effects (X and Z intercepts), but this behavior can be suppressed by setting `hasXIntercept=false` and/or `hasZIntercept=false`; the intercepts will be regularized unless `toXInterceptReg=false` and/or `toZInterceptReg=false`. Individual `X` (row) and `Z` (column) effects can be left unregularized by manually passing in 1d boolean arrays of length `p` and `q` to indicate which effects should be regularized (`true`) or not (`false`) into `toXReg` and `toZReg`. By default, `mlmnet` centers and normalizes the columns of `X` and `Z` to have mean 0 and norm 1 (`toNormalize=true`). Additional keyword arguments include `isVerbose`, which controls message printing; `thresh`, the threshold at which the coefficients are considered to have converged; and `maxiter`, the maximum number of iterations. 

In [6]:
est = mlmnet(dat, lambdas, alphas; method = "admm")

Criterion: 19.184826606351727
Number of iterations: 16
Criterion: 247.5541389838715
Number of iterations: 40
Criterion: 472.28160566570335
Number of iterations: 37
Criterion: 638.5915230212398
Number of iterations: 43
Criterion: 671.2990866722264
Number of iterations: 51
Criterion: 540.8364818964126
Number of iterations: 35
Criterion: 369.674759526384
Number of iterations: 34
Criterion: 232.29710840917514
Number of iterations: 21
Criterion: 138.68623212253885
Number of iterations: 32
Criterion: 80.50308000596563
Number of iterations: 20
Criterion: 19.18482660756022
Number of iterations: 13
Criterion: 46.36094706718012
Number of iterations: 36
Criterion: 114.86068039304456
Number of iterations: 39
Criterion: 242.0544810377271
Number of iterations: 37
Criterion: 395.2786268995153
Number of iterations: 35
Criterion: 509.68756204105665
Number of iterations: 34
Criterion: 547.6205866746141
Number of iterations: 26
Criterion: 501.64364080134715
Number of iterations: 19
Criterion: 402.9180227

Mlmnet([-7.924291302073635 -4.278340378104617 … -0.6961706174677971 -8.095375939192282; 9.335652102696374 -0.0 … 0.0 0.0; … ; -8.640383579637758 -0.0 … -0.0 -0.0; -2.0251037280126893 -0.0 … -0.0 -0.0;;; -7.948047450767887 -4.27936123307744 … -0.6961706959214091 -8.095376098990863; 9.335751525357358 -0.0 … 0.0 0.0; … ; -8.640426696294458 -0.0 … -0.0 -0.0; -2.0250706175077946 -0.0 … -0.0 -0.0;;; -8.42610211224222 -4.289130793122506 … -0.6961706665360483 -8.093712770645718; 9.316173391488505 -0.0 … 0.0 0.0; … ; -8.640417183691067 -0.0 … -0.0 -0.0; -2.0250779177933493 -0.0 … -0.0 -0.0;;; -9.326633913816554 -4.294767281307395 … -0.7438087631605879 -8.093642317293847; 9.248902493378182 -0.0 … 0.0 0.0; … ; -8.561858300904698 -0.0 … -0.3714496413073767 -0.0; -2.025082215581464 -0.0 … -0.0 -0.0;;; -10.979144940978708 -4.138647841850664 … -0.7224095995265675 -7.980494615503578; 9.209709373336757 -0.0 … 0.0 0.0; … ; -8.387466427715625 -1.1356587568242758 … -2.0454188333799315 -0.0; -1.89019587565

If `alphas` argument is omitted, a Lasso regression will be applied which is equivalent to `alphas = [1]`.  

The functions for the algorithms used to fit the Elastic Net penalized estimates have keyword arguments that can be passed into `mlmnet` when non-default behavior is desired. Irrelevant keyword arguments will be ignored. 

`"cd"` (coordinate descent)
- `isRandom= true`: Bool; whether to use random or cyclic updates

`"cd_active"` (active coordinate descent)
- `isRandom = true`: Bool; whether to use random or cyclic updates

`"ista"` (ISTA with fixed step size)
- `stepsize = 0.01`: Float64; fixed step size for updates
- `setStepsize = true`: Bool; whether the fixed step size should be calculated, overriding `stepsize`

`"fista"` (FISTA with fixed step size)
- `stepsize = 0.01`: Float64; fixed step size for updates
- `setStepsize = true`: Bool; whether the fixed step size should be calculated, overriding `stepsize`

`"fista_bt"` (FISTA with backtracking)
- `stepsize = 0.01`: Float64; initial step size for updates
- `gamma = 0.5`: Float64; multiplying factor for step size backtracking/line search

`"admm"` (ADMM)
- `rho = 1.0`: Float64; parameter that controls ADMM tuning
- `setRho = true`: Float64; whether the ADMM tuning parameter should be calculated, overriding `rho`
- `tau_incr = 2.0`: Float64; parameter that controls the factor at which the ADMM tuning parameter increases
- `tau_decr = 2.0`: Float64; parameter that controls the factor at which the ADMM tuning parameter decreases
- `mu = 10.0`: Float64; parameter that controls the factor at which the primal and dual residuals should be within each other

The 4d array of coefficient estimates can be accessed using `coef(est)`. Predicted values and residuals can be obtained by calling `predict` and `resid`. By default, both of these functions use the same data used to fit the model. However, a new `Predictors` object can be passed into `predict` as the `newPredictors` argument and a new `RawData` object can be passed into `resid` as the `newData` argument. For convenience, `fitted(est)` will return the fitted values by calling `predict` with the default arguments. 

In [7]:
preds = predict(est)
resids = resid(est)

estPerms = mlmnet_perms(dat, lambdas, alphas; method = "fista_bt")

Criterion: 135.04907782289692
Number of iterations: 197
Criterion: 135.04730367033596
Number of iterations: 36
Criterion: 135.04730309446663
Number of iterations: 1
Criterion: 156.28402752501606
Number of iterations: 675
Criterion: 207.04214578963789
Number of iterations: 369
Criterion: 231.9668636721151
Number of iterations: 524
Criterion: 219.63465697913898
Number of iterations: 429
Criterion: 198.2961556548988
Number of iterations: 864
Criterion: 175.6608859554169
Number of iterations: 369
Criterion: 158.99716489719935
Number of iterations: 172
Criterion: 135.04907782289692
Number of iterations: 197
Criterion: 135.04730367033596
Number of iterations: 36
Criterion: 135.04730309446663
Number of iterations: 1
Criterion: 138.20508044858119
Number of iterations: 196
Criterion: 152.88673764808834
Number of iterations: 411
Criterion: 170.75507495643203
Number of iterations: 189
Criterion: 185.353871805172
Number of iterations: 646
Criterion: 188.8874311470762
Number of iterations: 694
Crit

Mlmnet([12.443408216415428 -4.230847274888264 … -0.7032917026718124 -8.128196227303997; 7.158177878137566 0.0 … -0.0 -0.0; … ; -16.37081848292532 -0.0 … -0.0 -0.0; 3.736983164260677 0.0 … -0.0 0.0;;; 12.412315525537553 -4.274627234898147 … -0.6988790877560112 -8.099946454427602; 7.116824610009 0.0 … -0.0 -0.0; … ; -16.302280355398683 -0.0 … -0.0 -0.0; 3.758312692957218 0.0 … -0.0 0.0;;; 12.412260380992796 -4.2746568332583585 … -0.6988553452978227 -8.099910559914402; 7.116795849517562 0.0 … -0.0 -0.0; … ; -16.302239651115617 -0.0 … -0.0 -0.0; 3.7582974536786278 0.0 … -0.0 0.0;;; 12.662645035762274 -4.27924085617419 … -0.6954752134246621 -8.093906923673142; 7.113492704407816 0.0 … -0.0 -0.0; … ; -16.241877558129005 -0.0 … -0.0 -0.0; 3.757931544273566 0.0 … -0.0 0.0;;; 13.971785347656352 -4.223741916180748 … -0.6800536668750533 -8.095771816162907; 7.028441046585453 0.0 … -0.0 -0.0; … ; -16.068983130741934 -1.4401634293742056 … -0.2956434967493522 -0.0; 3.7400108540375605 0.0 … -0.0 0.0;;;

All four of these functions take optional `lambda` and `alpha` arguments, in which case only the 2d array corresponding to that values of lambda and alpha will be returned, e.g. `coef(est, lambdas[1], alphas[1])`. If a lambda or alpha value that was not used in the fitting of the `Mlmnet` object is passed in, an error will be raised. One can also extract coefficients as a flattened 3d array by calling `coef_3d(est)`, for convenience when writing the results to flat files. 

`mlmnet_perms` permutes the response matrix `Y` in a `RawData` object and then calls `mlmnet`. By default, the function used to permute `Y` is `shuffle_rows`, which shuffles the rows for `Y`. Alternative functions for permuting `Y`, such as `shuffle_cols`, can be passed into the argument `permFun`. Non-default behavior for `mlmnet` can be specified by passing its keyword arguments into `mlmnet_perms`. 

Cross-validation for `mlmnet` is implemented by `mlmnet_cv`. The user can either manually specify the row/column folds of `Y` as a 1d array of 1d arrays of row/column indices, or specify the number of folds that should be used. If the number of folds is specified, disjoint folds of approximately equal size will be generated from a call to `make_folds`. Passing in `1` for the number of row (or column) folds indicates that all of the rows (or columns) of `Y` should be used in each fold. The advantage of manually passing in the row and/or column folds is that it allows the user to stratify or otherwise control the nature of the folds. For example, `make_folds_conds` will generate folds for a set of categorical conditions and ensure that each condition is represented in each fold. Cross validation is computed in parallel when possible. Non-default behavior for `mlmnet` can be specified by passing its keyword arguments into `mlmnet_cv`. 

In the call below, `mlmnet_cv` generates 10 disjoint row folds but uses all columns of `Y` in each fold (indicated by the `1`). The function returns an `Mlmnet_cv` object, which contains an array of the Mlmnet objects for each fold (`MLMNets`); the lambda penalty values used (`lambdas`); the row and column folds (`rowFolds` and `colFolds`); an array of the mean-squared error for each fold (`mse`); and an array of the proportion of zero interaction effects for each fold (`propZero`). The keyword argument `dig` in `mlmnet_cv` adjusts the level of precision when calculating the percent of zero coefficients. It defaults to `12`. 

In [8]:
estCVObjs = mlmnet_cv(dat, lambdas, alphas, 10, 1, method = "fista_bt")
println(mlmnet_cv_summary(estCVObjs))

Performing 10-fold cross validation.
Criterion: 19.016385002245304
Number of iterations: 270
Criterion: 219.58082566151023
Number of iterations: 1791
Criterion: 382.91946404489516
Number of iterations: 1482
Criterion: 548.945422439869
Number of iterations: 1166
Criterion: 607.3122393835783
Number of iterations: 1577
Criterion: 500.21442510375385
Number of iterations: 222
Criterion: 347.0867552532113
Number of iterations: 1203
Criterion: 218.74585251142926
Number of iterations: 289
Criterion: 131.16705101770128
Number of iterations: 704
Criterion: 76.28900490597903
Number of iterations: 224
Criterion: 19.016385002245304
Number of iterations: 270
Criterion: 29.853024124757088
Number of iterations: 311
Criterion: 59.73782669105457
Number of iterations: 473
Criterion: 115.60135465466414
Number of iterations: 586
Criterion: 192.55855879327368
Number of iterations: 740
Criterion: 256.22869759491687
Number of iterations: 779
Criterion: 296.2040333444169
Number of iterations: 867
Criterion: 29

Criterion: 252.01305935116997
Number of iterations: 842
Criterion: 246.57272064514024
Number of iterations: 733
Criterion: 20.095013622093024
Number of iterations: 241
Criterion: 24.6768325757602
Number of iterations: 152
Criterion: 33.60439784393944
Number of iterations: 227
Criterion: 50.47139462415795
Number of iterations: 337
Criterion: 74.4352766017456
Number of iterations: 434
Criterion: 106.03409917038081
Number of iterations: 516
Criterion: 149.43952634765063
Number of iterations: 656
Criterion: 196.95855585353078
Number of iterations: 771
Criterion: 234.9232795892096
Number of iterations: 847
Criterion: 247.28294420416842
Number of iterations: 856
Criterion: 20.626122768703063
Number of iterations: 241
Criterion: 24.106380060714073
Number of iterations: 123
Criterion: 31.624527200596305
Number of iterations: 182
Criterion: 44.28324681231754
Number of iterations: 273
Criterion: 63.6174074007143
Number of iterations: 400
Criterion: 89.19206392478256
Number of iterations: 467
Cri

Criterion: 138.85733039189518
Number of iterations: 678
Criterion: 185.60913664605843
Number of iterations: 797
Criterion: 223.387918384684
Number of iterations: 878
Criterion: 21.258503522311347
Number of iterations: 201
Criterion: 23.723853515071685
Number of iterations: 77
Criterion: 27.733899212161838
Number of iterations: 149
Criterion: 33.83055229804684
Number of iterations: 204
Criterion: 44.155808946133234
Number of iterations: 300
Criterion: 60.76839619220825
Number of iterations: 388
Criterion: 88.30882080693502
Number of iterations: 530
Criterion: 127.04969788589003
Number of iterations: 660
Criterion: 173.01387817341723
Number of iterations: 773
Criterion: 213.6669268264913
Number of iterations: 848
Criterion: 21.495671716449742
Number of iterations: 201
Criterion: 23.359947871014914
Number of iterations: 65
Criterion: 26.376385670699715
Number of iterations: 131
Criterion: 31.359426454450897
Number of iterations: 177
Criterion: 40.19137557355387
Number of iterations: 266
C

Criterion: 499.28573994025425
Number of iterations: 1203
Criterion: 345.57000042294777
Number of iterations: 1055
Criterion: 217.731221467849
Number of iterations: 293
Criterion: 130.50000704757707
Number of iterations: 98
Criterion: 75.77704925806196
Number of iterations: 316
Criterion: 19.139674896501287
Number of iterations: 199
Criterion: 30.238655406043492
Number of iterations: 312
Criterion: 60.7066083969273
Number of iterations: 473
Criterion: 116.44387710887443
Number of iterations: 587
Criterion: 193.8534851658314
Number of iterations: 741
Criterion: 258.15745026628315
Number of iterations: 783
Criterion: 299.8940698370541
Number of iterations: 871
Criterion: 299.4513086693911
Number of iterations: 877
Criterion: 260.0385207788895
Number of iterations: 665
Criterion: 201.54815520323262
Number of iterations: 800
Criterion: 19.24386511588235
Number of iterations: 199
Criterion: 26.611446237251787
Number of iterations: 226
Criterion: 41.882509582491615
Number of iterations: 308
C

Criterion: 49.42985175719769
Number of iterations: 337
Criterion: 72.85627231029409
Number of iterations: 434
Criterion: 103.82926569967577
Number of iterations: 517
Criterion: 146.4782555166552
Number of iterations: 657
Criterion: 193.15964586148473
Number of iterations: 772
Criterion: 230.4479679602561
Number of iterations: 847
Criterion: 243.08938567634254
Number of iterations: 863
Criterion: 20.25679836629953
Number of iterations: 267
Criterion: 23.713601240759548
Number of iterations: 123
Criterion: 31.03424233927681
Number of iterations: 182
Criterion: 43.391568512015496
Number of iterations: 273
Criterion: 62.28110827673712
Number of iterations: 400
Criterion: 87.34083312384365
Number of iterations: 467
Criterion: 124.03469542701022
Number of iterations: 587
Criterion: 171.97999989146174
Number of iterations: 759
Criterion: 213.04756556236774
Number of iterations: 823
Criterion: 235.75648469101515
Number of iterations: 828
Criterion: 20.64790337159643
Number of iterations: 267
C

Criterion: 43.9160732297823
Number of iterations: 300
Criterion: 60.4439767983437
Number of iterations: 389
Criterion: 87.85105312085737
Number of iterations: 530
Criterion: 126.4082751976922
Number of iterations: 660
Criterion: 172.12727978724536
Number of iterations: 773
Criterion: 212.4447942467184
Number of iterations: 846
Criterion: 21.362118621468184
Number of iterations: 266
Criterion: 23.21486226471032
Number of iterations: 65
Criterion: 26.218981681461237
Number of iterations: 131
Criterion: 31.172663666876105
Number of iterations: 177
Criterion: 39.955026209351104
Number of iterations: 266
Criterion: 55.60378848371195
Number of iterations: 397
Criterion: 78.75851306887792
Number of iterations: 469
Criterion: 114.33780846716044
Number of iterations: 592
Criterion: 161.89951403406383
Number of iterations: 765
Criterion: 203.61609778836296
Number of iterations: 829
Criterion: 20.801158035233136
Number of iterations: 266
Criterion: 22.134793335255345
Number of iterations: 53
Crit

Criterion: 117.9272587937138
Number of iterations: 587
Criterion: 195.98344763485142
Number of iterations: 740
Criterion: 259.55026899528366
Number of iterations: 778
Criterion: 300.3408036938887
Number of iterations: 861
Criterion: 299.18635403840926
Number of iterations: 868
Criterion: 258.92488647197746
Number of iterations: 631
Criterion: 201.46808299004402
Number of iterations: 679
Criterion: 19.307038484866858
Number of iterations: 219
Criterion: 26.674567925383972
Number of iterations: 169
Criterion: 42.404427278242736
Number of iterations: 309
Criterion: 73.35487658412426
Number of iterations: 425
Criterion: 120.5098218126135
Number of iterations: 586
Criterion: 170.88992900614892
Number of iterations: 653
Criterion: 223.6019999089564
Number of iterations: 777
Criterion: 259.77339688376054
Number of iterations: 855
Criterion: 265.1613245629971
Number of iterations: 864
Criterion: 235.81846935469156
Number of iterations: 620
Criterion: 19.4928385455214
Number of iterations: 219


Criterion: 42.50451838528214
Number of iterations: 273
Criterion: 61.057514217024654
Number of iterations: 400
Criterion: 85.60819738189198
Number of iterations: 468
Criterion: 121.62204817564921
Number of iterations: 589
Criterion: 169.0604837662725
Number of iterations: 763
Criterion: 210.4845183672686
Number of iterations: 833
Criterion: 234.08501000340556
Number of iterations: 844
Criterion: 20.2903560850062
Number of iterations: 239
Criterion: 23.25958450051106
Number of iterations: 103
Criterion: 29.090279980500164
Number of iterations: 142
Criterion: 38.43581313306075
Number of iterations: 215
Criterion: 52.967340961608926
Number of iterations: 340
Criterion: 74.70136055692258
Number of iterations: 450
Criterion: 106.07185156722619
Number of iterations: 551
Criterion: 150.13028761962812
Number of iterations: 712
Criterion: 196.54415015511282
Number of iterations: 841
Criterion: 226.18525199049418
Number of iterations: 852
Criterion: 20.485742561603903
Number of iterations: 239
C

`mlmnet_cv_summary` displays a table of the average mean-squared error and proportion of zero coefficients across the folds for each value of lambda. The optimal lambda might be the one that minimizes the mean-squared error (MSE), or can be chosen based on a pre-determined proportion of zeros that is desired in the coefficient estimates. 

The `lambda_min` function returns the summary information for the lambdas that correspond to the minimum average test MSE across folds and the MSE that is one standard error greater.

In [9]:
lambda_min(estCVObjs)

Unnamed: 0_level_0,Name,Index,Lambda,Alpha,AvgMSE,AvgPropZero
Unnamed: 0_level_1,String,Tuple…,Float64,Float64,Float64,Float64
1,"(𝜆, 𝛼)_min","(2, 1)",198.359,1.0,34.6111,0.966
2,"(𝜆, 𝛼)_min1se","(1, 1)",357.047,1.0,38.9581,1.0
