Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

yet another gelu #293

Closed
wants to merge 4 commits into from
Closed

yet another gelu #293

wants to merge 4 commits into from

Conversation

ngc92
Copy link
Contributor

@ngc92 ngc92 commented Apr 29, 2024

more complicated Packet128 for cleaner kernels

@ngc92
Copy link
Contributor Author

ngc92 commented Apr 29, 2024

This is how it would look like if we moved all the casting into the load/store functions:

template<class ElementType>
struct alignas(16) Packed128 {
    __device__ ElementType& operator[](int index) {
        return payload[index];
    }
    __device__ const ElementType& operator[](int index) const {
        return payload[index];
    }
    __device__ float fp32(int index) {
        return static_cast<float>(payload[index]);
    }
    static constexpr const size_t size = sizeof(int4) / sizeof(ElementType);

    ElementType payload[size];
};

// use this function to load a Packet128 from an aligned memory address
template<class ElementType, ELoadMode Mode=ELoadMode::CA>
__device__ __forceinline__ Packed128<ElementType> load_aligned(const ElementType* address, load_mode_t<Mode> mode = {}) {
    int4 bits = generic_load(reinterpret_cast<const int4*>(address), mode);
    Packed128<ElementType> result;
    static_assert(sizeof(bits) == sizeof(result), "Size mismatch.");
    memcpy(&result, &bits, sizeof(bits));
    return result;
}

// use this function to store a Packet128 to an aligned memory address
template<class ElementType, EStoreMode Mode=EStoreMode::WB>
__device__ void store_aligned(ElementType* target, Packed128<ElementType> value, store_mode_t<Mode> mode = {}) {
    int4 bits;
    static_assert(sizeof(bits) == sizeof(value), "Size mismatch.");
    memcpy(&bits, &value, sizeof(bits));
    generic_store(reinterpret_cast<int4*>(target), bits, mode);
}

float cube = 0.044715f * xi * xi * xi;
packet_out[k] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)));
}
store_aligned(out + i, packet_out);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason we loadcs but store without cs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this kernel, we read exactly once and store once, so these hints don't gain us anything. But by not keeping the input in cache, but the output there, maybe the next kernel can be a bit faster. This is just guesswork though, I haven't actually measured this.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it! that makes sense actually

@ngc92
Copy link
Contributor Author

ngc92 commented Apr 29, 2024

close by #298

@ngc92 ngc92 closed this Apr 29, 2024
@ngc92 ngc92 deleted the yet-another-gelu branch May 19, 2024 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants