In [1]:
import json
import os
import matplotlib.pyplot as plt

def try_parse_int(s, base=10, val=None):
  try:
    if s == None:
        return None
    return int(s, base)
  except ValueError:
    return val

# Function to parse criterion output
def parse_criterion_output(directory):
    data = []
    print(f"seach in {directory}")
    # Walk through the directory and read each JSON file
    for root, g, files in os.walk(directory):
        benchmark_group = None 
        for file in files:
            if "base" in root:
                continue
            if "matmul" in root:
                if file == "benchmark.json":
                    with open(os.path.join(root, file)) as f:
                        benchmark_data = json.load(f)
                        group =  benchmark_data['group_id']
                        func =  benchmark_data['function_id']

                        size =  try_parse_int(benchmark_data['value_str'])
                        throughput = benchmark_data['throughput']['Bytes']
                        benchmark_group = (group, size, throughput, func)
                if file == "estimates.json":
                    with open(os.path.join(root, file)) as f:
                        benchmark_data = json.load(f)
                        mean = benchmark_data['mean']['point_estimate']
                        if benchmark_group != None:
                            benchmark_group = (benchmark_group[0],benchmark_group[1],benchmark_group[2],benchmark_group[3],mean)
                            data.append(benchmark_group)

    return data


directory = './target/criterion/'
data = parse_criterion_output(directory)

seach in ./target/criterion/


In [6]:
%matplotlib qt
# Function to generate comparison table
def generate_comparison_table(data):
    #group by group:
    groups = set(group for group, size, tp, func, mean in data)
    tables = []
    for group in groups:
     
        table = {}
        for g, size, tp, function_name, mean in data:
            if g == group:
                if size not in table:
                    table[size] = {}

                table[size][function_name] = mean

        print(f"Comparison Table {group}")
        print("Size\t", end="")
        functions = set(func for g, size, tp, func, mean in data if g == group)
        for function in functions:
            print(f"{function}\t", end="")
        print()
        
        for size, results in sorted(table.items()):
            print(f"{size}\t", end="")
            for function in functions:
                if function in results:
                    mean = results[function]
                    #print(f"{mean:.2f} ± {0:.2f}\t", end="")
                    print(f"{mean:.2f}\t", end="")
                else:
                    print("N/A\t", end="")
            print()
        tables.append((group, table))
    return tables

def print_best_results(data, index):
    groups = set(group for (group, size, tp, func, mean) in data)
    for group in groups:
        print(f"\nGroup: {group}")
        function_data = {}
        for (g, size, tp, function_name, mean) in [d for d in data if d[0] == group]:
            if function_name not in function_data:
                function_data[function_name] = {'sizes': [], 'means': [], 'tps': []}
            #if size in [1, 8, 32, 64, 128, 256, 512, 1024, 2048]:
            function_data[function_name]['sizes'].append(size)
            function_data[function_name]['means'].append(mean)
            function_data[function_name]['tps'].append(tp / mean)
        
        if len([function_name for function_name, metrics in function_data.items() if len(metrics['sizes']) > index]) == 0:
            continue

        tps_data = []
        size = 0
        for function_name, metrics in function_data.items():
            sizes = metrics['sizes']
            means = metrics['means']
            tps   = metrics['tps']
            dat = zip(sizes, means, tps)
            dat = sorted(dat, key=lambda x : x[0])
            sizes = [str(d[0]) for d in dat]
            means = [d[1] for d in dat]
            tps = [d[2] for d in dat]
            size = sizes[index]
            tps_data.append((function_name, tps[index]))

        print(f"(Size={size})")
        tps_map = sorted(tps_data, key=lambda x : x[1], reverse=True)
        
        for (name, tps) in tps_map:
            print(f"{name}: {tps}")


# Function to plot benchmark results
def plot_benchmark_results(data):
    groups = set(group for (group, size, tp, func, mean) in data)
    tables = []
    for group in groups:
    
        function_data = {}
        for (g, size, tp, function_name, mean) in [d for d in data if d[0] == group]:
            if function_name not in function_data:
                function_data[function_name] = {'sizes': [], 'means': [], 'tps': []}
            #if size in [1, 8, 32, 64, 128, 256, 512, 1024, 2048]:
            function_data[function_name]['sizes'].append(size)
            function_data[function_name]['means'].append(mean)
            function_data[function_name]['tps'].append(tp / mean)
        
        function_colors = {}
        c_counter = 0
        #colors = ["black", "red", "green", "blue", "yellow", "orange", "magenta", "cyan", "brown", "grey"]
        colors = plt.cm.tab20.colors

        if len([function_name for function_name, metrics in function_data.items() if len(metrics['sizes']) > 1]) == 0:
            continue

        plt.figure(figsize=(10, 14))
        for function_name, metrics in function_data.items():
            sizes = metrics['sizes']
            means = metrics['means']
            tps   = metrics['tps']
         
            dat = zip(sizes, means, tps)
            dat = sorted(dat, key=lambda x : x[0])
            sizes = [str(d[0]) for d in dat]
            means = [d[1] for d in dat]
            tps = [d[2] for d in dat]

            func = function_name.replace("_Prefetch", "")
            func = func.replace("_NoPadded", "")
            C = c_counter
            if func in function_colors:
                C = function_colors[func]
            else:
                function_colors[func] = C
                c_counter += 1
            #C = f"C{C}"
            C = colors[C]

            #std_devs = metrics['std_devs']
            #plt.errorbar(sizes, means, yerr=std_devs, label=function_name, capsize=5)
            if "Prefetch" in function_name and "NoPadded" in function_name:
                plt.plot(sizes, tps, "*--", label=function_name, color=C)
            elif "Prefetch" in function_name:
                plt.plot(sizes, tps, "--", label=function_name, color=C)
            elif "NoPadded" in function_name:
                plt.plot(sizes, tps, "*-", label=function_name, color=C)
            else: 
                plt.plot(sizes, tps, label=function_name, color=C)

        plt.xlabel('Size')
        plt.ylabel('Time (ns)')
        plt.ylabel('Throughput(Bytes)')
        plt.title(f'Benchmark Comparison {group}')
        #plt.xscale('log')
        plt.yscale('log')
        plt.legend()
        plt.grid(True)
        plt.show()


