In [2]:
import pandas as pd

# Define the table data with hierarchical columns represented as tuples
data = {
    ("Target", "", ""): [
        "Llama-3.1-70B", "Llama-3.1-70B", "Llama-3.1-70B", "Llama-3.1-70B",
        "Mixtral-8x22B-Instruct-v0.1", "Gemma-2-9b", "CodeLlama-13b-Instruct-hf",
        "Phi-3-medium-128k-instruct", "Mixtral-8x22B-Instruct-v0.1"
    ],
    ("Task", "", ""): [
        "long-ctx summ", "long-ctx summ", "long-ctx summ", "long-ctx summ",
        "summ", "summ", "coding", "long-ctx summ", "long-ctx summ"
    ],
    ("Data", "", ""): [""] * 9,
    ("In Toks", "", ""): [""] * 9,
    ("Device", "", ""): [
        "A100 * 2", "A100 * 2", "A100 * 2", "A100 * 2",
        "A100 * 4", "A6000", "A6000", "A6000", "A100 * 4"
    ],
    ("AR", "$t=0$", "Out Toks"): [""] * 9,
    ("AR", "$t=0$", "TTFT (ms)"): [""] * 9,
    ("AR", "$t=0$", "TPOT (ms)"): [""] * 9,
    ("AR", "$t>0$", "Out Toks"): [""] * 9,
    ("AR", "$t>0$", "TTFT (ms)"): [""] * 9,
    ("AR", "$t>0$", "TPOT (ms)"): [""] * 9,
    ("Drafter", "", ""): [
        "Qwen2-0.5B-Instruct", "Llama-3.2-1B", "Llama-3.2-3B", "Llama-3.1-8B",
        "vicuna-68m", "vicuna-68m", "tiny-starcoder-py", "Qwen2-0.5B-Instruct", "Qwen2-0.5B-Instruct"
    ],
    ("SD", "$t=0$", "Out Toks"): ["N/A"] * 9,
    ("SD", "$t=0$", "TTFT (ms)"): ["N/A"] * 9,
    ("SD", "$t=0$", "TPOT (ms)"): ["N/A"] * 9,
    ("SD", "$t>0$", "Out Toks"): ["N/A"] * 9,
    ("SD", "$t>0$", "TTFT (ms)"): ["N/A"] * 9,
    ("SD", "$t>0$", "TPOT (ms)"): ["N/A"] * 9,
    ("Alg \\ref{alg:exact-matching}", "$t=0$", "Out Toks"): [""] * 9,
    ("Alg \\ref{alg:exact-matching}", "$t=0$", "TTFT (ms)"): [""] * 9,
    ("Alg \\ref{alg:exact-matching}", "$t=0$", "TPOT (ms)"): [""] * 9,
    ("Alg \\ref{alg:exact-matching}", "$t>0$", "Out Toks"): [""] * 9,
    ("Alg \\ref{alg:exact-matching}", "$t>0$", "TTFT (ms)"): [""] * 9,
    ("Alg \\ref{alg:exact-matching}", "$t>0$", "TPOT (ms)"): [""] * 9,
    ("Alg \\ref{alg:vocabs-intersection}", "$t=0$", "Out Toks"): [""] * 9,
    ("Alg \\ref{alg:vocabs-intersection}", "$t=0$", "TTFT (ms)"): [""] * 9,
    ("Alg \\ref{alg:vocabs-intersection}", "$t=0$", "TPOT (ms)"): [""] * 9,
    ("Alg \\ref{alg:vocabs-intersection}", "$t>0$", "Out Toks"): [""] * 9,
    ("Alg \\ref{alg:vocabs-intersection}", "$t>0$", "TTFT (ms)"): [""] * 9,
    ("Alg \\ref{alg:vocabs-intersection}", "$t>0$", "TPOT (ms)"): [""] * 9,
}

# Create DataFrame with multi-index columns
df = pd.DataFrame(data)

