In [2]:
import random
import torch
from typing import Tuple
import triton

import deep_gemm
from deep_gemm import bench_kineto, calc_diff, cell_div, get_col_major_tma_aligned_tensor
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import pandas as pd
from core import per_token_cast_to_fp8, per_block_cast_to_fp8, deep_matmul, DeepLinear
import os
os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
from copy import deepcopy

In [2]:
def per_token_cast_to_fp82(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2 and x.size(1) % 128 == 0
    m, n = x.shape
    x_view = x.view(m, -1, 128)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)


def per_block_cast_to_fp82(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    x_padded = torch.zeros((cell_div(m, 128) * 128, cell_div(n, 128) * 128), dtype=x.dtype, device=x.device)
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))


def construct(x, y):
    m = x.size(0)
    n = y.size(0)
    out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)

    x_fp8, y_fp8 = per_token_cast_to_fp82(x), per_block_cast_to_fp82(y)
    # Transpose earlier so that the testing will not trigger transposing kernels
    x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
    return x_fp8, y_fp8, out

# def construct_grouped(list_x, list_y):
#     list_x_fp8 = []
#     list_y_fp8 = []
#     list_out = []
#     for idx in range(len(list_x)):

def construct_grouped(x, y, is_masked=False) -> \
        Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
    num_groups, m,k = x.shape
    n = y.size(1)
    out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16)

    assert m % 4 == 0, f'TMA alignment error: {m}'
    x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float))
    y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float))
    for i in range(num_groups):
        x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
        y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])

    # For non-masked input, we must merge the group and M dims
    if not is_masked:
        x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1])
        out = out.view(-1, n)

    # Transpose earlier so that the testing will not trigger transposing kernels
    x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
    return x_fp8, y_fp8, out



def deep_matmul2(x, y):
    x_fp8, y_fp8, out = construct(x, y)
    deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
    return out

def group_deep_matmul(x, y):
    num_groups, m, k = x.shape
    x_fp8, y_fp8, out = construct_grouped(x, y)
    m_indices = torch.arange(0, num_groups, device='cuda', dtype=torch.int)
    m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
    deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
    return out

fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
def te_matmul(x, fc):
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        out = fc(x)
    return out

