Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated adamw to use packed data types #303

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

ChrisDryden
Copy link
Contributor

Before Runtime
total average iteration time: 38.547570 ms

After Runtime:
total average iteration time: 37.901735 ms

Kernel development file specs:
Barely noticeable with the current test suite:
Before:
time gpu 0.0098 ms
After:
time gpu 0.0097 ms

Copy link
Contributor

@ademeure ademeure left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need feedback from others on what the right approach is here I think, we didn't think of these issues during the f128 discussion :(

x128 packed_params_memory = load128(params_memory+(i*x128::size));
f128 packed_m_memory = load128(m_memory+(i*f128::size));
f128 packed_v_memory = load128(v_memory+(i*f128::size));
for(int k = 0; k < packed_v_memory.size; ++k){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is iterating based on the size of a f128 = 4 elements, but packed_grads and packed_params are x128 = 8 elements (for BF16), so I think this means we are loading twice as much data as we need for the latter and wasting it (or, hopefully, the compiler optimises the loads away and we end up with LDG.64? that might be OK if so tbh)

Ideally we'd assert that the number of elements in a x128 is an integer multiple of a f128 (e.g. 1/2/4), and the kernel would work on the larger number of elements of the two, with both an inner and an outer loop... more complicated than I expected, and one case where the fetch8 approach that always fetches 8 elements would have been very slightly simpler :(

@ngc92 do you have any thoughts about how this should work with your FP16 moments changes? Potentially we could just be lazy and combine both changes and assert the sizeof() of the params is the same as the sizeof() of the moments?...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did some testing with hard coding loading as a 64 bit and didn't see a noticable time difference, I think the majority of the latency if from the warp stalls which appear to have the same amount whether its a 64 or a 128 read.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not hard-code the assumption that these are the same datatype. Let's just code this kernel properly, so that it iterates correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can generally assume sizeof(param) <= sizeof(moment), right? That should be enough for a simple implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add that assumption and the boundary checks in the kernel instantiation

f128 packed_m_memory = load128(m_memory+(i*f128::size));
f128 packed_v_memory = load128(v_memory+(i*f128::size));
for(int k = 0; k < packed_v_memory.size; ++k){
if (i*4 + k >= num_parameters) return; // guard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of this guard by asserting "(num_parameters % 4) == 0" outside the kernel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the idea of removing all of the bounds checks and bring them outside the kernel, I'm fairly confident we have some kernels that have incorrect sizing inputs that are just fixed by having the bounds check in the kernel, the fused kernel with the softmax as an example

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we really should codify our assumption in some prominent place, and ensure that any model we generate has nice enough shapes. additional asserts won't hurt, though.

train_gpt2.cu Outdated
@@ -1917,7 +1927,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
}

int block_size = 512;
int num_blocks = CEIL_DIV(model->num_parameters, block_size);
int num_blocks = CEIL_DIV(model->num_parameters, block_size)/x128::size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the division by x128::size be inside the CEIL_DIV?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was having difficulty getting it to compile in that format in the kernel file, if you know whats going on there and whats blocking it I would love to know

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I should just be able to typecast it, will modify to follow the format inside the ceil div

@ademeure
Copy link
Contributor

ademeure commented May 2, 2024

I think the problem was it can't work with only 1 loop, it was skipping some of the elements for some of the arrays because fo the different f128/x128 sizes, here's my attempt at fixing that in Chris' kernel which seems to work (haven't looked into perf yet):

__global__ void adamw_kernel4(floatX* params_memory, const floatX* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
                              float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay,
                              unsigned int seed) {
   int idx = blockIdx.x * blockDim.x + threadIdx.x;
   int idx_offset = idx*x128::size;
   if (idx_offset >= num_parameters) { return; }

   x128 packed_grads_memory = load128(grads_memory + idx_offset);
   x128 packed_params_memory = load128(params_memory + idx_offset);
   for (int n = 0; n < (x128::size / f128::size); n++) {
    int idx_n_offset = idx_offset + n*f128::size;
    f128 packed_m_memory = load128(m_memory + idx_n_offset);
    f128 packed_v_memory = load128(v_memory + idx_n_offset);
    for(int k = 0; k < f128::size; ++k){
        int k_n_offset = k + n*f128::size;
        float grad = (float)packed_grads_memory[k_n_offset];
        float m = packed_m_memory[k];
        float v = packed_v_memory[k];
        // update the first moment (momentum)
        m = lerp(grad, m, beta1);
        packed_m_memory[k] = m;
        // update the second moment (RMSprop)
        v = lerp(grad * grad, v, beta2);
        packed_v_memory[k] = v;
        m /= beta1_correction; // Setting these values explicitly due to compiler error for modifying
        v /= beta2_correction; // packed128 values when using
        // update the parameters (weight/bias)
        float param = (float)packed_params_memory[k_n_offset] - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * (float)packed_params_memory[k_n_offset]));
        unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed);
        // todo - explain stochastic rounding here
        stochastic_rounding(param, &packed_params_memory[k_n_offset], random);
    }
    store128(m_memory + idx_n_offset, packed_m_memory);
    store128(v_memory + idx_n_offset, packed_v_memory);
   }
   store128(params_memory+idx_offset, packed_params_memory);
}

@ChrisDryden
Copy link
Contributor Author

Updated the PR to show the new kernel, it does have a speedup in the train loop for me of:
total average iteration time: 38.287047 ms
to
total average iteration time: 37.143633 ms

params_memory[i] -= learning_rate * (m / (sqrtf(v) + eps) + weight_decay * (float) params_memory[i]);
}

// Optimized kernel to use lower precision data types for params memory and grads memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment accurate? kernel2 also uses floatX

void adamw_dispatch4(floatX* params_memory, const floatX* grads_memory, float* m_memory, float* v_memory, long num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {
unsigned int block_size = 512;
assert(num_parameters % 4 == 0 && f128::size <= x128::size); // asserting here to not require bounds check in kernel
Copy link
Contributor

@ngc92 ngc92 May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f128::size <= x128::size is a compile-time property, best make that a static_assert inside the actual kernel
also, num_parameters % x128::size == 0 would be the safer choice, I think,

}
store128(m_memory+(i*f128::size), packed_m_memory);
store128(v_memory+(i*f128::size), packed_v_memory);
store128(params_memory+(i*x128::size), packed_params_memory);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for sizeof(f128) != sizeof(x128), I believe this write might result in a race condition.
probably not in practice because the optimizer ends up with just a 64 bit store.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants