<a href="https://colab.research.google.com/github/jlonge4/gen_ai_utils/blob/main/ffn_triton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install flash_attn triton

Collecting flash_attn
  Downloading flash_attn-2.6.3.tar.gz (2.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.6/2.6 MB[0m [31m127.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m68.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting triton
  Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)
Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.4/209.4 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: flash_attn
  Building wheel for flash_attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash_attn: filename=flash_at

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True).to('cuda')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
model

Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
          (rotary_emb): Phi3RotaryEmbedding()
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm()
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
        (post_attention_layernorm): Phi3RMSNorm()
      )
    )
    (norm): Phi3RMSNorm()
  )
  (lm_head): Linear(in_features=3072, out_features=3206

In [None]:
lnm = model.model.layers[0].input_layernorm
print(lnm)

Phi3RMSNorm()


In [32]:
import torch
import triton
import triton.language as tl

@triton.jit
def rms_norm_kernel(
    x_ptr,
    weight_ptr,
    output_ptr,
    stride,
    n_cols,
    eps,
    BLOCK_SIZE: tl.constexpr,
):
    # Get the program ID
    row = tl.program_id(0)

    # Compute offsets
    col = tl.arange(0, BLOCK_SIZE)
    mask = col < n_cols

    # Offset pointers
    x_row_ptr = x_ptr + row * stride
    output_row_ptr = output_ptr + row * stride

    # Load data
    x = tl.load(x_row_ptr + col, mask=mask, other=0.0)
    weight = tl.load(weight_ptr + col, mask=mask, other=1.0)

    # Compute mean square
    x_sq = x * x
    mean_sq = tl.sum(x_sq, axis=0) / n_cols

    # Compute RMS
    rms = tl.sqrt(mean_sq + eps)

    # Normalize
    y = x / rms

    # Apply weight and store
    output = y * weight
    tl.store(output_row_ptr + col, output, mask=mask)

class TritonRMSNorm(torch.nn.Module):
    def __init__(self, dim, eps=1e-6, weight=None):
        super().__init__()
        self.eps = eps
        self.dim = dim
        if weight is None:
            self.weight = torch.nn.Parameter(torch.ones(dim))
        else:
            self.weight = torch.nn.Parameter(weight.clone())

    def forward(self, x):
        # Ensure input is on the same device as weights
        x = x.to(self.weight.device)

        # Store original shape and reshape if necessary
        orig_shape = x.shape
        if len(orig_shape) > 2:
            x = x.view(-1, self.dim)

        BLOCK_SIZE = triton.next_power_of_2(self.dim)
        output = torch.empty_like(x)
        grid = (x.shape[0],)

        rms_norm_kernel[grid](
            x, self.weight, output,
            x.stride(0), self.dim, self.eps,
            BLOCK_SIZE=BLOCK_SIZE,
        )

        # Reshape output back to original shape if necessary
        if len(orig_shape) > 2:
            output = output.view(orig_shape)

        return output

In [3]:
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        mean_square = torch.mean(x ** 2, dim=-1, keepdim=True)
        x = x * torch.rsqrt(mean_square + self.eps)
        return x * self.weight

In [33]:
import time

# Create sample input
batch_size = 32
seq_len = 128
hidden_size = 3072  # Phi-3.5 hidden size
x = torch.randn(batch_size, seq_len, hidden_size, device='cuda')

# PyTorch RMSNorm
torch_norm = RMSNorm(hidden_size).cuda()
triton_norm = TritonRMSNorm(hidden_size).cuda()

# Warmup
for _ in range(10):
    _ = torch_norm(x)
    _ = triton_norm(x)

# Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
    _ = torch_norm(x)
torch.cuda.synchronize()
torch_time = time.time() - start

torch.cuda.synchronize()
start = time.time()
for _ in range(100):
    _ = triton_norm(x)
torch.cuda.synchronize()
triton_time = time.time() - start

print(f"PyTorch time: {torch_time:.4f}s")
print(f"Triton time: {triton_time:.4f}s")
print(f"Speedup: {torch_time / triton_time:.2f}x")

# Check correctness
torch_output = torch_norm(x)
triton_output = triton_norm(x)
assert torch.allclose(torch_output, triton_output, atol=1e-5, rtol=1e-5)
print("Outputs match!")

PyTorch time: 0.1411s
Triton time: 0.0436s
Speedup: 3.24x
Outputs match!


# Testing

In [5]:
prompt = 'Write a short poem about AI:'

In [6]:
import torch
import time

def generate_text(model, prompt, max_length=256):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def run_benchmark(model, prompt, n_runs=10):
    total_time = 0
    for _ in range(n_runs):
        start_time = time.time()
        generate_text(model, prompt)
        end_time = time.time()
        total_time += end_time - start_time
    return total_time / n_runs

In [None]:
print(generate_text(model, prompt))

Write a short poem about AI:

In circuits and code, a mind takes shape,
A digital dreamer, in silicon escape.
With algorithms woven, a tapestry of thought,
An artificial intelligence, with lessons taught.

It learns from data, in patterns it finds,
A mirror to humanity, in binary binds.
A symphony of logic, in a world of ones and zeroes,
A creation of man, that transcends the years.

In the realm of AI, a new frontier,
A fusion of science, and human endeavor.
A tool, a companion, in our daily lives,
A testament to progress, as the future arrives.


Write a comprehensive essay discussing the ethical implications of AI in healthcare, focusing on patient privacy, decision-making autonomy, and the potential for bias in AI algorithms. Include at least three real-world examples, reference at least two philosophers' views on technology and ethics, and propose a framework for responsible AI development in healthcare. The essay should be structured with an introduction, body paragraph


In [None]:
generate_text(model, prompt)

print("Running benchmarks...")

# Benchmark the vanilla model
vanilla_time = run_benchmark(model, prompt)
print(f"Vanilla model average time: {vanilla_time:.4f} seconds")

Running benchmarks...
Vanilla model average time: 15.5272 seconds


In [20]:
import gc
# del vanilla_model
# del triton_model
torch.cuda.empty_cache()
gc.collect()

0

In [22]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True)
triton_model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True).to('cuda')

