In [1]:
import os
import glob
import json
import wandb
import random
import numpy as np
import pandas as pd
from collections import defaultdict
import pathlib

from utils import *

import sys
sys.path.insert(1, '../') #ugly hack
import config.train_shakespeare_char as params
from train_torch_profiler import wandb_config, main as train


current_dir =  pathlib.Path().resolve()
os.environ['WANDB_NOTEBOOK_NAME'] =  os.path.join(current_dir)


In [2]:
#wandb.init(project=params.wandb_project, name=params.wandb_run_name, config=wandb_config, sync_tensorboard=True)
#wandb.login()


### Profiling with PyTorch

We will first perform a sanity check and see whether profiling with a different number of iterations leads to different results.

In [3]:
fixed_params = dict(n_layer = params.n_layer, 
                    n_head = params.n_head, 
                    n_embd = params.n_embd, 
                    block_size = params.block_size,
                    bias = params.bias, 
                    dropout = params.dropout) 

varying_params = dict(max_iters = [400, 1000], #less than 400 iterations didn't lead to wandb system metrics
                      flash = [True, False])

profiling_params = params.profiler_schedule_args
print(' profiling_params', profiling_params)

experiment_params = combine_all_keys_and_values(varying_params)
print('\n experiment_params',experiment_params)

 profiling_params {'skip_first': 5, 'wait': 5, 'warmup': 5, 'active': 3, 'repeat': 1}

 experiment_params [{'max_iters': 400, 'flash': True}, {'max_iters': 400, 'flash': False}, {'max_iters': 1000, 'flash': True}, {'max_iters': 1000, 'flash': False}]


In [4]:
for exp_params in experiment_params:
    print('\n \n train with', exp_params)
    with set_params(params, exp_params): #check if reinit=True in wandb is better
        random_port = 1024 + random.randint(1,1000)
        setattr(params, 'random_port', random_port)
        print(params.random_port)
        train()


 
 train with {'max_iters': 400, 'flash': True}
1629

 Using Torch DDP 
 
tokens per iteration will be: 16,384
found vocab_size = 65 (inside /home/nanoGPT/data/shakespeare_char/meta.pkl)
Initializing a new model from scratch
number of parameters: 10.65M
num decayed parameter tensors: 26, with 10,740,096 parameters
num non-decayed parameter tensors: 13, with 4,992 parameters
using fused AdamW: True
compiling the model... (takes a ~minute)
model.module


[34m[1mwandb[0m: Currently logged in as: [33mm-motta[0m. Use [1m`wandb login --relogin`[0m to force relogin



 
 profiler_schedule_args:  {'skip_first': 5, 'wait': 5, 'warmup': 5, 'active': 3, 'repeat': 1} 
 



Using FallbackKernel: aten._scaled_dot_product_flash_attention
Using FallbackKernel: aten._scaled_dot_product_flash_attention


step 0: train loss 4.2885, val loss 4.2831


Using FallbackKernel: aten._scaled_dot_product_flash_attention
Using FallbackKernel: aten._scaled_dot_product_flash_attention
Using FallbackKernel: aten._scaled_dot_product_flash_attention_backward
Using FallbackKernel: aten._scaled_dot_product_flash_attention_backward


iter 0: loss 4.2659, time 18266.61ms, mfu -100.00%


[W kineto_shim.cpp:366] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:366] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:366] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:366] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:366] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:366] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:366] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:366] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:366] Profiler is not initialized: skipping step() invocation


iter 10: loss 3.0377, time 94.34ms, mfu 3.95%


STAGE:2023-10-17 16:19:43 313505:313505 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-10-17 16:19:43 313505:313505 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-10-17 16:19:43 313505:313505 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


iter 20: loss 2.6514, time 84.88ms, mfu 3.99%
step 25: train loss 2.5945, val loss 2.5938
saving checkpoint to out-shakespeare-char
iter 30: loss 2.5637, time 85.90ms, mfu 4.03%
iter 40: loss 2.5408, time 84.53ms, mfu 4.07%
step 50: train loss 2.5002, val loss 2.5042
saving checkpoint to out-shakespeare-char
iter 50: loss 2.5093, time 1409.77ms, mfu 3.69%
iter 60: loss 2.5066, time 87.92ms, mfu 3.74%
iter 70: loss 2.4900, time 84.99ms, mfu 3.81%
step 75: train loss 2.4680, val loss 2.4775
saving checkpoint to out-shakespeare-char
iter 80: loss 2.4855, time 85.87ms, mfu 3.86%
iter 90: loss 2.4728, time 87.43ms, mfu 3.90%
step 100: train loss 2.4522, val loss 2.4631
saving checkpoint to out-shakespeare-char
iter 100: loss 2.4862, time 1391.51ms, mfu 3.54%
iter 110: loss 2.4646, time 85.38ms, mfu 3.62%
iter 120: loss 2.4795, time 85.49ms, mfu 3.69%
step 125: train loss 2.4545, val loss 2.4612
saving checkpoint to out-shakespeare-char
iter 130: loss 2.4726, time 88.29ms, mfu 3.75%
iter 140




 
 train with {'max_iters': 400, 'flash': False}
