Skip to content

Commit

Permalink
Async gmem copy support on sm80+ (#1619)
Browse files Browse the repository at this point in the history
* Initial support for cp.async on ampere
  • Loading branch information
shmsong committed May 24, 2022
1 parent 69354da commit 1bb7b65
Show file tree
Hide file tree
Showing 23 changed files with 965 additions and 152 deletions.
16 changes: 16 additions & 0 deletions codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,15 @@ class CudaKernelGenerator : private OptOutConstDispatch {
return ss.str();
}

// Utility function to emit a cp.async intrinsic
void genCpAsync(const LoadStoreOp* ldst, int vec_size) {
auto dtype = ldst->in()->getDataType().value();

indent() << "Ampere::cpAsync("
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
<< genVectorPointer(ldst->in(), dtype, vec_size) << ");\n";
}

void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) {
auto dtype = ldst->in()->getDataType().value();
indent() << "Turing::ldMatrix";
Expand Down Expand Up @@ -1196,6 +1205,9 @@ class CudaKernelGenerator : private OptOutConstDispatch {
vectorize_op, "LdMatrix: Vectorization required: ", ldst);
genLdMatrix(ldst, vector_word_size);
break;
case LoadStoreOpType::CpAsync:
genCpAsync(ldst, vector_word_size);
break;
default:
TORCH_INTERNAL_ASSERT(false, "LoadStoreOp: Unknown op type");
}
Expand Down Expand Up @@ -2120,6 +2132,10 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

void handle(const kir::CpAsyncWait* cpasync_wait) final {
indent() << "Ampere::cpAsyncBarrier();\n";
}

void handle(const kir::GridSync* sync) final {
// Use a custom synchronization method if enabled
bool bidx = sync->syncDims().get(ParallelType::BIDx);
Expand Down
15 changes: 15 additions & 0 deletions dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::GridSync:
ptr(handler)->handle(expr->as<kir::GridSync>());
return;
case ExprType::CpAsyncWait:
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
return;
case ExprType::InitMagicZero:
ptr(handler)->handle(expr->as<kir::InitMagicZero>());
return;
Expand Down Expand Up @@ -304,6 +307,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::GridSync:
ptr(handler)->handle(expr->as<kir::GridSync>());
return;
case ExprType::CpAsyncWait:
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
return;
case ExprType::InitMagicZero:
ptr(handler)->handle(expr->as<kir::InitMagicZero>());
return;
Expand Down Expand Up @@ -465,6 +471,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::GridSync:
ptr(mutator)->mutate(expr->as<kir::GridSync>());
return;
case ExprType::CpAsyncWait:
ptr(mutator)->mutate(expr->as<kir::CpAsyncWait>());
return;
case ExprType::InitMagicZero:
ptr(mutator)->mutate(expr->as<kir::InitMagicZero>());
return;
Expand Down Expand Up @@ -691,6 +700,9 @@ void OptOutConstDispatch::handle(const kir::BlockSync* stmt) {
void OptOutConstDispatch::handle(const kir::GridSync* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::CpAsyncWait* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::InitMagicZero* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -814,6 +826,9 @@ void OptOutDispatch::handle(kir::BlockSync* stmt) {
void OptOutDispatch::handle(kir::GridSync* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::CpAsyncWait* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::InitMagicZero* stmt) {
unhandled(stmt);
}
Expand Down
4 changes: 4 additions & 0 deletions dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class TensorIndex;
class Allocate;
class BlockSync;
class GridSync;
class CpAsyncWait;
class ForLoop;
class IfThenElse;
class GridReduction;
Expand Down Expand Up @@ -152,6 +153,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const kir::Allocate*);
virtual void handle(const kir::BlockSync*);
virtual void handle(const kir::GridSync*);
virtual void handle(const kir::CpAsyncWait*);
virtual void handle(const kir::InitMagicZero*);
virtual void handle(const kir::UpdateMagicZero*);
virtual void handle(const kir::ForLoop*);
Expand Down Expand Up @@ -208,6 +210,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(kir::Allocate* stmt);
virtual void handle(kir::BlockSync* stmt);
virtual void handle(kir::GridSync* stmt);
virtual void handle(kir::CpAsyncWait* stmt);
virtual void handle(kir::InitMagicZero* stmt);
virtual void handle(kir::UpdateMagicZero* stmt);
virtual void handle(kir::ForLoop* stmt);
Expand Down Expand Up @@ -305,6 +308,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(kir::Allocate*);
virtual void mutate(kir::BlockSync*);
virtual void mutate(kir::GridSync*);
virtual void mutate(kir::CpAsyncWait*);
virtual void mutate(kir::InitMagicZero*);
virtual void mutate(kir::UpdateMagicZero*);
virtual void mutate(kir::ForLoop*);
Expand Down
4 changes: 4 additions & 0 deletions ir_iostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ void IrPrinter::handle(const kir::BlockSync* node) {
<< ")\n";
}

void IrPrinter::handle(const kir::CpAsyncWait* node) {
indent() << "CPASYNC_WAIT()\n";
}

void IrPrinter::handle(const kir::GridSync* node) {
indent() << "GRIDSYNC(" << node->syncDims().toString() << ", ";
handle(node->syncBuffer());
Expand Down
1 change: 1 addition & 0 deletions ir_iostream.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
void handle(const kir::Allocate*) final;
void handle(const kir::BlockSync*) final;
void handle(const kir::GridSync*) final;
void handle(const kir::CpAsyncWait*) final;
void handle(const kir::InitMagicZero*) final;
void handle(const kir::UpdateMagicZero*) final;
void handle(const kir::AllocateFusedReduction*) final;
Expand Down
7 changes: 7 additions & 0 deletions kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ GridSync::GridSync(
sync_dims_(sync_dims),
sync_buffer_(sync_buffer) {}

CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey)
: Expr(passkey, ExprType::CpAsyncWait) {
TORCH_INTERNAL_ASSERT(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
}

InitMagicZero::InitMagicZero(IrBuilderPasskey passkey)
: Expr(passkey, ExprType::InitMagicZero) {
TORCH_INTERNAL_ASSERT(
Expand Down
9 changes: 9 additions & 0 deletions kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TensorIndex;
class Allocate;
class BlockSync;
class GridSync;
class CpAsyncWait;
class InitMagicZero;
class UpdateMagicZero;
class ForLoop;
Expand Down Expand Up @@ -256,6 +257,14 @@ class TORCH_CUDA_CU_API BlockSync final : public Expr {
bool war_sync_ = false;
};

// CpAsyncWait represents wait intrinsics for cp.async
// TODO: expand to support different wait modes of the intrinsic
// as the analysis passes build out.
class TORCH_CUDA_CU_API CpAsyncWait final : public Expr {
public:
explicit CpAsyncWait(IrBuilderPasskey passkey);
};

// Synchronize all blocks in device, implies cooperative group launch is
// required.
class TORCH_CUDA_CU_API GridSync final : public Expr {
Expand Down
Loading

0 comments on commit 1bb7b65

Please sign in to comment.