In [8]:
import os
import sys
import glob
import json
import wandb
import random
import numpy as np
import pandas as pd
import warnings
from collections import defaultdict
import pathlib

from utils import *

sys.path.insert(1, '../') #ugly hack
import config.train_shakespeare_char as params
from train_torch_profiler import wandb_config, main as train

current_dir =  pathlib.Path().resolve()
os.environ['WANDB_NOTEBOOK_NAME'] =  os.path.join(current_dir)

warnings.filterwarnings(action='once')
%load_ext autoreload
%autoreload 2



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Profiling with PyTorch

We will first perform a sanity check and see whether profiling with a different number of iterations leads to significantly different results.

We keep all parameters fixed, whily varying the number of iterations and the type of attention.

In [None]:
fixed_params = dict(n_layer = params.n_layer, 
                    n_head = params.n_head, 
                    n_embd = params.n_embd, 
                    block_size = params.block_size,
                    bias = params.bias, 
                    dropout = params.dropout) 

varying_params = dict(max_iters = [400, 1000], #less than 400 iterations didn't lead to wandb system metrics
                      flash = [True, False])

profiling_params = params.profiler_schedule_args
print(' profiling_params', profiling_params)

experiment_params = combine_all_keys_and_values(varying_params)
print('\n experiment_params',experiment_params)

In [None]:
#Sanity check
#train with experiment_params
for exp_params in experiment_params:
    print('\n \n train with', exp_params)
    with set_params(params, exp_params): #check if reinit=True in wandb is better
        random_port = 1024 + random.randint(1,1000)
        setattr(params, 'random_port', random_port)
        print(params.random_port)
        train()

For single runs with 400 and 1000 iterations respectively, we see that mean kernel occupancy and duration (time of kernel activity)

In [10]:
for kernel in ['triton__0d1d2d3d4d5d6d', 'ampere_bf16_s1688gemm_bf16_128x128_ldg8_f2f_tn']:
    print('\n kernel:', kernel)
    for attention in ['flash', 'slow']:
        display(compare_runs(params.out_dir, attention, kernel))



 kernel: triton__0d1d2d3d4d5d6d


Unnamed: 0_level_0,mean_duration,mean_occupancy
flash,Unnamed: 1_level_1,Unnamed: 2_level_1
flash-block_size256-n_head6-n_embd384-max_iters400skip_first5-wait5-warmup5-active3-repeat1,174,83.0
flash-block_size256-n_head6-n_embd384-max_iters1000skip_first5-wait5-warmup5-active3-repeat1,184,83.0


Unnamed: 0_level_0,mean_duration,mean_occupancy
slow,Unnamed: 1_level_1,Unnamed: 2_level_1
slow-block_size256-n_head6-n_embd384-max_iters1000skip_first5-wait5-warmup5-active3-repeat1,882,99.16
slow-block_size256-n_head6-n_embd384-max_iters400skip_first5-wait5-warmup5-active3-repeat1,904,99.19



 kernel: ampere_bf16_s1688gemm_bf16_128x128_ldg8_f2f_tn


Unnamed: 0_level_0,mean_duration,mean_occupancy
flash,Unnamed: 1_level_1,Unnamed: 2_level_1
flash-block_size256-n_head6-n_embd384-max_iters400skip_first5-wait5-warmup5-active3-repeat1,528,17.0
flash-block_size256-n_head6-n_embd384-max_iters1000skip_first5-wait5-warmup5-active3-repeat1,507,17.0


Unnamed: 0_level_0,mean_duration,mean_occupancy
slow,Unnamed: 1_level_1,Unnamed: 2_level_1
slow-block_size256-n_head6-n_embd384-max_iters1000skip_first5-wait5-warmup5-active3-repeat1,534,17.0
slow-block_size256-n_head6-n_embd384-max_iters400skip_first5-wait5-warmup5-active3-repeat1,571,17.0


In [3]:
compare_runs(params.out_dir, 'slow', 'triton__0d1d2d3d4d5d6d')

Unnamed: 0,mean_duration,mean_occupancy
slow-block_size256-n_head6-n_embd384-max_iters1000skip_first5-wait5-warmup5-active3-repeat1,882,99.16
slow-block_size256-n_head6-n_embd384-max_iters400skip_first5-wait5-warmup5-active3-repeat1,904,99.19


