In [36]:
import glob
import json
import wandb
import pandas as pd


In [2]:
# Global Paramters

#I/O
out_dir = 'out-shakespeare-char'
#log_interval = 20 # don't print too too often

# data
dataset = 'shakespeare_char' 
gradient_accumulation_steps = 1
batch_size = 64
block_size = 256 # context of up to 256 previous characters

# model
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2 
bias = False # TODO do we use bias inside LayerNorm and Linear layers?

# adamw optimizer 
learning_rate = 1e-3 # with baby networks can afford to go a bit higher
weight_decay = 1e-1
max_iters = 400
beta1 = 0.9
beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0

eval_interval = max_iters//4
eval_iters = max_iters//5
eval_only = False # if True, script exits right after the first eval

# learning rate decay settings
warmup_iters = max(10, max_iters // 100) 
min_lr = 1e-4 # learning_rate / 10 usually
decay_lr = True 
lr_decay_iters = max_iters # make equal to max_iters usually (TODO - check cosine annealing)

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = True

#model initialization
init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'

# logging
#wandb_log = True # override via command line if you like
wandb_project = 'profile-attention-nano-gpt' 

config = {'backend': 'nccl', 'device': 'cuda', 'dtype': 'bfloat16', 'compile': True}

In [None]:
wandb.login()

In [82]:
def collect_system_metrics(username, project):

    api = wandb.Api()
    runs = api.runs(f"{username}/{project}")
    system_metrics = {'flash-attention': {}, 'slow-attention': {}}
    for run in runs:
        system_metrics[run.name][run.id] = run.history(stream='events') #run.history() is a pandas data frame 
    
    return system_metrics

sm = collect_system_metrics("m-motta" , 'profile-attention-nano-gpt')
tt = sm['flash-attention']['oluy4uj9']._runtime

In [85]:
type(tt)
tt.iloc[-1]

827.09263

In [None]:
def compute_average_runtime(system_metrics):
    runtimes = {}
    for attention in system_metrics.keys():
        runtimes[attention] = 0
        count = 0
        for id in system_metrics[attention].keys():
            runtimes[attention]+= system_metrics[attention][id]._runtime[-1].iloc[-1]
        


In [80]:
827.092630/60

13.784877166666666

In [65]:
def aggregate_system_metrics(username, project):

    #this is appending all the runs, but I don't want that actually

    api = wandb.Api()
    runs = api.runs(f"{username}/{project}")
    system_metrics = {'flash-attention': [], 'slow-attention': []}
    for run in runs:
        system_metrics[run.name].append(run.history(stream='events')) #run.history() is a pandas data frame 

    for key, group in system_metrics.items():
        if group:
            system_metrics[key] = pd.concat(group, axis=0, join='outer', ignore_index=True)
    
    return system_metrics

#system_metrics = aggregate_system_metrics("m-motta" , 'profile-attention-nano-gpt')

In [69]:
def filter_metrics(system_metrics, list_of_matches):
    pattern = '|'.join(list_of_matches)

    def split_select(s):
        return s.split('system.')[1]

    system_metrics = {
        key: df.filter(regex=pattern, axis=1).rename(columns=split_select)
        for key, df in system_metrics.items()
    }

    return system_metrics

In [71]:
system_metrics = aggregate_system_metrics("m-motta" , 'profile-attention-nano-gpt')
system_metrics = filter_metrics(system_metrics, ['gpu', 'disk'])

In [73]:
system_metrics['flash-attention'].head()

Unnamed: 0,gpu.0.powerPercent,gpu.0.powerWatts,disk.out,disk.\.usageGB,gpu.0.temp,gpu.0.memory,gpu.0.gpu,gpu.0.memoryAllocatedBytes,disk.\.usagePercent,disk.in,gpu.0.memoryAllocated
0,9.94,8.02,3.82,174.91,47.6,13.93,7.93,857468700.0,20.3,0.0,9.99
1,12.16,10.33,10.62,174.91,47.93,3.73,11.93,872196800.0,20.3,0.0,10.16
2,8.33,6.29,16.45,174.91,47.2,6.87,16.6,880563500.0,20.3,0.0,10.26
3,12.51,10.52,24.49,174.91,47.47,7.87,14.4,882368000.0,20.3,0.0,10.28
4,11.7,9.69,36.57,174.89,47.4,6.27,17.0,903575400.0,20.3,0.01,10.52


In [62]:

system_metrics['flash-attention'].head()



Unnamed: 0,system.gpu.0.powerPercent,system.gpu.0.powerWatts,system.disk.out,system.disk.\.usageGB,system.gpu.0.temp,system.gpu.0.memory,system.gpu.0.gpu,system.gpu.0.memoryAllocatedBytes,system.disk.\.usagePercent,system.disk.in,system.gpu.0.memoryAllocated
0,9.94,8.02,3.82,174.91,47.6,13.93,7.93,857468700.0,20.3,0.0,9.99
1,12.16,10.33,10.62,174.91,47.93,3.73,11.93,872196800.0,20.3,0.0,10.16
2,8.33,6.29,16.45,174.91,47.2,6.87,16.6,880563500.0,20.3,0.0,10.26
3,12.51,10.52,24.49,174.91,47.47,7.87,14.4,882368000.0,20.3,0.0,10.28
4,11.7,9.69,36.57,174.89,47.4,6.27,17.0,903575400.0,20.3,0.01,10.52


In [48]:
params = ['gpu', 'disk']
metrics = system_metrics['flash-attention'].columns
metrics = [m.split('system.')[1] for m in metrics if any([p in m for p in params])]
metrics

['gpu.0.powerPercent',
 'gpu.0.powerWatts',
 'disk.out',
 'disk.\\.usageGB',
 'gpu.0.temp',
 'gpu.0.memory',
 'gpu.0.gpu',
 'gpu.0.memoryAllocatedBytes',
 'disk.\\.usagePercent',
 'disk.in',
 'gpu.0.memoryAllocated']

In [20]:
type(system_metrics['flash-attention'][1])

pandas.core.frame.DataFrame

In [20]:
if flash: #use flash attention
    wandb_run_name = 'flash-attention'
else:
    wandb_run_name = 'slow-attention'

['schemaVersion',
 'deviceProperties',
 'distributedInfo',
 'with_flops',
 'with_modules',
 'with_stack',
 'traceEvents',
 'traceName']

In [21]:
#['schemaVersion', 'deviceProperties', 'distributedInfo', 'with_flops', 'with_modules', 'with_stack', 'traceEvents', 'traceName']
#only deviceProperties and traceEvents contain relevant information
data = []
for run in glob.glob(f"{out_dir}/**/*.pt.trace.json", recursive=True):
  #  print('\n',run)
    with open(run) as jsonFile:
        data.append(json.load(jsonFile))
       
        

In [73]:
system_metrics

Unnamed: 0,system.network.sent,system.gpu.0.powerPercent,system.network.recv,system.cpu.9.cpu_percent,system.cpu.1.cpu_percent,system.gpu.0.powerWatts,system.cpu.7.cpu_percent,_wandb,system.disk.out,system.disk.\.usageGB,...,system.proc.memory.percent,system.disk.\.usagePercent,system.cpu.10.cpu_percent,system.disk.in,system.gpu.0.memoryAllocated,_timestamp,system.cpu.8.cpu_percent,system.cpu.18.cpu_percent,system.cpu.17.cpu_percent,system.cpu.14.cpu_percent
0,111203.73,62.39,122769.67,3.86,4.03,49.94,2.67,True,8.35,175.36,...,3.41,20.4,10.53,0,30.02,1696613000.0,5.83,4.47,8.35,6.39
1,405545.11,99.46,415949.67,1.28,0.89,84.49,0.17,True,386.3,175.38,...,4.07,20.4,1.17,0,46.48,1696613000.0,1.06,95.84,0.94,2.21


In [None]:
checkpoint = torch.load(PATH)


In [None]:
python3 train_torch_profiler.py config/train_shakespeare_char.py

In [None]:

All models are trained with the same hyperparameters for 400K steps.

We run all 
implementations with mixed-precision training (PyTorch AMP).

Speedup also changes when we increase the head dimension. Each block
requires more memory, so we need to use smaller block sizes to fit into SRAM. Figure 6 shows speedup with
head dimension 128 on an A100 (batch size 16, 12 heads). We see less speedup overall—but we can still see
significant speedup (up to 3×) with a causal mask, where half the blocks are masked out.
                     
We confirm that the memory footprint
of FlashAttention scales linearly with seq. length and is up to 3× faster than standard attention for
common seq. lengths (up to 2K). We confirm that runtime of block-sparse FlashAttention scales linearly
in seq. length and is faster than all existing approximate attention baselines.
                     
We train the model on 8×A100-80GB GPUs. Each training run takes between 16 and 19 minutes, and we
average the results of 10 runs.
                     
attension head, seq length and block size