1669

 Using Torch DDP 
 
tokens per iteration will be: 16,384
found vocab_size = 65 (inside /home/nanoGPT/data/shakespeare_char/meta.pkl)
Initializing a new model from scratch
number of parameters: 10.65M
num decayed parameter tensors: 26, with 10,740,096 parameters
num non-decayed parameter tensors: 13, with 4,992 parameters
using fused AdamW: True
compiling the model... (takes a ~minute)
model.module



 
 profiler_schedule_args:  {'skip_first': 5, 'wait': 5, 'warmup': 5, 'active': 3, 'repeat': 1} 
 

step 0: train loss 4.2885, val loss 4.2831
iter 0: loss 4.2676, time 14623.82ms, mfu -100.00%
iter 10: loss 3.0517, time 106.03ms, mfu 3.51%


STAGE:2023-10-17 16:21:05 313505:313505 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-10-17 16:21:05 313505:313505 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-10-17 16:21:05 313505:313505 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


iter 20: loss 2.6574, time 105.85ms, mfu 3.51%
step 25: train loss 2.5983, val loss 2.5957
saving checkpoint to out-shakespeare-char
iter 30: loss 2.5617, time 111.47ms, mfu 3.50%
iter 40: loss 2.5425, time 106.22ms, mfu 3.50%
step 50: train loss 2.5063, val loss 2.5123
saving checkpoint to out-shakespeare-char
iter 50: loss 2.5100, time 1676.53ms, mfu 3.17%
iter 60: loss 2.5059, time 106.12ms, mfu 3.21%
iter 70: loss 2.4909, time 106.64ms, mfu 3.23%
step 75: train loss 2.4691, val loss 2.4806
saving checkpoint to out-shakespeare-char
iter 80: loss 2.4858, time 108.20ms, mfu 3.25%
iter 90: loss 2.4677, time 108.43ms, mfu 3.27%
step 100: train loss 2.4538, val loss 2.4623
saving checkpoint to out-shakespeare-char
iter 100: loss 2.4861, time 1730.51ms, mfu 2.97%
iter 110: loss 2.4655, time 106.84ms, mfu 3.02%
iter 120: loss 2.4813, time 106.88ms, mfu 3.07%
step 125: train loss 2.4560, val loss 2.4632
saving checkpoint to out-shakespeare-char
iter 130: loss 2.4749, time 106.82ms, mfu 3.11




 
 train with {'max_iters': 1000, 'flash': True}
2018

 Using Torch DDP 
 
tokens per iteration will be: 16,384
found vocab_size = 65 (inside /home/nanoGPT/data/shakespeare_char/meta.pkl)
Initializing a new model from scratch
number of parameters: 10.65M
num decayed parameter tensors: 26, with 10,740,096 parameters
num non-decayed parameter tensors: 13, with 4,992 parameters
using fused AdamW: True
compiling the model... (takes a ~minute)
model.module



 
 profiler_schedule_args:  {'skip_first': 5, 'wait': 5, 'warmup': 5, 'active': 3, 'repeat': 1} 
 



Using FallbackKernel: aten._scaled_dot_product_flash_attention
Using FallbackKernel: aten._scaled_dot_product_flash_attention


step 0: train loss 4.2885, val loss 4.2831


Using FallbackKernel: aten._scaled_dot_product_flash_attention
Using FallbackKernel: aten._scaled_dot_product_flash_attention
Using FallbackKernel: aten._scaled_dot_product_flash_attention_backward
Using FallbackKernel: aten._scaled_dot_product_flash_attention_backward


iter 0: loss 4.2659, time 16463.89ms, mfu -100.00%
iter 10: loss 3.0374, time 85.24ms, mfu 4.37%


