In [8]:
ISTA_lassosolve <- function(X, y, lambda, max_iter = 1000, tol = 1e-6) {
  n <- nrow(X)
  p <- ncol(X)

  # Center X and y for numerical stability
  X_mean <- colMeans(X)
  y_mean <- mean(y)
  X_centered <- scale(X, center = X_mean, scale = FALSE)
  y_centered <- y - y_mean

  # Precompute X'X and X'y
  XtX <- crossprod(X_centered) / n
  Xty <- crossprod(X_centered, y_centered) / n

  # Lipschitz constant (L = max eigenvalue of XtX)
  L <- max(eigen(XtX, symmetric = TRUE, only.values = TRUE)$values)

  # Initialize beta (use Ridge regression initialization for better accuracy)
  beta <- solve(XtX + diag(lambda, p), Xty)

  # Soft-thresholding function
  soft_threshold <- function(z, gamma) {
    sign(z) * pmax(0, abs(z) - gamma)
  }

  # ISTA Iterations
  for (iter in 1:max_iter) {
    # Gradient calculation
    gradient <- XtX %*% beta - Xty
    beta_new <- soft_threshold(beta - (1 / L) * gradient, lambda / L)

    # Check convergence using relative error
    if (sqrt(sum((beta_new - beta)^2)) / max(1, sqrt(sum(beta^2))) < tol) {
      intercept <- y_mean - sum(X_mean * beta_new)  # Restore intercept
      full_beta <- c(intercept, beta_new)          # Combine intercept and coefficients
      return(list(beta = full_beta, iter = iter, convergence = TRUE))
    }

    # Update beta
    beta <- beta_new
  }

  # Final intercept calculation
  intercept <- y_mean - sum(X_mean * beta)
  full_beta <- c(intercept, beta)

  return(list(beta = full_beta, iter = max_iter, convergence = FALSE))
}


In [9]:
set.seed(123)

# 模拟数据
n <- 100  # 样本数
p <- 10   # 特征数
X <- matrix(rnorm(n * p), n, p)           # 特征矩阵
beta_true <- c(5, 1, -1, rep(0, p - 2))  # 包括截距项（5）和稀疏系数
y <- X %*% beta_true[-1] + beta_true[1] + rnorm(n)  # 生成响应变量

# 正则化参数
lambda <- 0.1

# 调用 ISTA_lassosolve
result <- ISTA_lassosolve(X, y, lambda)

# 查看结果
cat("Fitted coefficients (with intercept):\n")
print(result$beta)  # 包括截距项的系数
cat("Number of iterations:", result$iter, "\n")
cat("Convergence:", result$convergence, "\n")


Fitted coefficients (with intercept):
 [1]  5.14042631  0.95608943 -0.91324410 -0.06293648  0.07228975  0.00000000
 [7]  0.00000000  0.00000000  0.04065107  0.00000000  0.06717730
Number of iterations: 19 
Convergence: TRUE 
