Skip to content

Commit

Permalink
Merge pull request #6 from sdmccabe/r_version
Browse files Browse the repository at this point in the history
Iterative solvers for R version + generative model optimizations
  • Loading branch information
dblarremore committed Jul 26, 2018
2 parents 08fb381 + ec6c7d7 commit 0733be6
Showing 1 changed file with 31 additions and 45 deletions.
76 changes: 31 additions & 45 deletions r/springrank.R
Expand Up @@ -2,7 +2,8 @@ library(igraph)
library(Matrix)
library(Rlinsolve)

spring_rank <- function(A, alpha = 0, l0 = 1.0, l1 = 1.0, shift = TRUE) {
spring_rank <- function(A, alpha = 0, l0 = 1.0, l1 = 1.0,
shift = TRUE, solver = Rlinsolve::lsolve.bicgstab) {
#' Core function for calculating SpringRank.
#' Default parameters follow stanadard model.
#'
Expand All @@ -13,13 +14,17 @@ spring_rank <- function(A, alpha = 0, l0 = 1.0, l1 = 1.0, shift = TRUE) {
#' @param l1 Interaction springs' rest length.
#' @param shift (Optional, default TRUE) normalize such that the lowest-ranked
#' node has a SpringRank value of zero.
#' @param solver (Optional, default Rlinsolve::bicgstab) your preferred
#' solver for Ax=B. Should be able to handle dgCMatrix sparse matrices.
#' if the solve is not from Rlinsolve may throw a spurious warning for
#' unused parameters.
#'
#' @return A vector of SpringRank scores for each node. Sort or order the
#' vector for ordinal rankings of each node.

if (class(A) == "matrix") {
# coerce dense matrix to sparse matrice so the user doesn't have to.
A <- as(Matrix(A, sparse = TRUE, doDiag = FALSE), "dgCMatrix")
# coerce dense matrix to sparse matrice so the user doesn't have to.
A <- as(Matrix(A, sparse = TRUE, doDiag = FALSE), "dgCMatrix")
} else {
# confirm it's the right kind of sparse matrix. might throw an error
# if it's one of the less common sparse matrix types.
Expand All @@ -45,24 +50,9 @@ spring_rank <- function(A, alpha = 0, l0 = 1.0, l1 = 1.0, shift = TRUE) {

B = One * l0 + D2 %*% One
A_ = alpha * diag(nrow = N, ncol = N) + D1 - C

solvable = class(try(as.matrix(solve(A_, B)), silent = T)) == "matrix"
if (solvable)
{
print("using solve")
rank <- solve(A_, B)
} else {
print("using bigcstab")
rank <- Rlinsolve::lsolve.bicgstab(A_, B, verbose = F)
rank <- rank$x
}

if (shift) {
rank <- rank - min(rank)
}

} else {
print("fixing a rank degree of freedom")

C <- C +
matrix(rep(A[N,] , times = N),
ncol = N,
Expand All @@ -72,22 +62,26 @@ spring_rank <- function(A, alpha = 0, l0 = 1.0, l1 = 1.0, shift = TRUE) {
ncol = N,
nrow = N,
byrow = T)

D3 <- as(Matrix(0, ncol = N, nrow = N), "dgCMatrix")
for (i in 1:N) {
D3[i, i] <- l1 * (k_out[N] - k_in[N])
}

B <- D2 %*% One + D3 %*% One
A_ <- D1 - C
}

rank <- solve(A_, B)
rank <- solver(A_, B, verbose = F)
if (class(rank) == "list") {
rank <- rank$x # accomodates both Rlinsolve and Matrix solves
}

if (shift) {
rank <- rank - min(rank)
}
if (shift) {
rank <- rank - min(rank)
}

# matrix to vector, so we can use names()
# coerce matrix to vector, so we can use names()
rank <- rank[,1]
names(rank) <- colnames(A)
return(rank)
Expand All @@ -114,27 +108,19 @@ spring_rank_network <- function(N, beta, alpha, K, l0 = 0.5, l1 = 1.0) {
Z <- 0
for (i in 1:N) {
for (j in 1:N) {
Z = Z + exp(-0.5 * beta * (scores[i] - scores[j] - l1)^2)
Z <- Z + exp(-0.5 * beta * (scores[i] - scores[j] - l1)^2)
}
}
C = (K*N)/Z
A = Matrix(0, N, N)

for (i in 1:N) {
for (j in 1:N) {

H_ij = .5 * (scores[i] - scores[j] - l1)^2
lambda_ij = C * exp(-1 * beta * H_ij)

A_ij = rpois(1, lambda_ij)

if (A_ij > 0) {
A[i, j] = A_ij
}
}
}

return(graph_from_adjacency_matrix(A, mode = "directed", weighted = TRUE))
C <- (K*N)/Z

# for loops are slow in R so make a matrix of element-wise subtractions,
# each element i, j being scores[i]-scores[j]
# basically, trading off increased memory usage (dense matrix) for speed.
scores_mat <- matrix(1, length(scores), 1) %*% t(scores)
scores_mat <- scores_mat - scores - l1
H <- .5 * scores_mat^2
lambda <- C * exp(-1*beta*H)
A <- rpois(length(scores)^2, lambda) %>% matrix(nrow = dim(lambda)[1])

return(graph_from_adjacency_matrix(A, mode = "directed", weighted = "weight"))
}


0 comments on commit 0733be6

Please sign in to comment.