Skip to content

Commit

Permalink
Support mixed precision with vexcl backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ddemidov committed Aug 26, 2018
1 parent bc78704 commit 332fdc6
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions amgcl/backend/vexcl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,16 @@ struct bytes_impl< solver::vexcl_skyline_lu<V> > {
}
};

template < typename Alpha, typename Beta, typename V, typename C, typename P >
template < typename Alpha, typename Beta, typename Va, typename Vx, typename Vy, typename C, typename P >
struct spmv_impl<
Alpha, vex::sparse::distributed<vex::sparse::matrix<V,C,P>>, vex::vector<V>,
Beta, vex::vector<V>
Alpha, vex::sparse::distributed<vex::sparse::matrix<Va,C,P>>, vex::vector<Vx>,
Beta, vex::vector<Vy>
>
{
typedef vex::sparse::distributed<vex::sparse::matrix<V,C,P>> matrix;
typedef vex::vector<V> vector;
typedef vex::sparse::distributed<vex::sparse::matrix<Va,C,P>> matrix;

static void apply(Alpha alpha, const matrix &A, const vector &x,
Beta beta, vector &y)
static void apply(Alpha alpha, const matrix &A, const vex::vector<Vx> &x,
Beta beta, vex::vector<Vy> &y)
{
if (beta)
y = alpha * (A * x) + beta * y;
Expand All @@ -324,19 +323,18 @@ struct spmv_impl<
}
};

template < typename V, typename C, typename P >
template < typename Va, typename Vf, typename Vx, typename Vr, typename C, typename P >
struct residual_impl<
vex::sparse::distributed<vex::sparse::matrix<V,C,P>>,
vex::vector<V>,
vex::vector<V>,
vex::vector<V>
vex::sparse::distributed<vex::sparse::matrix<Va,C,P>>,
vex::vector<Vf>,
vex::vector<Vx>,
vex::vector<Vr>
>
{
typedef vex::sparse::distributed<vex::sparse::matrix<V,C,P>> matrix;
typedef vex::vector<V> vector;
typedef vex::sparse::distributed<vex::sparse::matrix<Va,C,P>> matrix;

static void apply(const vector &rhs, const matrix &A, const vector &x,
vector &r)
static void apply(const vex::vector<Vf> &rhs, const matrix &A, const vex::vector<Vx> &x,
vex::vector<Vr> &r)
{
r = rhs - A * x;
}
Expand Down Expand Up @@ -425,14 +423,14 @@ struct axpbypcz_impl<
}
};

template < typename A, typename B, typename V >
template < typename A, typename B, typename Vx, typename Vy, typename Vz >
struct vmul_impl<
A, vex::vector<V>, vex::vector<V>,
B, vex::vector<V>
A, vex::vector<Vx>, vex::vector<Vy>,
B, vex::vector<Vz>
>
{
static void apply(A a, const vex::vector<V> &x, const vex::vector<V> &y,
B b, vex::vector<V> &z)
static void apply(A a, const vex::vector<Vx> &x, const vex::vector<Vy> &y,
B b, vex::vector<Vz> &z)
{
if (b)
z = a * x * y + b * z;
Expand Down

0 comments on commit 332fdc6

Please sign in to comment.