<a href="https://colab.research.google.com/github/mohitraosatya/used-flashattention-ttmetal-poc/blob/main/Untitled6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
!nvidia-smi

Tue Mar  4 09:22:22 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [24]:
!apt-get update
!apt-get install -y cmake build-essential

Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:5 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Reading package lists... Done
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
Reading package lists... Done
Building dependency tree... Done
Reading

In [25]:
%%writefile CMakeLists.txt
cmake_minimum_required(VERSION 3.16)
project(FusedFlashAttention LANGUAGES CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# We'll build one executable: fused_flash_attention
add_executable(fused_flash_attention
    fused_attention.cpp
    fused_attention.hpp
    naive_attention.cpp
    naive_attention.hpp
    main.cpp
)

Writing CMakeLists.txt


In [26]:
%%writefile fused_attention.hpp
#pragma once

namespace flash_attn {

struct FusedAttentionParams {
    float* Q;
    float* K;
    float* V;
    float* Output;

    int batch_size;
    int num_heads;
    int seq_len;
    int d_head;

    bool apply_scale;
    float scale_factor;
};

void fusedFlashAttentionKernel(const FusedAttentionParams& p);

} // namespace flash_attn

Writing fused_attention.hpp


In [27]:
%%writefile fused_attention.cpp
#include "fused_attention.hpp"
#include <cmath>
#include <algorithm>

namespace flash_attn {

// CPU "fused" approach: QK^T, softmax, multiply by V, in a single pass
void fusedFlashAttentionKernel(const FusedAttentionParams& p)
{
    int B = p.batch_size;
    int H = p.num_heads;
    int S = p.seq_len;
    int D = p.d_head;

    for(int b = 0; b < B; b++){
        for(int h = 0; h < H; h++){
            float* Qbh = p.Q + (b*H + h)*(S*D);
            float* Kbh = p.K + (b*H + h)*(S*D);
            float* Vbh = p.V + (b*H + h)*(S*D);
            float* Obh = p.Output + (b*H + h)*(S*D);

            // scores: S x S
            float* scores = new float[S*S];

            // 1) Q x K^T
            for(int i = 0; i < S; i++){
                for(int j = 0; j < S; j++){
                    float dot = 0.f;
                    for(int d_i=0; d_i < D; d_i++){
                        dot += Qbh[i*D + d_i] * Kbh[j*D + d_i];
                    }
                    if(p.apply_scale){
                        dot *= p.scale_factor;
                    }
                    scores[i*S + j] = dot;
                }
            }

            // 2) Softmax row by row
            for(int i = 0; i < S; i++){
                float row_max = scores[i*S];
                for(int j=1; j < S; j++){
                    float val = scores[i*S + j];
                    if(val > row_max) row_max = val;
                }
                float sum_exp = 0.f;
                for(int j=0; j < S; j++){
                    float exp_val = std::exp(scores[i*S + j] - row_max);
                    scores[i*S + j] = exp_val;
                    sum_exp += exp_val;
                }
                for(int j=0; j < S; j++){
                    scores[i*S + j] /= sum_exp;
                }
            }

            // 3) Multiply by V => Output
            for(int i=0; i < S*D; i++){
                Obh[i] = 0.f;
            }
            for(int i = 0; i < S; i++){
                for(int d_i=0; d_i < D; d_i++){
                    float sum_val = 0.f;
                    for(int j=0; j < S; j++){
                        sum_val += scores[i*S + j] * Vbh[j*D + d_i];
                    }
                    Obh[i*D + d_i] = sum_val;
                }
            }

            delete[] scores;
        }
    }
}

} // namespace flash_attn


Writing fused_attention.cpp


In [28]:
%%writefile naive_attention.hpp
#pragma once
#include "fused_attention.hpp"

namespace flash_attn {

using NaiveAttentionParams = FusedAttentionParams;

void naiveAttentionKernel(const NaiveAttentionParams& p);

}


Writing naive_attention.hpp


In [29]:
%%writefile naive_attention.cpp
#include "naive_attention.hpp"
#include <cmath>
#include <algorithm>

namespace flash_attn {

void naiveAttentionKernel(const NaiveAttentionParams& p)
{
    int B = p.batch_size;
    int H = p.num_heads;
    int S = p.seq_len;
    int D = p.d_head;

    // Step 1: QK^T => scores
    float* scores = new float[B * H * S * S];

    for(int b=0; b < B; b++){
        for(int h=0; h < H; h++){
            float* Qbh = p.Q + (b*H + h)*(S*D);
            float* Kbh = p.K + (b*H + h)*(S*D);

            for(int i=0; i < S; i++){
                for(int j=0; j < S; j++){
                    float dot = 0.f;
                    for(int d_i=0; d_i < D; d_i++){
                        dot += Qbh[i*D + d_i] * Kbh[j*D + d_i];
                    }
                    if(p.apply_scale){
                        dot *= p.scale_factor;
                    }
                    scores[(b*H + h)*S*S + i*S + j] = dot;
                }
            }
        }
    }

    // Step 2: softmax
    for(int b=0; b < B; b++){
        for(int h=0; h < H; h++){
            float* score_ptr = scores + (b*H + h)*S*S;
            for(int i=0; i < S; i++){
                float row_max = score_ptr[i*S];
                for(int j=1; j < S; j++){
                    float val = score_ptr[i*S + j];
                    if(val > row_max) row_max = val;
                }
                float sum_exp = 0.f;
                for(int j=0; j < S; j++){
                    float exp_val = std::exp(score_ptr[i*S + j] - row_max);
                    score_ptr[i*S + j] = exp_val;
                    sum_exp += exp_val;
                }
                for(int j=0; j < S; j++){
                    score_ptr[i*S + j] /= sum_exp;
                }
            }
        }
    }

    // Step 3: multiply by V => output
    for(int b=0; b < B; b++){
        for(int h=0; h < H; h++){
            float* score_ptr = scores + (b*H + h)*S*S;
            float* Vbh = p.V + (b*H + h)*(S*D);
            float* Obh = p.Output + (b*H + h)*(S*D);

            // zero out
            for(int i=0; i < S*D; i++){
                Obh[i] = 0.f;
            }
            for(int i=0; i < S; i++){
                for(int d_i=0; d_i < D; d_i++){
                    float sum_val = 0.f;
                    for(int j=0; j < S; j++){
                        sum_val += score_ptr[i*S + j] * Vbh[j*D + d_i];
                    }
                    Obh[i*S + d_i] = sum_val;
                }
            }
        }
    }

    delete[] scores;
}

} // namespace flash_attn


Writing naive_attention.cpp


In [30]:
%%writefile main.cpp
#include <iostream>
#include <chrono>
#include <cmath>
#include "fused_attention.hpp"
#include "naive_attention.hpp"

static float randFloat() {
    return static_cast<float>(rand()) / RAND_MAX;
}

int main()
{
    // Increase seq_len for a bigger workload: set S=512
    // (Change to 1024 if desired)
    int B = 1;
    int H = 8;
    int S = 512;
    int D = 64;
    bool apply_scale = true;
    float scale_factor = 1.0f / std::sqrt((float)D);

    size_t qkv_size = (size_t)B * H * S * D;
    float* Q = new float[qkv_size];
    float* K = new float[qkv_size];
    float* V = new float[qkv_size];
    float* out_naive = new float[qkv_size];
    float* out_fused = new float[qkv_size];

    srand(42);
    for(size_t i=0; i < qkv_size; i++){
        Q[i] = randFloat();
        K[i] = randFloat();
        V[i] = randFloat();
        out_naive[i] = 0.f;
        out_fused[i] = 0.f;
    }

    flash_attn::NaiveAttentionParams naive_p {
        Q, K, V, out_naive,
        B, H, S, D,
        apply_scale, scale_factor
    };

    flash_attn::FusedAttentionParams fused_p {
        Q, K, V, out_fused,
        B, H, S, D,
        apply_scale, scale_factor
    };

    auto start_naive = std::chrono::high_resolution_clock::now();
    flash_attn::naiveAttentionKernel(naive_p);
    auto end_naive = std::chrono::high_resolution_clock::now();
    double naive_ms = std::chrono::duration<double,std::milli>(end_naive - start_naive).count();

    auto start_fused = std::chrono::high_resolution_clock::now();
    flash_attn::fusedFlashAttentionKernel(fused_p);
    auto end_fused = std::chrono::high_resolution_clock::now();
    double fused_ms = std::chrono::duration<double,std::milli>(end_fused - start_fused).count();

    std::cout << "Naive: " << naive_ms << " ms\n";
    std::cout << "Fused: " << fused_ms << " ms\n";
    std::cout << "Speedup: " << (naive_ms / fused_ms) << "x\n";

    // Compare outputs for correctness
    double sum_sq_diff = 0.0;
    for(size_t i=0; i < qkv_size; i++){
        double diff = (double)out_naive[i] - (double)out_fused[i];
        sum_sq_diff += diff * diff;
    }
    double rmse = std::sqrt(sum_sq_diff / qkv_size);
    std::cout << "RMSE: " << rmse << std::endl;

    delete[] Q;
    delete[] K;
    delete[] V;
    delete[] out_naive;
    delete[] out_fused;

    return 0;
}


Writing main.cpp


In [31]:
!mkdir -p build
%cd build
!cmake ..
!make
!./fused_flash_attention


/content/build/build/build/build/build
-- The CXX compiler identification is GNU 11.4.0
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Configuring done (0.2s)
-- Generating done (0.0s)
-- Build files have been written to: /content/build/build/build/build/build
[ 25%] [32mBuilding CXX object CMakeFiles/fused_flash_attention.dir/fused_attention.cpp.o[0m
[ 50%] [32mBuilding CXX object CMakeFiles/fused_flash_attention.dir/naive_attention.cpp.o[0m
[ 75%] [32mBuilding CXX object CMakeFiles/fused_flash_attention.dir/main.cpp.o[0m
[100%] [32m[1mLinking CXX executable fused_flash_attention[0m
[100%] Built target fused_flash_attention
Naive: 879.1 ms
Fused: 860.264 ms
Speedup: 1.0219x
RMSE: 0.467702
