In [1]:
import torch 
from torch.nn.parallel.distributed import _MixedPrecision

In [68]:
def dynamic_distance_bias_matrix(start, stop )-> torch.Tensor:
    res = -torch.abs(torch.arange(start[0], start[1]) - torch.arange(stop[0],stop[1])[:,None])
    start_height = start[1]-start[0]
    stop_width = stop[1]-stop[0]
    if res.shape[0] <= start_height:
        res = res.T
    if res.shape[0] != start_height:
        print(f"{res.shape[0]=}")
        print(f"mismatch in generated mask 0 dim")
    if res.shape[1] != stop_width:
        print(f"mismatch in generated mask 1 dim")
        print(f"{res.shape[1]=}, {stop_width=}" )
    return res

In [3]:
base = dynamic_distance_bias_matrix((0,8), (0,8))
base

tensor([[ 0, -1, -2, -3, -4, -5, -6, -7],
        [-1,  0, -1, -2, -3, -4, -5, -6],
        [-2, -1,  0, -1, -2, -3, -4, -5],
        [-3, -2, -1,  0, -1, -2, -3, -4],
        [-4, -3, -2, -1,  0, -1, -2, -3],
        [-5, -4, -3, -2, -1,  0, -1, -2],
        [-6, -5, -4, -3, -2, -1,  0, -1],
        [-7, -6, -5, -4, -3, -2, -1,  0]])

In [71]:
block_sram = dynamic_distance_bias_matrix((0,4), (0,2))
block_sram


tensor([[ 0, -1],
        [-1,  0],
        [-2, -1],
        [-3, -2]])

In [43]:
start = (0,4)
stop = (0,2)
actual = dynamic_distance_bias_matrix(start, stop)

actual

len(res[0])=2
mismatch in generated mask 0 dim


tensor([[ 0, -1],
        [-1,  0],
        [-2, -1],
        [-3, -2]])

In [108]:
import pytest

#class TestFlashAlibiMask:
#@pytest.fixture
N = 1024
step_row = 64
step_col = 32
mask = dynamic_distance_bias_matrix((0,N), (0,N)) 
#print(f"{mask[0:128, 0:64]=}")

def test_walk_left_side(mask, N, step_row, step_col):
    start_row = 0
    start_col = 0
    assert N % step_row ==0, f"mismatch in row step size vs N"
    for stop_row in range(step_row, N, step_row):
        print(f"{start_row=}, {stop_row=}, {start_col=}, {step_col=}")
        expected = mask[start_row:stop_row, start_col:step_col]
        actual = dynamic_distance_bias_matrix((start_row, stop_row), (start_col,step_col))
        #print(f"{actual=}, {expected=}")
        assert torch.equal(actual, expected)
        start_row+= step_row

def test_walk_right_side(mask, N, step_row, step_col):
    start_row = 0
    start_col = N - step_col
    stop_col = N
    assert N % step_row ==0, f"mismatch in row step size vs N"
    for stop_row in range(step_row, N, step_row):
        print(f"{start_row=}, {stop_row=}, {start_col=}, {stop_col=}")
        expected = mask[start_row:stop_row, start_col:stop_col]
        actual = dynamic_distance_bias_matrix((start_row, stop_row), (start_col,stop_col))
        print(f"{actual.shape=}, {expected.shape=}")
        assert torch.equal(actual, expected)
        start_row += step_row

def test_walk_diagonal(mask, N, step_row, step_col):
    start_row = 0
    start_col = 0
    stop_col = step_col

    assert N % step_row ==0, f"mismatch in row step size vs N"
    assert N % step_col ==0, f"mismatch in col step size vs N"
    for stop_row in range(step_row, N, step_row):
        print(f"{start_row=}, {stop_row=}, {start_col=}, {stop_col=}")
        expected = mask[start_row:stop_row, start_col:stop_col]
        actual = dynamic_distance_bias_matrix((start_row, stop_row), (start_col,stop_col))
        print(f"{actual.shape=}, {expected.shape=}")
        assert torch.equal(actual, expected)
        start_row += step_row
        start_col += step_col
        stop_col += step_col
        if stop_col > N:
            # cols don't necessarily fit evenly relative to rows
            break

