In [1]:
from utils import *
# Set up interactive widgets for the variables
from ipywidgets import interact, IntSlider, Checkbox, BoundedIntText, BoundedFloatText, Dropdown
import ipywidgets as widgets


# Max Batch Size
def create_interactive_widget():
    # Quantization dropdown
    graph_type = widgets.Dropdown(
        options=[
            ('1. ISO-HW: Model vs Throughput', 1),
            ('2. ISO-HW: Model vs Latency (TTFT, TPOT)', 2),
            ('3. ISO-HW, ISO-Model: Batch vs Throughput/Latency', 3) ],
        value=3,
        description='Graph Type:',
        disabled=False ,
        layout=widgets.Layout(width='500px'),  # Adjust this value as needed
        style={'description_width': 'initial'},

    )
    max_batch_size = widgets.BoundedIntText(
        value=8, min=1, max=128, step=1,
        description='Max Batch Size:',
        disabled=False ,
        layout=widgets.Layout(width='150px'),  # Adjust this value as needed
        style={'description_width': 'initial'}
    )

    # Custom Usecases
    usecases = Dropdown( options=['Ques-Ans', 'Text Summarization', 'Chatbots', 'Code Gen.', 'Custom'], value='Chatbots', description='Usecases:', disabled=False,)

    # Beam size
    beam_size = widgets.IntSlider(value=2, min=1, max=16, description='# of Parallel Beams:', style={'description_width': 'initial'},)

    # Input Tokens
    input_tokens = widgets.BoundedIntText(
        value=2048, min=1, max= 100000, step=1,
        description='Input Tokens:',
        disabled=False ,
        layout=widgets.Layout(width='150px'),  # Adjust this value as needed
        style={'description_width': 'initial'}
    )

    # Output Tokens
    output_tokens = widgets.BoundedIntText(
        value=128, min=1, max= 100000, step=1,
        description='Output Tokens:',
        disabled=False ,
        layout=widgets.Layout(width='150px'),  # Adjust this value as needed
        style={'description_width': 'initial'}
    )

    # Quantization dropdown
    quantization = widgets.Dropdown(
        options=['fp8', 'bf16', 'int8', 'int4', 'int2', 'fp32'],
        value='int8',
        description='Quantization:',
        disabled=False ,
        layout=widgets.Layout(width='150px'),  # Adjust this value as needed
        style={'description_width': 'initial'},

    )
    model_box = widgets.SelectMultiple( options=[
        ('meta-llama/Llama-2-7B','llama2_7b'),
        ('meta-llama/Meta-Llama-3-8B','llama3_8b'),
        ('meta-llama/Llama-2-13B','llama2_13b'),
        ('meta-llama/Llama-2-70B','LLaMA2_70b'),
        ('meta-llama/Meta-Llama-3.1-405B', 'llama_405b'),
        ('google/gemma-2B','gemma_2b'),
        ('google/gemma-7B','gemma_7b'),
        ('google/gemma-2-9B','gemma2_9b'),
        ('google/gemma-2-27B','gemma2_27b'),
        ('mistralai/mistral-7B', 'mistral_7b'),
        ('mistralai/Mixtral-8x7B','mixtral_8x7b'),
        ('microsoft/phi3mini', 'phi3mini'),
        ('microsoft/phi3small', 'phi3small'),
        ('microsoft/phi3medium', 'phi3medium'),
        ('databricks/dbrx-base','dbrx'),
        ('xai-org/grok-1','grok-1'),
        ('openai/gpt-3','gpt-3'),
        ('openai/gpt-4','gpt-4'),
        ('facebook/opt-125m','opt_125m'),
        ('facebook/opt-350m','opt_350m'),
        ('facebook/opt-1.3b','opt_1b'),
        ('facebook/opt-175b','opt_175b'),
        ], value=['llama2_7b'],
        description='Models:',
        disabled=False,
        layout=widgets.Layout(width='300px', height='150px'))
    # System
    system = Dropdown( options=['A100_40GB_GPU', 'A100_80GB_GPU', 'H100_GPU','GH200_GPU', 'TPUv4','TPUv5e', 'MI300X', 'Gaudi3', 'Custom'], value='H100_GPU', description='System:', disabled=False,)

    # Number of Nodes
    nodes = widgets.IntText(
        value=2,
        description='# Nodes:',
        layout=widgets.Layout(width='150px'),  # Adjust this value as needed
        disabled=False
    )

    # System Efficiency
    system_efficiency = widgets.FloatSlider(
        value=0.80,
        min=0,
        max=1.0,
        step=0.01,
        description='System Efficiency:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.2f',
    )


    # FLOPS (initially hidden)
    flops = widgets.FloatText(value=1000,description='FLOPS(T):',disabled=False, layout=widgets.Layout(width='200px'),)
    # MEM BW (initially hidden)
    mem_bw = widgets.FloatText(value=3.6,description='MEM BW(TB/s):',disabled=False, layout=widgets.Layout(width='200px'),)
    # FLOPS (initially hidden)
    mem_cap = widgets.FloatText(value=48,description='FLOPS(GBs):',disabled=False, layout=widgets.Layout(width='200px'),)
    # ICN BW (initially hidden)
    icn_bw = widgets.FloatText(value=100.0,description='ICN BW(GB/s):',disabled=False, layout=widgets.Layout(width='200px'),)


    # Function to show/hide FLOPS and MEM BW
    def update_visibility_system_param(change):
        if 'Custom' in change['new']:
            flops.layout.display = ''
            mem_bw.layout.display = ''
            mem_cap.layout.display = ''
            icn_bw.layout.display = ''
        else:
            flops.layout.display = 'none'
            mem_bw.layout.display = 'none'
            mem_cap.layout.display = 'none'
            icn_bw.layout.display = 'none'

    # Connect the function to the models widget
    system.observe(update_visibility_system_param, names='value')

    # Function to show/hide FLOPS and MEM BW
    def update_visibility_usecases(change):
        if 'Ques-Ans' in change['new']:
            beam_size.value = 4
            input_tokens.value = 1000
            output_tokens.value = 200
        elif 'Text Summarization' in change['new']:
            beam_size.value = 4
            input_tokens.value = 15000
            output_tokens.value = 1000
        elif 'Chatbots' in change['new']:
            beam_size.value = 2
            input_tokens.value = 2048
            output_tokens.value = 128
        elif 'Code Gen.' in change['new']:
            beam_size.value = 4
            input_tokens.value = 20000
            output_tokens.value = 50

    # Connect the function to the models widget
    usecases.observe(update_visibility_usecases, names='value')

    # # Initially hide custom params
    # beam_size.layout.display = 'none'
    # input_tokens.layout.display = 'none'
    # output_tokens.layout.display = 'none'

    flops.layout.display = 'none'
    mem_bw.layout.display = 'none'
    mem_cap.layout.display = 'none'
    icn_bw.layout.display = 'none'

    # Layout
    left_box = widgets.HBox([quantization, max_batch_size])
    input_param_box = widgets.VBox([usecases, beam_size,widgets.HBox([ input_tokens, output_tokens])])
    top_box = widgets.VBox([left_box, input_param_box, ])
    bottom_box = widgets.HBox([system, nodes, system_efficiency])
    system_bottom_box = widgets.HBox([flops, mem_bw, mem_cap, icn_bw])


    # Final layout
    final_layout = widgets.VBox([graph_type, widgets.HBox([model_box,top_box]), bottom_box, system_bottom_box], layout=widgets.Layout(justify_content='space-between'))

    output = widgets.interactive_output(generate_demand_curve,
                                        dict(
        graph_type = graph_type,
        system_box = system,
        system_eff = system_efficiency,
        num_nodes_slider = nodes,
        model_box=model_box,
        quantization_box=quantization,
        batch_slider=max_batch_size,
        input_token_slider=input_tokens,
        output_token_slider=output_tokens,
        beam_size = beam_size,
        flops=flops,
        mem_bw=mem_bw,
        mem_cap=mem_cap,
        icn_bw=icn_bw))
    clear_output(wait=True)
    display(final_layout, output)


## TTFT, TPOT, RPS, Latency, 
1. ISO-HW: Model vs Throughput
2. ISO-HW: Model vs Latency (TTFT, TPOT)
3. ISO-HW, ISO-Model: Batch vs Throughput/Latency
4. 
## Website support
## A visulaization of the system being modelled


In [2]:
fig = create_interactive_widget()
display(fig)

VBox(children=(Dropdown(description='Graph Type:', index=2, layout=Layout(width='500px'), options=(('1. ISO-HW…

Output()

None