Skip to content

Commit 40bb155

Browse files
committed
remove use of thrust vector everywhere and other refactoring changes
1 parent 9ce2976 commit 40bb155

File tree

50 files changed

+1999
-2503
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1999
-2503
lines changed

include/DataTypeOverloads.h

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// ---------------------------------------------------------------------
2+
//
3+
// Copyright (c) 2017-2022 The Regents of the University of Michigan and DFT-FE
4+
// authors.
5+
//
6+
// This file is part of the DFT-FE code.
7+
//
8+
// The DFT-FE code is free software; you can use it, redistribute
9+
// it, and/or modify it under the terms of the GNU Lesser General
10+
// Public License as published by the Free Software Foundation; either
11+
// version 2.1 of the License, or (at your option) any later version.
12+
// The full text of the license can be found in the file LICENSE at
13+
// the top level of the DFT-FE distribution.
14+
//
15+
// ---------------------------------------------------------------------
16+
17+
18+
#ifndef dftfeDataTypeOverloads_h
19+
#define dftfeDataTypeOverloads_h
20+
21+
#include <complex>
22+
namespace dftfe
23+
{
24+
namespace utils
25+
{
26+
inline double
27+
realPart(const double x)
28+
{
29+
return x;
30+
}
31+
32+
inline float
33+
realPart(const float x)
34+
{
35+
return x;
36+
}
37+
38+
inline double
39+
realPart(const std::complex<double> x)
40+
{
41+
return x.real();
42+
}
43+
44+
inline float
45+
realPart(const std::complex<float> x)
46+
{
47+
return x.real();
48+
}
49+
50+
inline double
51+
imagPart(const double x)
52+
{
53+
return 0;
54+
}
55+
56+
57+
inline float
58+
imagPart(const float x)
59+
{
60+
return 0;
61+
}
62+
63+
inline double
64+
imagPart(const std::complex<double> x)
65+
{
66+
return x.imag();
67+
}
68+
69+
inline float
70+
imagPart(const std::complex<float> x)
71+
{
72+
return x.imag();
73+
}
74+
75+
inline double
76+
complexConj(const double x)
77+
{
78+
return x;
79+
}
80+
81+
inline float
82+
complexConj(const float x)
83+
{
84+
return x;
85+
}
86+
87+
inline std::complex<double>
88+
complexConj(const std::complex<double> x)
89+
{
90+
return std::conj(x);
91+
}
92+
93+
inline std::complex<float>
94+
complexConj(const std::complex<float> x)
95+
{
96+
return std::conj(x);
97+
}
98+
}
99+
} // namespace dftfe
100+
101+
#endif

include/DeviceDataTypeOverloads.cu.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ namespace dftfe
5555
return a;
5656
}
5757

58+
__inline__ __device__ global_size_type
59+
conj(global_size_type a)
60+
{
61+
return a;
62+
}
63+
5864
__inline__ __device__ int
5965
conj(int a)
6066
{
@@ -94,6 +100,12 @@ namespace dftfe
94100
return a * b;
95101
}
96102

103+
__inline__ __device__ global_size_type
104+
mult(global_size_type a, global_size_type b)
105+
{
106+
return a * b;
107+
}
108+
97109
__inline__ __device__ int
98110
mult(int a, int b)
99111
{
@@ -192,6 +204,12 @@ namespace dftfe
192204
return a + b;
193205
}
194206

207+
__inline__ __device__ global_size_type
208+
add(global_size_type a, global_size_type b)
209+
{
210+
return a + b;
211+
}
212+
195213
__inline__ __device__ int
196214
add(int a, int b)
197215
{
@@ -229,6 +247,12 @@ namespace dftfe
229247
return a - b;
230248
}
231249

250+
__inline__ __device__ global_size_type
251+
sub(global_size_type a, global_size_type b)
252+
{
253+
return a - b;
254+
}
255+
232256
__inline__ __device__ int
233257
sub(int a, int b)
234258
{
@@ -265,6 +289,12 @@ namespace dftfe
265289
return a / b;
266290
}
267291

292+
__inline__ __device__ global_size_type
293+
div(global_size_type a, global_size_type b)
294+
{
295+
return a / b;
296+
}
297+
268298
__inline__ __device__ int
269299
div(int a, int b)
270300
{
@@ -378,6 +408,18 @@ namespace dftfe
378408
return a;
379409
}
380410

411+
inline global_size_type *
412+
makeDataTypeDeviceCompatible(global_size_type *a)
413+
{
414+
return a;
415+
}
416+
417+
inline const global_size_type *
418+
makeDataTypeDeviceCompatible(const global_size_type *a)
419+
{
420+
return a;
421+
}
422+
381423
inline double *
382424
makeDataTypeDeviceCompatible(double *a)
383425
{
@@ -438,6 +480,12 @@ namespace dftfe
438480
return a;
439481
}
440482