def group_te_matmul(x, fc, num_groups):
    splits = [x.size(0)//num_groups] * num_groups
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        out = fc(x, splits)
    return out

In [4]:
device = 'cuda'
dtype = torch.bfloat16
m, n, k = 4096, 4096*2, 4096*2
x = torch.randn(m, n, dtype=dtype, device=device)
fc = torch.nn.Linear(k, n, bias=False, device=device, dtype=dtype)
y = fc.weight
x_fp8, y_fp8, out = construct(x, y)
out_ref = fc(x)

In [5]:
a_fp8 = per_token_cast_to_fp8(x)
b_fp8 = per_block_cast_to_fp8(y)

In [6]:
out1 = deep_matmul2(x, y)
out2 = deep_matmul(x, y)

In [7]:
print((out1 - out_ref).abs().max(), (out1 - out_ref).abs().mean())
print((out2 - out_ref).abs().max(), (out2 - out_ref).abs().mean())

tensor(0.1133, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>) tensor(0.0167, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
tensor(0.1133, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>) tensor(0.0166, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)


In [32]:
print(triton.testing.do_bench(lambda:deep_matmul2(x, y)))
print(triton.testing.do_bench(lambda:deep_matmul(x, y)))
print(triton.testing.do_bench(lambda:fc(x)))

1.8789821863174438
0.5845611095428467
0.7825400233268738


In [5]:
device = 'cuda'
dtype = torch.bfloat16
m, n, k = 4096*2, 4096*4, 4096
x1 = torch.randn(m, k, dtype=dtype, device=device)
x1.requires_grad_(True)
x2 = deepcopy(x1)
bias = False
fc1 = torch.nn.Linear(k, n, bias=bias, device=device, dtype=dtype)
fc2 = DeepLinear(k, n, bias=bias, device=device, dtype=dtype)
fc2.weight.data.copy_(fc1.weight.data)
if bias:
    fc2.bias.data.copy_(fc1.bias.data)

In [6]:
y1 = fc1(x1)
y2 = fc2(x2)
dy = torch.rand_like(y1)
y1.backward(dy)
y2.backward(dy)
print((y1 - y2).abs().max(), (y1 - y2).abs().mean())
print((x1.grad - x2.grad).abs().max(), (x1.grad - x2.grad).abs().mean())
print((fc1.weight.grad - fc2.weight.grad).abs().max(), (fc1.weight.grad - fc2.weight.grad).abs().mean())
if bias:
    print((fc1.bias.grad - fc2.bias.grad).abs().max(), (fc1.bias.grad - fc2.bias.grad).abs().mean())

tensor(0.1328, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>) tensor(0.0167, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
tensor(0.1406, device='cuda:0', dtype=torch.bfloat16) tensor(0.0190, device='cuda:0', dtype=torch.bfloat16)
tensor(11., device='cuda:0', dtype=torch.bfloat16) tensor(1.5312, device='cuda:0', dtype=torch.bfloat16)


In [7]:
print(triton.testing.do_bench(lambda:fc1(x1)))
print(triton.testing.do_bench(lambda:fc2(x2)))
y1 = fc1(x1)
y2 = fc2(x2)
dy = torch.rand_like(y1)
print(triton.testing.do_bench(lambda:y1.backward(dy, retain_graph=True), grad_to_none=[x1, fc1.weight]))
print(triton.testing.do_bench(lambda:y2.backward(dy, retain_graph=True), grad_to_none=[x2, fc2.weight]))

1.5591166019439697
1.143518328666687
3.2884576320648193
2.541518211364746


In [1]:
from transformer_engine.pytorch.module.deep_linear import Linear as TeLinear
from transformer_engine.pytorch.module.deep_layernorm_linear import LayerNormLinear as TeLayerNormLinear

import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import torch
from copy import deepcopy
import triton

In [13]:
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
# with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
device = 'cuda'
dtype = torch.bfloat16
m, n, k = 256, 512, 1024
x1 = torch.randn(m, k, dtype=dtype, device=device)
x1.requires_grad_(True)
x2 = deepcopy(x1)
bias = False
fc1 = TeLinear(k, n, bias=bias, device=device, params_dtype=dtype)
fc2 = TeLinear(k, n, bias=bias, device=device, params_dtype=dtype)
# fc2 = te.Linear(k, n, bias=bias, device=device, params_dtype=dtype)
fc2.weight.data.copy_(fc1.weight.data)
if bias:
    fc2.bias.data.copy_(fc1.bias.data)

In [14]:
y1 = fc1(x1)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    y2 = fc2(x2)
dy = torch.rand_like(y1)
y1.backward(dy)
y2.backward(dy)
print((y1 - y2).abs().max(), (y1 - y2).abs().mean())
print((x1.grad - x2.grad).abs().max(), (x1.grad - x2.grad).abs().mean())
print((fc1.weight.grad - fc2.weight.grad).abs().max(), (fc1.weight.grad - fc2.weight.grad).abs().mean())
if bias:
    print((fc1.bias.grad - fc2.bias.grad).abs().max(), (fc1.bias.grad - fc2.bias.grad).abs().mean())

tensor(0.1328, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>) tensor(0.0217, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
tensor(0.0508, device='cuda:0', dtype=torch.bfloat16) tensor(0.0087, device='cuda:0', dtype=torch.bfloat16)
tensor(1.6406, device='cuda:0', dtype=torch.bfloat16) tensor(0.2637, device='cuda:0', dtype=torch.bfloat16)


In [8]:
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
device = 'cuda'
dtype = torch.bfloat16
m, n, k = 256, 512, 1024
x1 = torch.randn(m, k, dtype=dtype, device=device)
x1.requires_grad_(True)
x2 = deepcopy(x1)
bias = False
fc1 = TeLayerNormLinear(k, n, normalization='RMSNorm', bias=bias, device=device, params_dtype=dtype)
fc2 = TeLayerNormLinear(k, n, normalization='RMSNorm', bias=bias, device=device, params_dtype=dtype)
# fc2 = te.Linear(k, n, bias=bias, device=device, params_dtype=dtype)
fc2.weight.data.copy_(fc1.weight.data)
fc2.layer_norm_weight.data.copy_(fc1.layer_norm_weight.data)
if bias:
    fc2.bias.data.copy_(fc1.bias.data)

In [9]:
y1 = fc1(x1)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    y2 = fc2(x2)
dy = torch.rand_like(y1)
y1.backward(dy)
y2.backward(dy)
print((y1 - y2).abs().max(), (y1 - y2).abs().mean())
print((x1.grad - x2.grad).abs().max(), (x1.grad - x2.grad).abs().mean())
print((fc1.weight.grad - fc2.weight.grad).abs().max(), (fc1.weight.grad - fc2.weight.grad).abs().mean())
print((fc1.layer_norm_weight.grad - fc2.layer_norm_weight.grad).abs().max(), (fc1.layer_norm_weight.grad - fc2.layer_norm_weight.grad).abs().mean())
if bias:
    print((fc1.bias.grad - fc2.bias.grad).abs().max(), (fc1.bias.grad - fc2.bias.grad).abs().mean())

tensor(0.1250, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>) tensor(0.0215, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
tensor(0.0488, device='cuda:0', dtype=torch.bfloat16) tensor(0.0088, device='cuda:0', dtype=torch.bfloat16)
tensor(1.6562, device='cuda:0', dtype=torch.bfloat16) tensor(0.2676, device='cuda:0', dtype=torch.bfloat16)
tensor(0.8750, device='cuda:0', dtype=torch.bfloat16) tensor(0.1328, device='cuda:0', dtype=torch.bfloat16)


In [12]:
torch.__version__

'2.5.1+cu124'