-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
yet another gelu #293
yet another gelu #293
Conversation
This is how it would look like if we moved all the casting into the load/store functions: template<class ElementType>
struct alignas(16) Packed128 {
__device__ ElementType& operator[](int index) {
return payload[index];
}
__device__ const ElementType& operator[](int index) const {
return payload[index];
}
__device__ float fp32(int index) {
return static_cast<float>(payload[index]);
}
static constexpr const size_t size = sizeof(int4) / sizeof(ElementType);
ElementType payload[size];
};
// use this function to load a Packet128 from an aligned memory address
template<class ElementType, ELoadMode Mode=ELoadMode::CA>
__device__ __forceinline__ Packed128<ElementType> load_aligned(const ElementType* address, load_mode_t<Mode> mode = {}) {
int4 bits = generic_load(reinterpret_cast<const int4*>(address), mode);
Packed128<ElementType> result;
static_assert(sizeof(bits) == sizeof(result), "Size mismatch.");
memcpy(&result, &bits, sizeof(bits));
return result;
}
// use this function to store a Packet128 to an aligned memory address
template<class ElementType, EStoreMode Mode=EStoreMode::WB>
__device__ void store_aligned(ElementType* target, Packed128<ElementType> value, store_mode_t<Mode> mode = {}) {
int4 bits;
static_assert(sizeof(bits) == sizeof(value), "Size mismatch.");
memcpy(&bits, &value, sizeof(bits));
generic_store(reinterpret_cast<int4*>(target), bits, mode);
}
|
float cube = 0.044715f * xi * xi * xi; | ||
packet_out[k] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))); | ||
} | ||
store_aligned(out + i, packet_out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason we loadcs but store without cs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for this kernel, we read exactly once and store once, so these hints don't gain us anything. But by not keeping the input in cache, but the output there, maybe the next kernel can be a bit faster. This is just guesswork though, I haven't actually measured this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it! that makes sense actually
close by #298 |
more complicated Packet128 for cleaner kernels