483+
inline global_size_type
484+
makeDataTypeDeviceCompatible(global_size_type a)
485+
{
486+
return a;
487+
}
488+
441489
inline double
442490
makeDataTypeDeviceCompatible(double a)
443491
{

include/DeviceTypeConfig.cu.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
#ifndef dftfeDeviceTypeConfig_cuh
1818
#define dftfeDeviceTypeConfig_cuh
1919

20+
# include <cuComplex.h>
2021

2122
namespace dftfe
2223
{
2324
namespace utils
2425
{
26+
typedef cuDoubleComplex deviceDoubleComplex;
2527
typedef cudaStream_t deviceStream_t;
2628
typedef cudaEvent_t deviceEvent_t;
2729
typedef cudaError_t deviceError_t;

include/MemoryStorage.h

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ namespace dftfe
8181
*/
8282
~MemoryStorage();
8383

84+
/**
85+
* @brief clear and set to d_data to nullptr
86+
*/
87+
void
88+
clear();
89+
90+
8491
/**
8592
* @brief Set all the entries to a given value
8693
* @param[in] val The value to which the entries are to be set
@@ -142,26 +149,23 @@ namespace dftfe
142149
MemoryStorage &
143150
operator=(MemoryStorage &&rhs) noexcept;
144151

145-
// // This part does not work for GPU version, will work on this
146-
// until
147-
// // having cleaner solution.
148-
// /**
149-
// * @brief Operator to get a reference to a element of the Vector
150-
// * @param[in] i is the index to the element of the Vector
151-
// * @returns reference to the element of the Vector
152-
// * @throws exception if i >= size of the Vector
153-
// */
154-
// reference
155-
// operator[](size_type i);
156-
//
157-
// /**
158-
// * @brief Operator to get a const reference to a element of the Vector
159-
// * @param[in] i is the index to the element of the Vector
160-
// * @returns const reference to the element of the Vector
161-
// * @throws exception if i >= size of the Vector
162-
// */
163-
// const_reference
164-
// operator[](size_type i) const;
152+
/**
153+
* @brief Operator to get a reference to a element of the Vector
154+
* @param[in] i is the index to the element of the Vector
155+
* @returns reference to the element of the Vector
156+
* @throws exception if i >= size of the Vector
157+
*/
158+
reference
159+
operator[](size_type i);
160+
161+
/**
162+
* @brief Operator to get a const reference to a element of the Vector
163+
* @param[in] i is the index to the element of the Vector
164+
* @returns const reference to the element of the Vector
165+
* @throws exception if i >= size of the Vector
166+
*/
167+
const_reference
168+
operator[](size_type i) const;
165169

166170
void
167171
swap(MemoryStorage &rhs);

include/chebyshevOrthogonalizedSubspaceIterationSolverDevice.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ namespace dftfe
6060
double
6161
solve(operatorDFTDeviceClass & operatorMatrix,
6262
elpaScalaManager & elpaScala,
63-
dataTypes::numberDevice *eigenVectorsFlattenedDevice,
64-
dataTypes::numberDevice *eigenVectorsRotFracDensityFlattenedDevice,
63+
dataTypes::number *eigenVectorsFlattenedDevice,
64+
dataTypes::number *eigenVectorsRotFracDensityFlattenedDevice,
6565
const unsigned int flattenedSize,
6666
const unsigned int totalNumberWaveFunctions,
6767
std::vector<double> & eigenValues,
@@ -80,7 +80,7 @@ namespace dftfe
8080
void
8181
solveNoRR(operatorDFTDeviceClass & operatorMatrix,
8282
elpaScalaManager & elpaScala,
83-
dataTypes::numberDevice *eigenVectorsFlattenedDevice,
83+
dataTypes::number *eigenVectorsFlattenedDevice,
8484
const unsigned int flattenedSize,
8585
const unsigned int totalNumberWaveFunctions,
8686
std::vector<double> & eigenValues,
@@ -96,7 +96,7 @@ namespace dftfe
9696
void
9797
densityMatrixEigenBasisFirstOrderResponse(
9898
operatorDFTDeviceClass & operatorMatrix,
99-
dataTypes::numberDevice * eigenVectorsFlattenedDevice,
99+
dataTypes::number * eigenVectorsFlattenedDevice,
100100
const unsigned int flattenedSize,
101101
const unsigned int totalNumberWaveFunctions,
102102
const std::vector<double> &eigenValues,
@@ -137,16 +137,16 @@ namespace dftfe
137137
//
138138
// temporary parallel vectors needed for Chebyshev filtering
139139
//
140-
distributedDeviceVec<dataTypes::numberDevice> d_YArray;
140+
distributedDeviceVec<dataTypes::number> d_YArray;
141141

142-
distributedDeviceVec<dataTypes::numberFP32Device>
142+
distributedDeviceVec<dataTypes::numberFP32>
143143
d_deviceFlattenedFloatArrayBlock;
144144

145-
distributedDeviceVec<dataTypes::numberDevice> d_deviceFlattenedArrayBlock2;
145+
distributedDeviceVec<dataTypes::number> d_deviceFlattenedArrayBlock2;
146146

147-
distributedDeviceVec<dataTypes::numberDevice> d_YArray2;
147+
distributedDeviceVec<dataTypes::number> d_YArray2;
148148

149-
distributedDeviceVec<dataTypes::numberDevice> d_projectorKetTimesVector2;
149+
distributedDeviceVec<dataTypes::number> d_projectorKetTimesVector2;
150150

151151
bool d_isTemporaryParallelVectorsCreated;
152152

include/constraintMatrixInfoDevice.h

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
# ifndef constraintMatrixInfoDevice_H_
2020
# define constraintMatrixInfoDevice_H_
2121

22-
# include <thrust/device_vector.h>
23-
22+
#include <MemoryStorage.h>
2423
# include <vector>
2524

26-
# include "headers.h"
25+
# include <headers.h>
2726

2827
namespace dftfe
2928
{
@@ -105,7 +104,7 @@ namespace dftfe
105104

106105
inline void
107106
distribute_slave_to_master(
108-
distributedDeviceVec<cuDoubleComplex> &fieldVector,
107+
distributedDeviceVec<std::complex<double>> &fieldVector,
109108
const unsigned int blockSize) const
110109
{}
111110

@@ -121,7 +120,7 @@ namespace dftfe
121120
*/
122121
void
123122
distribute_slave_to_master(
124-
distributedDeviceVec<cuDoubleComplex> &fieldVector,
123+
distributedDeviceVec<std::complex<double>> &fieldVector,
125124
double * tempReal,
126125
double * tempImag,
127126
const unsigned int blockSize) const;
@@ -146,7 +145,7 @@ namespace dftfe
146145
*/
147146
void
148147
distribute_slave_to_master(
149-
distributedDeviceVec<cuFloatComplex> &fieldVector,
148+
distributedDeviceVec<std::complex<float>> &fieldVector,
150149
float * tempReal,
151150
float * tempImag,
152151
const unsigned int blockSize) const;
@@ -180,13 +179,13 @@ namespace dftfe
180179
std::vector<dealii::types::global_dof_index>
181180
d_localIndexMapUnflattenedToFlattened;
182181

183-
thrust::device_vector<unsigned int> d_rowIdsLocalDevice;
184-
thrust::device_vector<unsigned int> d_columnIdsLocalDevice;
185-
thrust::device_vector<double> d_columnValuesDevice;
186-
thrust::device_vector<double> d_inhomogenitiesDevice;
187-
thrust::device_vector<unsigned int> d_rowSizesDevice;
188-
thrust::device_vector<unsigned int> d_rowSizesAccumulatedDevice;
189-
thrust::device_vector<dealii::types::global_dof_index>
182+
dftfe::utils::MemoryStorage<unsigned int, dftfe::utils::MemorySpace::DEVICE> d_rowIdsLocalDevice;
183+
dftfe::utils::MemoryStorage<unsigned int, dftfe::utils::MemorySpace::DEVICE> d_columnIdsLocalDevice;
184+
dftfe::utils::MemoryStorage<double, dftfe::utils::MemorySpace::DEVICE> d_columnValuesDevice;
185+
dftfe::utils::MemoryStorage<double, dftfe::utils::MemorySpace::DEVICE> d_inhomogenitiesDevice;
186+
dftfe::utils::MemoryStorage<unsigned int, dftfe::utils::MemorySpace::DEVICE> d_rowSizesDevice;
187+
dftfe::utils::MemoryStorage<unsigned int, dftfe::utils::MemorySpace::DEVICE> d_rowSizesAccumulatedDevice;
188+
dftfe::utils::MemoryStorage<dealii::types::global_dof_index, dftfe::utils::MemorySpace::DEVICE>
190189
d_localIndexMapUnflattenedToFlattenedDevice;
191190

192191
std::vector<unsigned int> d_rowIdsLocalBins;
@@ -196,10 +195,10 @@ namespace dftfe
196195
std::vector<unsigned int> d_binColumnSizes;
197196
std::vector<unsigned int> d_binColumnSizesAccumulated;
198197

199-
thrust::device_vector<unsigned int> d_rowIdsLocalBinsDevice;
200-
thrust::device_vector<unsigned int> d_columnIdsLocalBinsDevice;
201-
thrust::device_vector<unsigned int> d_columnIdToRowIdMapBinsDevice;
202-
thrust::device_vector<double> d_columnValuesBinsDevice;
198+
dftfe::utils::MemoryStorage<unsigned int, dftfe::utils::MemorySpace::DEVICE> d_rowIdsLocalBinsDevice;
199+
dftfe::utils::MemoryStorage<unsigned int, dftfe::utils::MemorySpace::DEVICE> d_columnIdsLocalBinsDevice;
200+
dftfe::utils::MemoryStorage<unsigned int, dftfe::utils::MemorySpace::DEVICE> d_columnIdToRowIdMapBinsDevice;
201+
dftfe::utils::MemoryStorage<double, dftfe::utils::MemorySpace::DEVICE> d_columnValuesBinsDevice;
203202

204203
unsigned int d_numConstrainedDofs;
205204
};

0 commit comments

Comments
 (0)