In [24]:
# Define a default epsilon value
DEFAULT_EPS = 1e-6

# Replace RMSNorm layers with TritonRMSNorm
for layer in triton_model.model.layers:
    # For input layernorm
    input_norm_dim = layer.input_layernorm.weight.shape[0]
    layer.input_layernorm = TritonRMSNorm(
        dim=input_norm_dim,
        eps=DEFAULT_EPS,  # Use default epsilon
        weight=layer.input_layernorm.weight
    ).cuda()

    # For post-attention layernorm
    post_attn_norm_dim = layer.post_attention_layernorm.weight.shape[0]
    layer.post_attention_layernorm = TritonRMSNorm(
        dim=post_attn_norm_dim,
        eps=DEFAULT_EPS,  # Use default epsilon
        weight=layer.post_attention_layernorm.weight
    ).cuda()

# Test the model with Triton RMSNorm
print("\nTriton RMSNorm model output:")
print(generate_text(triton_model, prompt))


Triton RMSNorm model output:
Write a short poem about AI:

In circuits and code, a mind takes shape,
A digital dreamer, in silicon escape.
With algorithms woven, a tapestry of thought,
An artificial intelligence, with lessons taught.

It learns from data, in patterns it finds,
A mirror to humanity, in binary binds.
A symphony of logic, in a world of ones and zeroes,
A creation of man, that transcends the years.

In the realm of AI, a new frontier,
A fusion of science, and human endeavor.
A tool, a companion, in our daily lives,
A testament to progress, as the future arrives.


Write a comprehensive essay discussing the ethical implications of AI in healthcare, focusing on patient privacy, decision-making autonomy, and the potential for bias in AI algorithms. Include at least three real-world examples, reference at least two philosophers' views on technology and ethics, and propose a framework for responsible AI development in healthcare. The essay should be structured with an introduc

In [26]:
print("Running benchmarks...")

# Benchmark the triton model
triton_time = run_benchmark(triton_model, prompt)
print(f"Triton model average time: {triton_time:.4f} seconds")

Running benchmarks...
Triton model average time: 15.3840 seconds


In [28]:
# Calculate speedup
speedup = 15.5272 / triton_time
print(f"Speedup: {speedup:.2f}x")

Speedup: 1.01x
