Skip to content

Commit

Permalink
The sampler function now can return indices instead. The sampler now …
Browse files Browse the repository at this point in the history
…has an indexing method. The last two changes allow making function calls faster for the structural test. Also, we are now avoiding generating lists with empty vectors when counting stats, this makes the function a bit faster now as well.
  • Loading branch information
gvegayon committed Mar 7, 2019
1 parent f87ef50 commit 1b30006
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 44 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method("[",ergmito_sampler)
S3method(as.adjmat,formula)
S3method(as.adjmat,list)
S3method(as.adjmat,network)
Expand Down
12 changes: 9 additions & 3 deletions R/count_stats.R
Expand Up @@ -117,6 +117,7 @@ count_stats.formula <- function(X, ...) {

out <- matrix(nrow = nnets(LHS), ncol = length(ergm_model$names),
dimnames = list(NULL, ergm_model$passed))

for (j in 1:ncol(out)) {

out[, j] <- count_stats(
Expand All @@ -138,8 +139,8 @@ count_stats.list <- function(X, terms, attrs = NULL, ...) {

chunks <- make_chunks(length(X), 2e5)

if (!length(attrs))
attrs <- replicate(length(X), numeric(0), simplify = FALSE)
# if (!length(attrs))
# attrs <- replicate(length(X), numeric(0), simplify = FALSE)

ans <- matrix(NA, nrow = length(X), ncol=length(terms))

Expand All @@ -148,7 +149,12 @@ count_stats.list <- function(X, terms, attrs = NULL, ...) {
i <- chunks$from[s]
j <- chunks$to[s]

ans[i:j,] <- count_stats.(X[i:j], terms, attrs[i:j])
for (k in seq_along(terms)) {
if (!length(attrs))
ans[i:j, k] <- count_stats.(X[i:j], terms[k], list(numeric(0)))
else
ans[i:j, k] <- count_stats.(X[i:j], terms[k], attrs[i:j])
}

}

Expand Down
52 changes: 46 additions & 6 deletions R/sim.R
Expand Up @@ -245,8 +245,8 @@ new_rergmito <- function(model, theta = NULL, sizes = NULL, mc.cores = 2L,...) {
# Calling the prob function
ans$calc_prob()

# Sampling function
ans$sample <- function(n, s, theta = NULL) {
# Sampling functions ---------------------------------------------------------
ans$sample <- function(n, s, theta = NULL, as_indexes = FALSE) {

s <- as.character(s)
# All should be able to be sampled
Expand All @@ -262,16 +262,26 @@ new_rergmito <- function(model, theta = NULL, sizes = NULL, mc.cores = 2L,...) {
on.exit(ans$prob[s] <- oldp)
}

ans$networks[[s]][
if (!as_indexes) {
ans$networks[[s]][
sample.int(
n = length(ans$prob[[s]]),
size = n,
replace = TRUE,
prob = ans$prob[[s]],
useHash = FALSE
)
]
} else {
sample.int(
n = length(ans$prob[[s]]),
size = n,
replace = TRUE,
prob = ans$prob[[s]],
useHash = FALSE
)
]
)
}

}

# Call
Expand All @@ -286,6 +296,36 @@ new_rergmito <- function(model, theta = NULL, sizes = NULL, mc.cores = 2L,...) {

}

#' @export
#' @rdname new_rergmito
#' @param i,j `i` is an integer vector indicating the indexes of the networks to
#' draw, while `j` the corresponding sizes. These need not to be of the same size.
#' @details The indexing method, `[.ergmito_sampler`, allows extracting networks
#' directly by passing indexes. `i` indicates the index of the networks to draw,
#' which go from 1 through `2^(n*(n-1))`, and `j` indicates the requested
#' size.
#' @return The indexing method `[.ergmito_sampler` returns a named list of length
#' `length(j)`.
`[.ergmito_sampler` <- function(x, i, j, ...) {

# Checking sizes
j <- as.character(j)
test <- which(!(j %in% as.character(x$sizes)))
if (length(test))
stop(
"Some values of `j` (requested sizes) are not included in the sampling function: ",
paste(j[test], collapse = ", "), ".", call. = FALSE
)

# Sampling networks
ans <- structure(vector("list", length(j)), names = j)
for (k in j)
ans[[j]] <- x$networks[[k]][i]

return(ans)

}

#' @export
#' @rdname new_rergmito
print.ergmito_sampler <- function(x, ...) {
Expand Down
15 changes: 15 additions & 0 deletions man/new_rergmito.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions playground/speed-count_stats.R
Expand Up @@ -43,3 +43,13 @@ microbenchmark::microbenchmark(
plot(b)

z <- new_rergmito(nets[[1]] ~ edges + mutual + nodeicov("a"))


# Should we still use ergm.allstats? -------------------------------------------
pset <- powerset(5)

microbenchmark::microbenchmark(
ERGMito = count_stats(pset ~ edges + mutual + ctriad),
ergm = ergm.allstats(pset[[1]] ~ edges + mutual + ctriad),
times=5
)
1 change: 1 addition & 0 deletions src/Makevars
Expand Up @@ -4,6 +4,7 @@ CXX_STD = CXX11

# This is necesary since ARMADILLO now supports OpenMP
# PKG_CXXFLAGS=-fopenmp -DARMA_64BIT_WORD
# PKG_CXXFLAGS=-DERGMITO_COUNT_STATS_DEBUG

# For testing
#PKG_CXXFLAGS=-Wall
Expand Down
98 changes: 63 additions & 35 deletions src/count_stats.cpp
@@ -1,32 +1,45 @@
#include <Rcpp.h>
using namespace Rcpp;

/* This function may be useful in the future:
* List x(1);
* x.containsElementNamed("Casa")
*/

inline double count_mutual(const IntegerMatrix & x, const NumericVector & A) {

double count = 0.0;
int count = 0;
for (int i = 0; i < x.nrow(); ++i)
for (int j = i; j < x.nrow(); ++j)
if (i != j && x(i,j) + x(j, i) > 1)
count += 1.0;
++count;

#ifdef ERGMITO_COUNT_STATS_DEBUG
print(x);
Rprintf("[debug count_mutual] %d\n", count);
#endif

return count;
return (double) count;
}

inline double count_edges(const IntegerMatrix & x, const NumericVector & A) {

double count = 0.0;

for (int i = 0; i < x.nrow(); ++i)
for (int j = 0; j < x.nrow(); ++j)
if (x(i,j) > 0)
count += 1;
int count = 0;
for (IntegerMatrix::const_iterator it = x.begin(); it != x.end(); ++it)
if (*it > 0)
++count;

#ifdef ERGMITO_COUNT_STATS_DEBUG
print(x);
Rprintf("[debug count_edges] %d\n", count);
#endif

return count;
return (double) count;
}

inline double count_ttriad(const IntegerMatrix & x, const NumericVector & A) {

double count = 0.0;
int count = 0;

for (int i = 0; i < x.nrow(); ++i)
for (int j = 0; j < x.nrow(); ++j) {
Expand All @@ -39,18 +52,18 @@ inline double count_ttriad(const IntegerMatrix & x, const NumericVector & A) {
// Label 1
if (x(i,j) == 1 && x(i,k) == 1 && x(j,k) == 1)
// if (x(j, i) == 0 && x(k,i) == 0 && x(k,j) == 0)
count += 1.0;
++count;

}
}

return count;
return (double) count;

}

inline double count_ctriad(const IntegerMatrix & x, const NumericVector & A) {

double count = 0.0;
int count = 0;

for (int i = 0; i < x.nrow(); ++i)
for (int j = 0; j < i; ++j) {
Expand All @@ -63,12 +76,12 @@ inline double count_ctriad(const IntegerMatrix & x, const NumericVector & A) {
// Label 1
if (x(i, j) == 1 && x(j, k) == 1 && x(k, i) == 1)
// if (x(j, i) == 0 && x(k, j) == 0 && x(i, k) == 0)
count += 1.0;
++count;

}
}

return count;
return (double) count;

}

Expand All @@ -79,7 +92,7 @@ inline double count_nodecov(const IntegerMatrix & x, const NumericVector & A, bo
for (int i = 0; i < x.nrow(); ++i)
for (int j = 0; j < x.nrow(); ++j)
if (x(i,j) == 1)
count += A.at(ego ? i : j);
count += A[ego ? i : j];

return count;

Expand All @@ -95,14 +108,14 @@ inline double count_nodeocov(const IntegerMatrix & x, const NumericVector & A) {


inline double count_nodematch(const IntegerMatrix & x, const NumericVector & A) {
double count = 0.0;
int count = 0;

for (int i = 0; i < x.nrow(); ++i)
for (int j = 0; j < x.nrow(); ++j)
if (x(i,j) == 1 && A.at(i) == A.at(j))
count += 1.0;
++count;

return count;
return (double) count;
}

inline double count_triangle(const IntegerMatrix & x, const NumericVector & A) {
Expand All @@ -115,14 +128,14 @@ typedef double (*ergm_term_fun)(const IntegerMatrix & x, const NumericVector & A

void get_ergm_term(std::string term, ergm_term_fun & fun) {

if (term == "mutual") fun = &count_mutual;
else if (term == "edges") fun = &count_edges;
else if (term == "ttriad") fun = &count_ttriad;
else if (term == "ctriad") fun = &count_ctriad;
else if (term == "nodeicov") fun = &count_nodeicov;
else if (term == "nodeocov") fun = &count_nodeocov;
if (term == "mutual") fun = &count_mutual;
else if (term == "edges") fun = &count_edges;
else if (term == "ttriad") fun = &count_ttriad;
else if (term == "ctriad") fun = &count_ctriad;
else if (term == "nodeicov") fun = &count_nodeicov;
else if (term == "nodeocov") fun = &count_nodeocov;
else if (term == "nodematch") fun = &count_nodematch;
else if (term == "triangle") fun = &count_triangle;
else if (term == "triangle") fun = &count_triangle;
else
stop("The term %s is not available.", term);

Expand Down Expand Up @@ -156,22 +169,37 @@ NumericMatrix count_stats(
int n = X.size();
int k = terms.size();

bool uses_attributes = false;
NumericVector A_empty(0);
if (A[0].size() != 0) {
if (A.size() != n)
stop("The number of attributes in `A` differs from the number of adjacency matrices.");

uses_attributes = true;
}

NumericMatrix ans(n, k);
ergm_term_fun fun;

int i = 0;
for (int j = 0; j < k; ++j) {

// Getting the function
get_ergm_term(terms.at(j), fun);

for (ListOf< IntegerMatrix >::const_iterator x = X.begin(); x != X.end(); ++x)
ans(i++, j) = fun(*x, A[x.index()]);

i = 0;
get_ergm_term(terms[j], fun);

if (uses_attributes) {
for (int i = 0; i < n; ++i)
ans.at(i, j) = fun(X[i], A[i]);
} else {
for (int i = 0; i < n; ++i)
ans.at(i, j) = fun(X[i], A_empty);
}

}

//
// #ifdef ERGMITO_COUNT_STATS_DEBUG
// print(ans);
// #endif
//
return ans;

}
Expand Down

0 comments on commit 1b30006

Please sign in to comment.