Permalink
Browse files

power ep/convex ep.

set power to 0 for convexity, 1 for standard. something else for craziness
  • Loading branch information...
1 parent d58eff6 commit 55dab7018c6ac2e7d6d9232aef6f2362a003e70c @dlwh committed May 3, 2012
Showing with 32 additions and 9 deletions.
  1. +9 −3 gpuep/ising.c
  2. +1 −1 gpuep/ising.h
  3. +12 −2 gpuep/kernel.cl
  4. +10 −3 gpuep/main.c
View
@@ -38,7 +38,7 @@ void random_fill_ising(ising_t *ising, float lowerBound, float upperBound, unsig
}
-int do_inference(ising_t* result, ising_t model, cl_context context, cl_device_id device_id, int numIter) {
+int do_inference(ising_t* result, ising_t model, cl_context context, cl_device_id device_id, float power, int numIter) {
construct_ising(result, model.rows, model.cols);
char* KernelSource=read_kernel("kernel.cl");
@@ -88,6 +88,11 @@ int do_inference(ising_t* result, ising_t model, cl_context context, cl_device_i
}
unsigned count = model.rows * model.cols;
+ // every node has 2 edges (one to the right, and one down) except for bottom row and the right column, which have one fewer.
+ if(power == 0) {
+ float numEdges = 2 * (model.rows * model.cols) - model.rows - model.cols;
+ power = numEdges;
+ }
cl_mem pair = clCreateBuffer(context, CL_MEM_READ_ONLY, sizeof(float) * count * 2, NULL, NULL);
cl_mem single = clCreateBuffer(context, CL_MEM_READ_ONLY, sizeof(float) * count, NULL, NULL);
@@ -126,8 +131,9 @@ int do_inference(ising_t* result, ising_t model, cl_context context, cl_device_i
err |= clSetKernelArg(kernelInf, 1, sizeof(cl_mem), &single_out);
err |= clSetKernelArg(kernelInf, 2, sizeof(cl_mem), &message1);
err |= clSetKernelArg(kernelInf, 3, sizeof(cl_mem), &message2);
- err |= clSetKernelArg(kernelInf, 4, sizeof(int), &model.rows);
- err |= clSetKernelArg(kernelInf, 5, sizeof(int), &model.cols);
+ err |= clSetKernelArg(kernelInf, 4, sizeof(float), &power);
+ err |= clSetKernelArg(kernelInf, 5, sizeof(int), &model.rows);
+ err |= clSetKernelArg(kernelInf, 6, sizeof(int), &model.cols);
if (err != CL_SUCCESS)
{
printf("Error: Failed to set kernel arguments! %d\n", err);
View
@@ -28,7 +28,7 @@ extern void ising_print_single(ising_t ising);
extern void ising_print_pair(ising_t ising);
/// Inference routines
-extern int do_inference(ising_t* result, ising_t model, cl_context context, cl_device_id device_id, int numIter);
+extern int do_inference(ising_t* result, ising_t model, cl_context context, cl_device_id device_id, float power, int numIter);
extern int sequential_inference(ising_t* result, ising_t model, int numIter);
static inline float get_ising_singleton(ising_t* ising, int row, int col) {
View
@@ -10,7 +10,9 @@ __kernel void updateFactor(__global float* ising_pair,
__global float* ising_single,
__global float* ising_message,
__global float* ising_message_out,
+ float numEdges,
int rows, int cols) {
+ float one_over_numEdges = native_recip(numEdges);
int r = get_global_id(0);
int c = get_global_id(1);
im_dir_t dir = get_global_id(2);
@@ -24,8 +26,8 @@ __kernel void updateFactor(__global float* ising_pair,
float edgeWeight = ising_pair[(r * cols + c) * 2 + dir];
float mesgToA = ising_message[(r * cols + c) * 4 + dir];
float mesgToB = ising_message[(nr * cols + nc) * 4 + otherDir];
- marginalWeightA -= mesgToA;
- marginalWeightB -= mesgToB;
+ marginalWeightA -= mesgToA * one_over_numEdges;
+ marginalWeightB -= mesgToB * one_over_numEdges;
float jointMarginal11 = exp(marginalWeightA + marginalWeightB + edgeWeight);
float jointMarginal10 = exp(marginalWeightA);
@@ -41,6 +43,14 @@ __kernel void updateFactor(__global float* ising_pair,
float adjA = newTargetA - marginalWeightA;
float adjB = newTargetB - marginalWeightB;
+ // damp updates
+ if(numEdges != 1.0f) {
+ adjA *= one_over_numEdges;
+ adjB *= one_over_numEdges;
+ adjA += (1 - one_over_numEdges) * mesgToA;
+ adjB += (1 - one_over_numEdges) * mesgToB;
+ }
+
ising_message_out[(r * cols + c) * 4 + dir] = adjA;
ising_message_out[(nr * cols + nc) * 4 + otherDir] = adjB;
}
View
@@ -12,9 +12,12 @@
void test(int a, int b);
+#define NUM_ROWS 5
+#define NUM_COLS 6
+
int main (int argc, const char * argv[]) {
ising_t input;
- construct_ising(&input, 5, 6);
+ construct_ising(&input, NUM_ROWS, NUM_COLS);
ising_t output;
unsigned seed = 3;
@@ -42,15 +45,19 @@ int main (int argc, const char * argv[]) {
printf("model:\n");
ising_print(input);
- do_inference(&output, input, context, device_id, 400);
+ do_inference(&output, input, context, device_id, 0, 400);
+ printf("EP parallel convex:\n");
+ ising_print_single(output);
+
+ do_inference(&output, input, context, device_id, 1, 400);
printf("EP parallel:\n");
ising_print_single(output);
printf("EP sequential:\n");
sequential_inference(&output, input, 400);
ising_print_single(output);
- //ising_t exact;
+ ising_t exact;
// exact_marginals_parallel(&exact, input, context, device_id);
// printf("Exact parallel log domain:\n");
// ising_print_single(exact);

0 comments on commit 55dab70

Please sign in to comment.