STAGE:2023-10-17 16:22:46 313505:313505 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-10-17 16:22:46 313505:313505 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-10-17 16:22:46 313505:313505 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


iter 20: loss 2.6515, time 85.28ms, mfu 4.37%
step 25: train loss 2.5944, val loss 2.5937
saving checkpoint to out-shakespeare-char
iter 30: loss 2.5637, time 90.22ms, mfu 4.35%
iter 40: loss 2.5408, time 87.91ms, mfu 4.34%
step 50: train loss 2.5002, val loss 2.5042
saving checkpoint to out-shakespeare-char
iter 50: loss 2.5093, time 1516.15ms, mfu 3.93%
iter 60: loss 2.5067, time 85.36ms, mfu 3.97%
iter 70: loss 2.4900, time 85.54ms, mfu 4.01%
step 75: train loss 2.4680, val loss 2.4775
saving checkpoint to out-shakespeare-char
iter 80: loss 2.4854, time 84.83ms, mfu 4.05%
iter 90: loss 2.4728, time 86.36ms, mfu 4.07%
step 100: train loss 2.4521, val loss 2.4631
saving checkpoint to out-shakespeare-char
iter 100: loss 2.4862, time 1502.10ms, mfu 3.69%
iter 110: loss 2.4645, time 86.19ms, mfu 3.76%
iter 120: loss 2.4795, time 87.15ms, mfu 3.81%
step 125: train loss 2.4545, val loss 2.4612
saving checkpoint to out-shakespeare-char
iter 130: loss 2.4725, time 85.70ms, mfu 3.86%
iter 140




 
 train with {'max_iters': 1000, 'flash': False}
1358

 Using Torch DDP 
 
tokens per iteration will be: 16,384
found vocab_size = 65 (inside /home/nanoGPT/data/shakespeare_char/meta.pkl)
Initializing a new model from scratch
number of parameters: 10.65M
num decayed parameter tensors: 26, with 10,740,096 parameters
num non-decayed parameter tensors: 13, with 4,992 parameters
using fused AdamW: True
compiling the model... (takes a ~minute)
model.module



 
 profiler_schedule_args:  {'skip_first': 5, 'wait': 5, 'warmup': 5, 'active': 3, 'repeat': 1} 
 

step 0: train loss 4.2885, val loss 4.2831
iter 0: loss 4.2676, time 11147.68ms, mfu -100.00%
iter 10: loss 3.0523, time 107.01ms, mfu 3.48%


STAGE:2023-10-17 16:25:39 313505:313505 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2023-10-17 16:25:39 313505:313505 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2023-10-17 16:25:39 313505:313505 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


iter 20: loss 2.6575, time 105.58ms, mfu 3.49%
step 25: train loss 2.5984, val loss 2.5957
saving checkpoint to out-shakespeare-char
iter 30: loss 2.5618, time 102.83ms, mfu 3.50%
iter 40: loss 2.5425, time 104.73ms, mfu 3.51%
step 50: train loss 2.5063, val loss 2.5123
saving checkpoint to out-shakespeare-char
iter 50: loss 2.5101, time 1613.61ms, mfu 3.18%
iter 60: loss 2.5058, time 106.94ms, mfu 3.21%
iter 70: loss 2.4909, time 107.93ms, mfu 3.23%
step 75: train loss 2.4691, val loss 2.4806
saving checkpoint to out-shakespeare-char
iter 80: loss 2.4859, time 104.57ms, mfu 3.27%
iter 90: loss 2.4677, time 105.69ms, mfu 3.29%
step 100: train loss 2.4538, val loss 2.4624
saving checkpoint to out-shakespeare-char
iter 100: loss 2.4862, time 1691.89ms, mfu 2.99%
iter 110: loss 2.4655, time 108.40ms, mfu 3.03%
iter 120: loss 2.4813, time 107.24ms, mfu 3.07%
step 125: train loss 2.4561, val loss 2.4632
saving checkpoint to out-shakespeare-char
iter 130: loss 2.4748, time 105.63ms, mfu 3.12



In [92]:
def collect_traces(attention='flash'):
    #returns a list of traceEvents dictionaries for all runs with selected 
    #pt.trace.json keys:['schemaVersion', 'deviceProperties', 'distributedInfo', 'with_flops', 
    # 'with_modules', 'with_stack', 'traceEvents', 'traceName']
    data = []
    for run in glob.glob(f"{params.out_dir}/**/*.pt.trace.json", recursive=True):
        if run.split('logs_')[1].split('-')[0]==attention:
            with open(run) as jsonFile:
                events_dict = json.load(jsonFile)
                name = events_dict['traceName']
                events = events_dict['traceEvents']#TODO: add traceName to identify which run
                #events.update(name)
                data.append(events)
    return data