## Experiments:

vary: 

(n_heads*h_size = embedding dimensionality (n_embd)) 

n_heads 
            
h_size

seq_len  

(number of parameters in scaling_laws.ipynb)

In [None]:
def wandb_system_metrics(username, project):#

    api = wandb.Api()
    runs = api.runs(f"{username}/{project}")
    system_metrics = defaultdict(dict) 
    
    for run in runs:
        if run.state =='finished':         
            system_metrics[run.name][run.id] = run.history(stream='events') #run.history() is a pandas data frame 
    
    return system_metrics

sm = wandb_system_metrics("m-motta" , 'profile-attention-nano-gpt')

About logged runtimes:

there is a difference between the runtime from .history() and .history(stream='events). Firstly, because events are restricted to the GPU, but probably also because the system is checked at specific intervals/checked once again after iterations are finished.

In [None]:
def compute_average_runtime(system_metrics):
    
    runtimes = {}
    for params in system_metrics.keys():
        print('params',params)
        runtimes[params] = 0
        count = 0
        for id in system_metrics[params].keys():
            print('id',id)
            print('.iloc[-1]',system_metrics[params][id]._runtime.iloc[-1])
            runtimes[params]+= system_metrics[params][id]._runtime.iloc[-1]
            count += 1
            print('')
        runtimes[params] = runtimes[params]/count
    
    return runtimes    

compute_average_runtime(sm)

In [None]:
def aggregate_wandb_system_metrics(username, project):

    #this is appending all the runs, I don't think I'll need this

    api = wandb.Api()
    runs = api.runs(f"{username}/{project}")
    system_metrics = {'flash-attention': [], 'slow-attention': []}
    for run in runs:
        system_metrics[run.name].append(run.history(stream='events')) #run.history() is a pandas data frame 

    for key, group in system_metrics.items():
        if group:
            system_metrics[key] = pd.concat(group, axis=0, join='outer', ignore_index=True)
    
    return system_metrics

#system_metrics = aggregate_wandb_system_metrics("m-motta" , 'profile-attention-nano-gpt')

In [None]:
def filter_metrics(system_metrics, list_of_matches):
    #selects specific metris (columns) from the .history() dataframe, based on string matches
    pattern = '|'.join(list_of_matches)

    def split_select(s):
        return s.split('system.')[1]

    system_metrics = {
        key: df.filter(regex=pattern, axis=1).rename(columns=split_select)
        for key, df in system_metrics.items()
    }

    return system_metrics

In [None]:
system_metrics = aggregate_wandb_system_metrics("m-motta" , 'profile-attention-nano-gpt')
system_metrics = filter_metrics(system_metrics, ['gpu', 'disk'])

In [None]:
system_metrics['flash-attention'].head()

In [None]:

system_metrics['flash-attention'].head()

In [None]:
params = ['gpu', 'disk']
metrics = system_metrics['flash-attention'].columns
metrics = [m.split('system.')[1] for m in metrics if any([p in m for p in params])]
metrics


Notes from paper:

All models are trained with the same hyperparameters for 400K steps.

We run all implementations with mixed-precision training (PyTorch AMP).

Speedup also changes when we increase the head dimension. Each block
requires more memory, so we need to use smaller block sizes to fit into SRAM. Figure 6 shows speedup with
head dimension 128 on an A100 (batch size 16, 12 heads). We see less speedup overall—but we can still see
significant speedup (up to 3×) with a causal mask, where half the blocks are masked out.
                     
We confirm that the memory footprint
of FlashAttention scales linearly with seq. length and is up to 3× faster than standard attention for
common seq. lengths (up to 2K). We confirm that runtime of block-sparse FlashAttention scales linearly
in seq. length and is faster than all existing approximate attention baselines.
                     
We train the model on 8×A100-80GB GPUs. Each training run takes between 16 and 19 minutes, and we
average the results of 10 runs.
                     
attension head, seq length and block size

In [None]:
# TODO further refactor train.py?  https://github.com/pytorch/examples/blob/main/imagenet/main.py