#run the analysis:
#comparison_table = generate_comparison_table(data)

print_best_results(data, 0)
plot_benchmark_results(data)


# Save comparison table to a file
# with open("benchmark_comparison_table.txt", "w") as f:
#     for (group, table) in comparison_table:
#         for size, results in sorted(table.items()):
#             f.write(f"{size}\t")
#             for function in table:
#                 if function in results:
#                     mean = results[function]
#                     #f.write(f"{mean:.2f} ± {0:.2f}\t")
#                     f.write(f"{mean:.2f}\t")
#                 else:
#                     f.write("N/A\t")
#             f.write("\n")                               


Group: matmul_(24x1536 * 1536x6144)
(Size=None)
wgpu_matmul_Matmul5_32_32(): 922.5797551910705
wgpu_matmul_Matmul5_32_32(_Prefetch): 685.5685242363462
wgpu_matmul_MatmulX: 656.4797558241759
wgpu_matmul_Matmul5_16_16: 637.8123107537132
wgpu_matmul_Matmul7: 587.0887150522702
wgpu_matmul_Matuml5_64_64(): 535.9628438449554
wgpu_matmul_Matuml5_64_64(_Prefetch): 466.20527097288476
wgpu_matmul_Matmul1: 263.87447541311127
wgpu_matmul_Matmul5_32_32(_NoPadded): 171.19553488461884
wgpu_matmul_Matmul5_32_32(_Prefetch_NoPadded): 154.7785604754766
wgpu_matmul_Matuml5_64_64(_Prefetch_NoPadded): 138.73582490895217
wgpu_matmul_Matuml5_64_64(_NoPadded): 138.65328455427368
wgpu_matmul_Matmul5_128_128(_Prefetch_NoPadded): 75.13673379128662
wgpu_matmul_Matmul5_128_128(_NoPadded): 75.06937726037803
wgpu_matmul_Matmul5_64_64_8_8(_NoPadded): 73.80776507714106
wgpu_matmul_Matmul5_64_64_8_8(_Prefetch_NoPadded): 72.90763099679333
wgpu_matmul_Matmul5_64_64_8_8(): 66.25237336813746
wgpu_matmul_Matmul5_64_64_8_8(_

In [3]:
print_best_results(data, 9)


Group: matmul_(24x1536 * 1536x6144)

Group: matmul_m_1
(Size=2048)
wgpu_matmul_Matmul5_16_16: 27.553693415204595
wgpu_matmul_Matmul5_16_64(): 25.74192547476765
wgpu_matmul: 24.24842933111788
wgpu_matmul_Matmul1: 23.879218836428457
wgpu_matmul_Matmul7: 22.59946910203689
wgpu_matmul_Matmul5_32_32(): 22.04527462708373
cpu_matmul: 21.3154725758459
wgpu_matmul_Matmul5_1_128(): 18.173622091670566
wgpu_matmul_Matmul5_16_64(_LoadB): 16.25266185948022
wgpu_matmul_Matuml5_64_64(): 13.554887091587627
wgpu_matmul_Matmul5_1_128(_LoadB): 11.787527517478617
wgpu_matmul_Matmul5_16_64(_Prefetch_LoadB): 11.480629516414767
wgpu_matmul_Matmul5_1_128(_Prefetch): 11.322072233944004
wgpu_matmul_Matmul5_1_128(_Prefetch_LoadB): 9.997974061002813
wgpu_matmul_Matmul5_64_64_8_8(): 8.936843476852262
wgpu_matmul_Matmul5_128_128(): 5.826312281347342
wgpu_matmul_Matmul5_1_128(_NoPadded): 5.523290747217718
wgpu_matmul_Matmul5_1_128(_Prefetch_NoPadded): 5.417819124550855
wgpu_matmul_Matmul5_16_64(_NoPadded_LoadB): 5.0

In [4]:


FLOPS = 2048 * 2048 * 2048
B = 2048*2048*2
#t2 = FLOPS / 1608.0 + (B) / x = FLOPS / 1044.77

x =  B / (FLOPS / 1044.77 - FLOPS / 1608.0) 
print(x)



t = FLOPS / 1608.0 + B / x 
print(t)

t2 = FLOPS / 1044.77
print(t2)



2.912869326252153
8221842.694564354
8221842.694564354
