Skip to content

Commit

Permalink
Made batch part of kernel index in interp/gridding.
Browse files Browse the repository at this point in the history
  • Loading branch information
frankong committed Jul 29, 2018
1 parent eadc078 commit 966b0df
Showing 1 changed file with 58 additions and 67 deletions.
125 changes: 58 additions & 67 deletions sigpy/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def gridding(input, shape, width, table, coord):
if device == util.cpu_device:
_gridding(output, input, width, table, coord)
else:
_gridding(output, input, width, table, coord, size=npts)
_gridding(output, input, width, table, coord, size=npts * batch)

return output.reshape(shape)

Expand Down Expand Up @@ -354,6 +354,8 @@ def _gridding3(output, input, width, table, coord):
"""
const int batch = input.shape()[0];
const int nx = input.shape()[1];
const int b = i % batch;
i /= batch;
const int coord_idx[] = {i, 0};
const S posx = coord[coord_idx];
Expand All @@ -362,13 +364,10 @@ def _gridding3(output, input, width, table, coord):
for (int x = startx; x < endx + 1; x++) {
const S w = lin_interp(&table[0], table.size(), fabsf((S) x - posx) / (width / 2.0));
for (int b = 0; b < batch; b++) {
const int input_idx[] = {b, pos_mod(x, nx)};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, i};
output[output_idx] += v;
}
const int input_idx[] = {b, pos_mod(x, nx)};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, i};
output[output_idx] += v;
}
""",
name='interp1', preamble=lin_interp_cuda + pos_mod_cuda, reduce_dims=False)
Expand All @@ -379,6 +378,8 @@ def _gridding3(output, input, width, table, coord):
"""
const int batch = output.shape()[0];
const int nx = output.shape()[1];
const int b = i % batch;
i /= batch;
const int coord_idx[] = {i, 0};
const S posx = coord[coord_idx];
Expand All @@ -387,13 +388,10 @@ def _gridding3(output, input, width, table, coord):
for (int x = startx; x < endx + 1; x++) {
const S w = lin_interp(&table[0], table.size(), fabsf((S) x - posx) / (width / 2.0));
for (int b = 0; b < batch; b++) {
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(x, nx)};
atomicAdd(&output[output_idx], v);
}
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(x, nx)};
atomicAdd(&output[output_idx], v);
}
""",
name='gridding1', preamble=lin_interp_cuda + pos_mod_cuda, reduce_dims=False)
Expand All @@ -404,6 +402,8 @@ def _gridding3(output, input, width, table, coord):
"""
const int batch = output.shape()[0];
const int nx = output.shape()[1];
const int b = i % batch;
i /= batch;
const int coord_idx[] = {i, 0};
const S posx = coord[coord_idx];
Expand All @@ -412,14 +412,11 @@ def _gridding3(output, input, width, table, coord):
for (int x = startx; x < endx + 1; x++) {
const S w = lin_interp(&table[0], table.size(), fabsf((S) x - posx) / (width / 2.0));
for (int b = 0; b < batch; b++) {
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(x, nx)};
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])), v.real());
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])) + 1, v.imag());
}
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(x, nx)};
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])), v.real());
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])) + 1, v.imag());
}
""",
name='gridding1_complex',
Expand All @@ -433,6 +430,8 @@ def _gridding3(output, input, width, table, coord):
const int batch = input.shape()[0];
const int ny = input.shape()[1];
const int nx = input.shape()[2];
const int b = i % batch;
i /= batch;
const int coordx_idx[] = {i, 1};
const S posx = coord[coordx_idx];
Expand All @@ -449,13 +448,10 @@ def _gridding3(output, input, width, table, coord):
const S wy = lin_interp(&table[0], table.size(), fabsf((S) y - posy) / (width / 2.0));
for (int x = startx; x < endx + 1; x++) {
const S w = wy * lin_interp(&table[0], table.size(), fabsf((S) x - posx) / (width / 2.0));
for (int b = 0; b < batch; b++) {
const int input_idx[] = {b, pos_mod(y, ny), pos_mod(x, nx)};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, i};
output[output_idx] += v;
}
const int input_idx[] = {b, pos_mod(y, ny), pos_mod(x, nx)};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, i};
output[output_idx] += v;
}
}
""",
Expand All @@ -468,6 +464,8 @@ def _gridding3(output, input, width, table, coord):
const int batch = output.shape()[0];
const int ny = output.shape()[1];
const int nx = output.shape()[2];
const int b = i % batch;
i /= batch;
const int coordx_idx[] = {i, 1};
const S posx = coord[coordx_idx];
Expand All @@ -484,13 +482,10 @@ def _gridding3(output, input, width, table, coord):
const S wy = lin_interp(&table[0], table.size(), fabsf((S) y - posy) / (width / 2.0));
for (int x = startx; x < endx + 1; x++) {
const S w = wy * lin_interp(&table[0], table.size(), fabsf((S) x - posx) / (width / 2.0));
for (int b = 0; b < batch; b++) {
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(y, ny), pos_mod(x, nx)};
atomicAdd(&output[output_idx], v);
}
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(y, ny), pos_mod(x, nx)};
atomicAdd(&output[output_idx], v);
}
}
""",
Expand All @@ -503,6 +498,8 @@ def _gridding3(output, input, width, table, coord):
const int batch = output.shape()[0];
const int ny = output.shape()[1];
const int nx = output.shape()[2];
const int b = i % batch;
i /= batch;
const int coordx_idx[] = {i, 1};
const S posx = coord[coordx_idx];
Expand All @@ -519,14 +516,11 @@ def _gridding3(output, input, width, table, coord):
const S wy = lin_interp(&table[0], table.size(), fabsf((S) y - posy) / (width / 2.0));
for (int x = startx; x < endx + 1; x++) {
const S w = wy * lin_interp(&table[0], table.size(), fabsf((S) x - posx) / (width / 2.0));
for (int b = 0; b < batch; b++) {
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(y, ny), pos_mod(x, nx)};
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])), v.real());
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])) + 1, v.imag());
}
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(y, ny), pos_mod(x, nx)};
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])), v.real());
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])) + 1, v.imag());
}
}
""",
Expand All @@ -542,6 +536,8 @@ def _gridding3(output, input, width, table, coord):
const int nz = input.shape()[1];
const int ny = input.shape()[2];
const int nx = input.shape()[3];
const int b = i % batch;
i /= batch;
const int coordz_idx[] = {i, 0};
const S posz = coord[coordz_idx];
Expand All @@ -564,13 +560,10 @@ def _gridding3(output, input, width, table, coord):
const S wy = wz * lin_interp(&table[0], table.size(), fabsf((S) y - posy) / (width / 2.0));
for (int x = startx; x < endx + 1; x++) {
const S w = wy * lin_interp(&table[0], table.size(), fabsf((S) x - posx) / (width / 2.0));
for (int b = 0; b < batch; b++) {
const int input_idx[] = {b, pos_mod(z, nz), pos_mod(y, ny), pos_mod(x, nx)};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, i};
output[output_idx] += v;
}
const int input_idx[] = {b, pos_mod(z, nz), pos_mod(y, ny), pos_mod(x, nx)};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, i};
output[output_idx] += v;
}
}
}
Expand All @@ -585,6 +578,8 @@ def _gridding3(output, input, width, table, coord):
const int nz = output.shape()[1];
const int ny = output.shape()[2];
const int nx = output.shape()[3];
const int b = i % batch;
i /= batch;
const int coordz_idx[] = {i, 0};
const S posz = coord[coordz_idx];
Expand All @@ -607,13 +602,10 @@ def _gridding3(output, input, width, table, coord):
const S wy = wz * lin_interp(&table[0], table.size(), fabsf((S) y - posy) / (width / 2.0));
for (int x = startx; x < endx + 1; x++) {
const S w = wy * lin_interp(&table[0], table.size(), fabsf((S) x - posx) / (width / 2.0));
for (int b = 0; b < batch; b++) {
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(z, nz), pos_mod(y, ny), pos_mod(x, nx)};
atomicAdd(&output[output_idx], v);
}
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(z, nz), pos_mod(y, ny), pos_mod(x, nx)};
atomicAdd(&output[output_idx], v);
}
}
}
Expand All @@ -628,6 +620,8 @@ def _gridding3(output, input, width, table, coord):
const int nz = output.shape()[1];
const int ny = output.shape()[2];
const int nx = output.shape()[3];
const int b = i % batch;
i /= batch;
const int coordz_idx[] = {i, 0};
const S posz = coord[coordz_idx];
Expand All @@ -650,14 +644,11 @@ def _gridding3(output, input, width, table, coord):
const S wy = wz * lin_interp(&table[0], table.size(), fabsf((S) y - posy) / (width / 2.0));
for (int x = startx; x < endx + 1; x++) {
const S w = wy * lin_interp(&table[0], table.size(), fabsf((S) x - posx) / (width / 2.0));
for (int b = 0; b < batch; b++) {
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(z, nz), pos_mod(y, ny), pos_mod(x, nx)};
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])), v.real());
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])) + 1, v.imag());
}
const int input_idx[] = {b, i};
const T v = (T) w * input[input_idx];
const int output_idx[] = {b, pos_mod(z, nz), pos_mod(y, ny), pos_mod(x, nx)};
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])), v.real());
atomicAdd(reinterpret_cast<T::value_type*>(&(output[output_idx])) + 1, v.imag());
}
}
}
Expand Down

0 comments on commit 966b0df

Please sign in to comment.