def test_walk_reverse_diagonal(mask, N, step_row, step_col):
    start_row = 0
    start_col = N-step_col
    stop_col = N

    assert N % step_row ==0, f"mismatch in row step size vs N"
    assert N % step_col ==0, f"mismatch in col step size vs N"
    for stop_row in range(step_row, N, step_row):
        print(f"{start_row=}, {stop_row=}, {start_col=}, {stop_col=}")
        expected = mask[start_row:stop_row, start_col:stop_col]
        actual = dynamic_distance_bias_matrix((start_row, stop_row), (start_col,stop_col))
        print(f"{actual.shape=}, {expected.shape=}")
        assert torch.equal(actual, expected)
        start_row += step_row
        start_col -= step_col
        stop_col -= step_col
        if stop_col < 0:
            # cols don't necessarily fit evenly relative to rows
            break

def test_walk_all_rows(mask, N, step_row, step_col):
    start_row = 0
    start_col = 0
    stop_col = step_col

    assert N % step_row ==0, f"mismatch in row step size vs N"
    assert N % step_col ==0, f"mismatch in col step size vs N"
    block_counter=0
    for stop_row in range(step_row, N, step_row):
        for stop_col in range(step_col,N, step_col):

            print(f"{start_row=}, {stop_row=}, {start_col=}, {stop_col=}")
            expected = mask[start_row:stop_row, start_col:stop_col]
            actual = dynamic_distance_bias_matrix((start_row, stop_row), (start_col,stop_col))
            print(f"{actual.shape=}, {expected.shape=}")
            assert torch.equal(actual, expected)
            
            start_col += step_col
            block_counter+=1
        
        start_row += step_row
        start_col = 0
    print(f"checked {block_counter} blocks")

def test_walk_all_cols(mask, N, step_row, step_col):
    start_row = 0
    start_col = 0
    stop_col = step_col

    assert N % step_row ==0, f"mismatch in row step size vs N"
    assert N % step_col ==0, f"mismatch in col step size vs N"
    block_counter=0
    
    for stop_col in range(step_col,N, step_col):
        for stop_row in range(step_row, N, step_row):

            print(f"{start_row=}, {stop_row=}, {start_col=}, {stop_col=}")
            expected = mask[start_row:stop_row, start_col:stop_col]
            actual = dynamic_distance_bias_matrix((start_row, stop_row), (start_col,stop_col))
            print(f"{actual.shape=}, {expected.shape=}")
            assert torch.equal(actual, expected)
            
            start_row += step_row
            block_counter+=1
        
        start_col += step_col
        start_row = 0
    print(f"checked {block_counter} blocks")

            
            


In [109]:
test_walk_all_cols(mask, N, step_row, step_col)

start_row=0, stop_row=64, start_col=0, stop_col=32
actual.shape=torch.Size([64, 32]), expected.shape=torch.Size([64, 32])
start_row=64, stop_row=128, start_col=0, stop_col=32
actual.shape=torch.Size([64, 32]), expected.shape=torch.Size([64, 32])
start_row=128, stop_row=192, start_col=0, stop_col=32
actual.shape=torch.Size([64, 32]), expected.shape=torch.Size([64, 32])
start_row=192, stop_row=256, start_col=0, stop_col=32
actual.shape=torch.Size([64, 32]), expected.shape=torch.Size([64, 32])
start_row=256, stop_row=320, start_col=0, stop_col=32
actual.shape=torch.Size([64, 32]), expected.shape=torch.Size([64, 32])
start_row=320, stop_row=384, start_col=0, stop_col=32
actual.shape=torch.Size([64, 32]), expected.shape=torch.Size([64, 32])
start_row=384, stop_row=448, start_col=0, stop_col=32
actual.shape=torch.Size([64, 32]), expected.shape=torch.Size([64, 32])
start_row=448, stop_row=512, start_col=0, stop_col=32
actual.shape=torch.Size([64, 32]), expected.shape=torch.Size([64, 32])
star

In [5]:

distance_bias_matrix = -torch.abs(
            torch.arange(5,10) - torch.arange(5,10)[:,None]
        )
print(f"{distance_bias_matrix[0:10]=}")

distance_bias_matrix[0:10]=tensor([[ 0, -1, -2, -3, -4],
        [-1,  0, -1, -2, -3],
        [-2, -1,  0, -1, -2],
        [-3, -2, -1,  0, -1],
        [-4, -3, -2, -1,  0]])


In [2]:
bf16_policy = _MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)



In [None]:
model = torch.nn.Linear(12, 24)

In [4]:
from torch.nn.parallel import DistributedDataParallel as DDP

In [None]:
model = DDP(model, mixed_precision = bf16_policy)