In [95]:
collect_traces(attention='flash')[:1]

[[{'ph': 'X',
   'cat': 'cpu_op',
   'name': 'autograd::engine::evaluate_function: DivBackward0',
   'pid': 313505,
   'tid': 313818,
   'ts': 1697559583119593,
   'dur': 73,
   'args': {'External id': 2049,
    'Ev Idx': 0,
    'Fwd thread id': 1,
    'Sequence number': 1712}},
  {'ph': 'X',
   'cat': 'cpu_op',
   'name': 'DivBackward0',
   'pid': 313505,
   'tid': 313818,
   'ts': 1697559583119618,
   'dur': 42,
   'args': {'External id': 2050,
    'Ev Idx': 1,
    'Input Dims': [[]],
    'Input type': ['float'],
    'Concrete Inputs': [''],
    'Fwd thread id': 1,
    'Sequence number': 1712}},
  {'ph': 'f',
   'id': 1,
   'pid': 313505,
   'tid': 313818,
   'ts': 1697559583119618,
   'cat': 'fwdbwd',
   'name': 'fwdbwd',
   'bp': 'e'},
  {'ph': 'X',
   'cat': 'cpu_op',
   'name': 'aten::div',
   'pid': 313505,
   'tid': 313818,
   'ts': 1697559583119625,
   'dur': 34,
   'args': {'External id': 2051,
    'Ev Idx': 2,
    'Input Dims': [[], []],
    'Input type': ['float', 'long int

In [68]:
def kernel_mean_occupancy(runTrace, kernel_name):
    #aggregates 'est. achieved occupancy %' and call duration for all calls of runTrace=traceEvents
    #computes the weighted average of occupancy with weights=call duration for each run

    #TODO add lambda for regex match with kernel_name

    keys = {'name', 'args', 'dur'}
    kernel_events = [{k: trace[k] for k in keys} for trace in runTrace if 'cat' in trace.keys() if trace['cat']=='kernel']

    kernel_events = [event for event in kernel_events if kernel_name == event['name']]
    kernel_aggregate = {event['name']: {'occupancy': 0,  'calls':0, 'dur':0} for event in kernel_events}
    kernel_summary = {event['name']: {'mean_occupancy': 0} for event in kernel_events}

    for event in kernel_events:
        kernel_aggregate[event['name']]['occupancy'] += event['args']['est. achieved occupancy %']*event['dur']
        kernel_aggregate[event['name']]['calls'] += 1
        kernel_aggregate[event['name']]['dur'] += event['dur']
        weighted_avg = np.round(kernel_aggregate[event['name']]['occupancy']/ kernel_aggregate[event['name']]['dur'],2)
        kernel_summary[event['name']]['mean_occupancy'] = weighted_avg
    return kernel_summary

In [69]:
def kernel_mean_duration(runTrace, kernel_name):
    #aggregates call duration for all calls of runTrace=traceEvents
    #computes the average call duration for each run

    #TODO add lambda for regex match with kernel_name

    keys = {'name', 'dur'}
    kernel_events = [{k: trace[k] for k in keys} for trace in runTrace if 'cat' in trace.keys() if trace['cat']=='kernel']

    kernel_events = [event for event in kernel_events if kernel_name == event['name']]
    kernel_summary = {event['name']: {'dur': 0, 'calls':0} for event in kernel_events}

    for event in kernel_events:
        kernel_summary[event['name']]['dur'] += event['dur']
        kernel_summary[event['name']]['calls'] += 1
        avg =  np.round(kernel_summary[event['name']]['dur']/ kernel_summary[event['name']]['calls'])
        kernel_summary[event['name']]['mean_duration'] = int(avg)

    return kernel_summary

In [70]:
def compare_kernels(attention, kernel_name):
    #outputs the metrics = mean_duration, mean_occupancy for kernel=kernel_name 
    # for different runs
    #this recovers the exact same values in the tensorboard

    traceEvents = collect_traces(attention=attention)
    #TODO simplify methods such that filter name is not needed twice?
    mean_durations = [kernel_mean_duration(run, kernel_name)[kernel_name]['mean_duration'] for run in traceEvents]
    mean_occupancies = [kernel_mean_occupancy(run, kernel_name)[kernel_name]['mean_occupancy'] for run in traceEvents]

    return mean_durations, mean_occupancies

In [75]:
flash_runs = collect_traces(attention='flash')
#mm = kernel_mean_duration(flash_runs[0], kernel_name = 'triton__0d1d2d3d4d5d6d')
#flash_runs[0]

JSONDecodeError: Expecting value: line 1 column 1 (char 0)

In [58]:
compare_kernels('flash', 'triton__0d1d2d3d4d5d6d')

KeyError: 'name'

In [48]:
compare_kernels('slow', 'triton__0d1d2d3d4d5d6d')

[882, 904]
[99.16, 99.19]


In [42]:
#exploring profiler metrics
print(len(data))
traceEvents = data[0]['traceEvents']#for run 0
duration = defaultdict
keys = {'name', 'dur', 'args'}
kernel_events = [{k: trace[k] for k in keys} for trace in traceEvents if 'cat' in trace.keys() if trace['cat']=='kernel']
kernel_names = {event['name']: 0 for event in kernel_events}

for event in kernel_events:
    kernel_names[event['name']] += event['dur']
#kernel_names = {trace['name']:  for trace in kernel_events}
#kernel_names = [duration[trace['name']]+= trace['dur'] for trace in kernel_events]

kernel_names

#kernel_durations['triton__0d1d2d3d4d5d6d7d8d910d11d']

30


{'void at::native::(anonymous namespace)::multi_tensor_apply_kernel<at::native::(anonymous namespace)::TensorListMetadata<1>, at::native::(anonymous namespace)::BinaryOpScalarFunctor<float, 1, 1, 0>, std::plus<float>, float>(at::native::(anonymous namespace)::TensorListMetadata<1>, at::native::(anonymous namespace)::BinaryOpScalarFunctor<float, 1, 1, 0>, std::plus<float>, float)': 12,
 'void at::native::(anonymous namespace)::multi_tensor_apply_kernel<at::native::(anonymous namespace)::FusedOptimizerTensorListMetadata<4>, at::native::(anonymous namespace)::FusedAdamMathFunctor<float, 4>, double, double, double, double, double, bool, bool, float*, float*, at::native::ADAM_MODE>(at::native::(anonymous namespace)::FusedOptimizerTensorListMetadata<4>, at::native::(anonymous namespace)::FusedAdamMathFunctor<float, 4>, double, double, double, double, double, bool, bool, float*, float*, at::native::ADAM_MODE)': 28952,
 'triton__0d1d': 24,
 'void at::native::(anonymous namespace)::distribution

# run experiments:

vary:

n_iters  - to check if profiling changes with more or less iterations --it should not

        for (n_heads*h_size = embedding dimensionality (n_embd)) vary: 

            n_heads 
            
            h_size

            seq_len  (seq length affect number of operations for fix head size and nheads, as it increased the projections dimensions, i.e. sizes of K, V, Q)

In [61]:
def collect_system_metrics(username, project):

    api = wandb.Api()
    runs = api.runs(f"{username}/{project}")
    system_metrics = defaultdict(dict) 
    
    for run in runs:
        if run.state =='finished':         
            system_metrics[run.name][run.id] = run.history(stream='events') #run.history() is a pandas data frame 
    
    return system_metrics

sm = collect_system_metrics("m-motta" , 'profile-attention-nano-gpt')

In [62]:
sm

defaultdict(dict,
            {'slow-n_iters200-n_head6-h_size384-seq_len256': {'ym3qpq2l':    system.network.sent  system.network.recv  system.cpu.9.cpu_percent  \
              0            674391.91            682564.45                      2.05   
              
                 system.cpu.1.cpu_percent  system.cpu.7.cpu_percent  _wandb  \
              0                         3                      1.36    True   
              
                 system.disk.out  system.disk.\.usageGB  system.gpu.0.temp  \
              0            667.5                 179.52              65.93   
              
                 system.gpu.0.memory  ...  system.proc.memory.percent  \
              0                    0  ...                        4.56   
              
                 system.disk.\.usagePercent  system.cpu.10.cpu_percent  system.disk.in  \
              0                        20.8                       3.06               0   
              
                 system.gpu.0.mem

About logged runtimes:

there is a difference between the runtime from .history() and .history(stream='events). Firstly, because events are restricted to the GPU, but probably also because the system is checked at specific intervals/checked once again after iterations are finished.

In [None]:
def compute_average_runtime(system_metrics):
    
    runtimes = {}
    for params in system_metrics.keys():
        print('params',params)
        runtimes[params] = 0
        count = 0
        for id in system_metrics[params].keys():
            print('id',id)
            print('.iloc[-1]',system_metrics[params][id]._runtime.iloc[-1])
            runtimes[params]+= system_metrics[params][id]._runtime.iloc[-1]
            count += 1
            print('')
        runtimes[params] = runtimes[params]/count
    
    return runtimes    

compute_average_runtime(sm)

In [None]:
def aggregate_system_metrics(username, project):

    #this is appending all the runs, I don't think I'll need this

    api = wandb.Api()
    runs = api.runs(f"{username}/{project}")
    system_metrics = {'flash-attention': [], 'slow-attention': []}
    for run in runs:
        system_metrics[run.name].append(run.history(stream='events')) #run.history() is a pandas data frame 

    for key, group in system_metrics.items():
        if group:
            system_metrics[key] = pd.concat(group, axis=0, join='outer', ignore_index=True)
    
    return system_metrics

#system_metrics = aggregate_system_metrics("m-motta" , 'profile-attention-nano-gpt')

In [None]:
def filter_metrics(system_metrics, list_of_matches):
    #selects specific metris (columns) from the .history() dataframe, based on string matches
    pattern = '|'.join(list_of_matches)

    def split_select(s):
        return s.split('system.')[1]

    system_metrics = {
        key: df.filter(regex=pattern, axis=1).rename(columns=split_select)
        for key, df in system_metrics.items()
    }

    return system_metrics

In [None]:
system_metrics = aggregate_system_metrics("m-motta" , 'profile-attention-nano-gpt')
system_metrics = filter_metrics(system_metrics, ['gpu', 'disk'])

In [None]:
system_metrics['flash-attention'].head()

In [None]:

system_metrics['flash-attention'].head()

In [None]:
params = ['gpu', 'disk']
metrics = system_metrics['flash-attention'].columns
metrics = [m.split('system.')[1] for m in metrics if any([p in m for p in params])]
metrics

In [None]:
#['schemaVersion', 'deviceProperties', 'distributedInfo', 'with_flops', 'with_modules', 'with_stack', 'traceEvents', 'traceName']
#only deviceProperties and traceEvents contain relevant information
data = []
for run in glob.glob(f"{params.out_dir}/**/*.pt.trace.json", recursive=True):
  #  print('\n',run)
    with open(run) as jsonFile:
        data.append(json.load(jsonFile))
       
        


Notes from paper:

All models are trained with the same hyperparameters for 400K steps.

We run all implementations with mixed-precision training (PyTorch AMP).

Speedup also changes when we increase the head dimension. Each block
requires more memory, so we need to use smaller block sizes to fit into SRAM. Figure 6 shows speedup with
head dimension 128 on an A100 (batch size 16, 12 heads). We see less speedup overall—but we can still see
significant speedup (up to 3×) with a causal mask, where half the blocks are masked out.
                     
We confirm that the memory footprint
of FlashAttention scales linearly with seq. length and is up to 3× faster than standard attention for
common seq. lengths (up to 2K). We confirm that runtime of block-sparse FlashAttention scales linearly
in seq. length and is faster than all existing approximate attention baselines.
                     
We train the model on 8×A100-80GB GPUs. Each training run takes between 16 and 19 minutes, and we
average the results of 10 runs.
                     
attension head, seq length and block size

In [None]:
#exploring profiler metrics

trace=data[0]['traceEvents']
#ops = defaultdict

ops = [t for t in trace if 'cat' in t.keys() if t['cat']=='kernel']
#trace[100]
ops[:2]

In [None]:
https://github.com/pytorch/examples/blob/main/imagenet/main.py
use this to refactor train.py

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize the images to a common size
    transforms.RandomCrop(224),     # Randomly crop the images to a smaller size
    transforms.RandomHorizontalFlip(),  # Apply random horizontal flip for data augmentation
    transforms.ToTensor(),         # Convert images to tensors
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),  # Normalize with ImageNet statistics
])

# Define the training dataset
train_dataset = ImageFolder('path/to/train', transform=data_transforms)

# Define the validation dataset (if available)
val_dataset = ImageFolder('path/to/val', transform=data_transforms)

# Define the DataLoader for both datasets
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)