def save_dataframe_to_latex_merged(df, filename="model_performance_table.tex"):
    """
    Convert a Pandas DataFrame with hierarchical headers to a LaTeX tabularx table with merged headers.

    Parameters:
        df (pd.DataFrame): DataFrame to convert.
        filename (str): The filename to save the LaTeX output.

    Returns:
        None
    """
    column_format = "|X" * len(df.columns) + "|"

    # Extract multi-index levels for merging headers
    level_1 = []
    level_2 = []
    level_3 = []

    for col in df.columns:
        level_1.append(col[0])
        level_2.append(col[1])
        level_3.append(col[2])

    def merge_headers(header_list):
        merged = []
        prev = None
        count = 1

        for idx, item in enumerate(header_list + [None]):  # Add None for last element comparison
            if item == prev:
                count += 1
            else:
                if prev is not None:
                    if count > 1:
                        merged.append(rf"\multicolumn{{{count}}}{{|c|}}{{{prev}}}")
                    else:
                        merged.append(prev)
                count = 1
            prev = item
        return " & ".join(merged)

    first_row = merge_headers(level_1) + r" \\ \hline"
    second_row = merge_headers(level_2) + r" \\ \hline"
    third_row = " & ".join(level_3) + r" \\ \hline"

    rows = "\n".join(" & ".join(str(val) for val in row) + r" \\ \hline" for row in df.values)

    latex_str = f"""
\\usepackage{{array}}
\\usepackage{{tabularx}}
\\renewcommand{{\\arraystretch}}{{1.0}}

{{\\tiny
\\setlength{{\\tabcolsep}}{{1pt}}
\\begin{{tabularx}}{{\\textwidth}}{{{column_format}}}
\\hline
{first_row}
{second_row}
{third_row}
{rows}
\\end{{tabularx}}
}}
"""

    with open(filename, "w") as f:
        f.write(latex_str.strip())
    print(f"LaTeX table saved to {filename}")

# Save the LaTeX table with merged hierarchical headers
save_dataframe_to_latex_merged(df, "model_performance_table.tex")
df

LaTeX table saved to model_performance_table.tex


Unnamed: 0_level_0,Target,Task,Data,In Toks,Device,AR,AR,AR,AR,AR,...,Alg \ref{alg:exact-matching},Alg \ref{alg:exact-matching},Alg \ref{alg:exact-matching},Alg \ref{alg:exact-matching},Alg \ref{alg:vocabs-intersection},Alg \ref{alg:vocabs-intersection},Alg \ref{alg:vocabs-intersection},Alg \ref{alg:vocabs-intersection},Alg \ref{alg:vocabs-intersection},Alg \ref{alg:vocabs-intersection}
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,$t=0$,$t=0$,$t=0$,$t>0$,$t>0$,...,$t=0$,$t>0$,$t>0$,$t>0$,$t=0$,$t=0$,$t=0$,$t>0$,$t>0$,$t>0$
Unnamed: 0_level_2,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Out Toks,TTFT (ms),TPOT (ms),Out Toks,TTFT (ms),...,TPOT (ms),Out Toks,TTFT (ms),TPOT (ms),Out Toks,TTFT (ms),TPOT (ms),Out Toks,TTFT (ms),TPOT (ms)
0,Llama-3.1-70B,long-ctx summ,,,A100 * 2,,,,,,...,,,,,,,,,,
1,Llama-3.1-70B,long-ctx summ,,,A100 * 2,,,,,,...,,,,,,,,,,
2,Llama-3.1-70B,long-ctx summ,,,A100 * 2,,,,,,...,,,,,,,,,,
3,Llama-3.1-70B,long-ctx summ,,,A100 * 2,,,,,,...,,,,,,,,,,
4,Mixtral-8x22B-Instruct-v0.1,summ,,,A100 * 4,,,,,,...,,,,,,,,,,
5,Gemma-2-9b,summ,,,A6000,,,,,,...,,,,,,,,,,
6,CodeLlama-13b-Instruct-hf,coding,,,A6000,,,,,,...,,,,,,,,,,
7,Phi-3-medium-128k-instruct,long-ctx summ,,,A6000,,,,,,...,,,,,,,,,,
8,Mixtral-8x22B-Instruct-v0.1,long-ctx summ,,,A100 * 4,,,,,,...,,,,,,,,,,
