Skip to content

Commit

Permalink
[Fix] Add TVM_DLL to Disco session
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Mar 29, 2024
1 parent 3117505 commit 6d47d37
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
16 changes: 8 additions & 8 deletions include/tvm/runtime/disco/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,28 @@ inline std::string ReduceKind2String(ReduceKind kind) {
* \param device The default device used to initialize the RelaxVM
* \return The RelaxVM as a runtime Module
*/
Module LoadVMModule(std::string path, Device device);
TVM_DLL Module LoadVMModule(std::string path, Device device);
/*!
* \brief Create an uninitialized empty NDArray
* \param shape The shape of the NDArray
* \param dtype The dtype of the NDArray
* \param device The device the NDArray is created on. If None, use the thread local default device
* \return The NDArray created
*/
NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device device);
TVM_DLL NDArray DiscoEmptyNDArray(ShapeTuple shape, DataType dtype, Device device);
/*!
* \brief Perform an allreduce operation using the underlying communication library
* \param send The array send to perform allreduce on
* \param reduce_kind The kind of reduction operation (e.g. sum, avg, min, max)
* \param recv The array receives the outcome of allreduce
*/
void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
TVM_DLL void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv);
/*!
* \brief Perform an allgather operation using the underlying communication library
* \param send The array send to perform allgather on
* \param recv The array receives the outcome of allgather
*/
void AllGather(NDArray send, NDArray recv);
TVM_DLL void AllGather(NDArray send, NDArray recv);
/*!
* \brief Perform a broadcast operation from worker-0
* \param send The buffer to be broadcasted
Expand All @@ -103,20 +103,20 @@ TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv);
* \param recv For worker-0, it must be provided, and otherwise, the buffer must be None. The
* receiving buffer will be divided into equal parts and receive from each worker accordingly.
*/
void GatherToWorker0(NDArray send, Optional<NDArray> recv);
TVM_DLL void GatherToWorker0(NDArray send, Optional<NDArray> recv);
/*!
* \brief Receive a buffer from worker-0. No-op if the current worker is worker-0.
* \param buffer The buffer to be received
*/
void RecvFromWorker0(NDArray buffer);
TVM_DLL void RecvFromWorker0(NDArray buffer);
/*! \brief Get the local worker id */
int WorkerId();
TVM_DLL int WorkerId();
/*!
* \brief Called by the worker thread. Waiting until the worker completes all its tasks.
* As a specific example, on a CUDA worker, it blocks until all kernels are launched and
* cudaStreamSynchronize is complete.
*/
void SyncWorker();
TVM_DLL void SyncWorker();

} // namespace runtime
} // namespace tvm
Expand Down
18 changes: 9 additions & 9 deletions include/tvm/runtime/disco/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,51 +196,51 @@ class SessionObj : public Object {
* The second element must be 0, which will later be updated by the session to return reg_id
* The thirtd element is the function to be called.
*/
virtual DRef CallWithPacked(const TVMArgs& args) = 0;
TVM_DLL virtual DRef CallWithPacked(const TVMArgs& args) = 0;
/*! \brief Get a global functions on workers. */
virtual DRef GetGlobalFunc(const std::string& name) = 0;
TVM_DLL virtual DRef GetGlobalFunc(const std::string& name) = 0;
/*!
* \brief Copy an NDArray from worker-0 to the controler-side NDArray
* \param host_array The array to be copied to worker-0
* \param remote_array The NDArray on worker-0
*/
virtual void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
TVM_DLL virtual void CopyFromWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
/*!
* \brief Copy the controler-side NDArray to worker-0
* \param host_array The array to be copied to worker-0
* \param remote_array The NDArray on worker-0
*/
virtual void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
TVM_DLL virtual void CopyToWorker0(const NDArray& host_array, const DRef& remote_array) = 0;
/*!
* \brief Synchrnoize the controler with a worker, and it will wait until worker finishes
* executing this instruction.
* \param worker_id The id of the worker to be synced with.
* \note This function is usually used for worker-0, because it is the only worker that is
* assumed to collocate with the controler. Syncing with other workers may not be supported.
*/
virtual void SyncWorker(int worker_id) = 0;
TVM_DLL virtual void SyncWorker(int worker_id) = 0;
/*! \brief Signal all the workers to shutdown */
virtual void Shutdown() = 0;
TVM_DLL virtual void Shutdown() = 0;
/*!
* \brief Initialize the data plane between workers.
* \param ccl The name of the communication backend, e.g., nccl, rccl, mpi.
* \param device_ids The device ids of the workers.
*/
virtual void InitCCL(String ccl, IntTuple device_ids) = 0;
TVM_DLL virtual void InitCCL(String ccl, IntTuple device_ids) = 0;
/*!
* \brief Get the value of a register from a remote worker.
* \param reg_id The id of the register to be fetched.
* \param worker_id The id of the worker to be fetched from.
* \return The value of the register.
*/
virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;
TVM_DLL virtual TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) = 0;
/*!
* \brief Set the value of a register on a remote worker.
* \param reg_id The id of the register to be set.
* \param value The value to be set.
* \param worker_id The id of the worker to be set.
*/
virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) = 0;
TVM_DLL virtual void DebugSetRegister(int64_t reg_id, TVMArgValue value, int worker_id) = 0;

struct FFI;
friend struct SessionObj::FFI;
Expand Down

0 comments on commit 6d47d37

Please sign in to comment.