In [None]:
from bittransgnn_ops_calc import bert_ops, bert_ops_info, bert_embedding_params, bert_attn_params, bert_ffn_params, bert_pooler_params, bert_layernorms, bert_clsif, gnn_params
from bittransgnn_byte_calc import transformer_bytes, return_bert_byte_info, return_bittransgnn_byte_info
from bittransgnn_energy_calc import bittransformer_energy_info, bittransgnn_energy_info, get_bittransformer_total_energy
import matplotlib.pyplot as plt
import itertools
import numpy as np
import os

In [None]:
os.makedirs('efficiency_plots/', exist_ok=True)

In [None]:
def get_do(dataset_name):
    if dataset_name == "20ng":
        do = 20
    elif dataset_name == "mr":
        do = 2
    elif dataset_name == "ohsumed":
        do = 23
    elif dataset_name == "r8":
        do = 8
    elif dataset_name == "r52":
        do = 52
    return do

def get_num_sequences(dataset_name):
    if dataset_name == "20ng":
        num_sequences = 18846
    elif dataset_name == "mr":
        num_sequences = 10662
    elif dataset_name == "ohsumed":
        num_sequences = 7400
    elif dataset_name == "r8":
        num_sequences = 7674
    elif dataset_name == "r52":
        num_sequences = 9100
    return num_sequences

def get_dataset_conf(dataset_name):
    do = get_do(dataset_name)
    num_sequences = get_num_sequences(dataset_name)
    dataset_conf = {"num_sequences": num_sequences, "do": do}
    print(dataset_conf)
    return dataset_conf

def get_gnn_conf(dataset_name, bert_model_size, gnn_bits=32.0):
    if bert_model_size == "base":
        dh = 768
    elif bert_model_size == "large":
        dh = 1024
    dg = 256
    do = get_do(dataset_name)
    import pickle as pkl
    with open("../bittransgnn/dataset/ind.{}.adj".format(dataset_name), 'rb') as f:
        adj = pkl.load(f)
    num_nodes = adj.shape[0]
    num_edges = adj.count_nonzero()
    gnn_conf = {"dg": dg, "do": do, "dh": dh, "num_nodes": num_nodes, "num_edges": num_edges, "gnn_bits": gnn_bits}
    return gnn_conf

def get_acc_list(dataset_name, for_energy=False, small=False, bert_only=False):
    if dataset_name == "20ng":
        bittrans_acc = [81.49, 82.15, 83.05]
        bittrans_ind_acc = [82.97, 82.14, 84.36]
        bittransgnnkd_acc = [84.77, 84.62, 84.97]
        bittransgnn_acc = [88.65, 88.43, 88.64]
        bittransgnn_static_acc = [86.26, 87.84, 87.58]
        bittransgnn_dynamic_acc = [88.65, 88.43, 88.64]
        bert_acc = [85.30]
        bittransgnn_acc = [89.30]
    elif dataset_name == "mr":
        bittrans_acc = [76.92, 77.66, 81.86]
        bittrans_ind_acc = [77.12, 78.14, 84.78]
        bittransgnnkd_acc = [78.47, 79.63, 84.67]
        bittransgnn_acc = [79.27, 80.23, 85.27]
        bittransgnn_static_acc = [78.64, 79.02, 83.24]
        bittransgnn_dynamic_acc = [79.27, 80.23, 85.27]
        bert_acc = [85.70]
        bittransgnn_acc = [86.00]
    elif dataset_name == "ohsumed":
        bittrans_acc = [63.34, 64.08, 66.62]
        bittrans_ind_acc = [66.41, 68.09, 55.48]
        bittransgnnkd_acc = [67.82, 69.34, 70.82]
        bittransgnn_acc = [69.36, 70.08, 71.49]
        bittransgnn_static_acc = [69.36, 67.84, 70.81]
        bittransgnn_dynamic_acc = [68.90, 67.84, 70.81]
        bert_acc = [70.50]
        bittransgnn_acc = [72.80]
    elif dataset_name == "r8":
        bittrans_acc = [97.30, 97.29, 96.82]
        bittrans_ind_acc = [97.94, 97.76, 97.67]
        bittransgnnkd_acc = [97.32, 97.25, 97.80]
        bittransgnn_acc = [98.10, 97.78, 98.12]
        bittransgnn_static_acc = [98.09, 97.72, 98.02]
        bittransgnn_dynamic_acc = [98.10, 97.78, 98.12]
        bert_acc = [97.80]
        bittransgnn_acc = [98.10]
    elif dataset_name == "r52":
        bittrans_acc = [94.42, 95.43, 94.85]
        bittrans_ind_acc = [95.64, 95.99, 96.07]
        bittransgnnkd_acc = [95.81, 95.84, 95.91]
        bittransgnn_acc = [95.94, 95.99, 96.26]
        bittransgnn_static_acc = [94.00, 95.50, 96.12]
        bittransgnn_dynamic_acc = [95.94, 95.99, 96.26]
        bert_acc = [96.40]
        bittransgnn_acc = [96.60]

    if bert_only:
        confs_bits_accs_list = [bittrans_acc, bittrans_ind_acc, bittransgnnkd_acc, bert_acc]
    else:
        if for_energy:
            if small:
                confs_bits_accs_list = [bittrans_acc, bittrans_ind_acc, bittransgnnkd_acc, bittransgnn_dynamic_acc, bittransgnn_static_acc]
            else:
                confs_bits_accs_list = [bittrans_acc, bittrans_ind_acc, bittransgnnkd_acc, bittransgnn_dynamic_acc, bittransgnn_static_acc, bert_acc, bittransgnn_acc]
        else:
            if small:
                confs_bits_accs_list = [bittrans_acc, bittrans_ind_acc, bittransgnnkd_acc, bittransgnn_acc]
            else:
                confs_bits_accs_list = [bittrans_acc, bittrans_ind_acc, bittransgnnkd_acc, bittransgnn_acc, bert_acc, bittransgnn_acc]
    return confs_bits_accs_list

def get_model_conf_list(for_energy=False, small=False, bert_only=False):
    if bert_only:
        #model_conf_list = ["BitTrans", "DS", "KD", "BERT"]
        model_conf_list = ["BitTransformer", "DS", "KD", "Transformer"]
    else:
        if for_energy:
            if small:
                #model_conf_list = ["BitTrans", "DS", "KD", "BitTransGNN (d)", "BitTransGNN (s)"]
                model_conf_list = ["BitTransformer", "DS", "KD", "BitTransGNN (d)", "BitTransGNN (s)"]
            else:
                #model_conf_list = ["BitTrans", "DS", "KD", "BitTransGNN (d)", "BitTransGNN (s)", "BERT", "BERTGCN (d)"]
                model_conf_list = ["BitTransformer", "DS", "KD", "BitTransGNN (d)", "BitTransGNN (s)", "Transformer", "BERTGCN (d)"]
        else:
            if small:
                #model_conf_list = ["BitTrans", "DS", "KD", "BitTransGNN"]
                model_conf_list = ["BitTransformer", "DS", "KD", "BitTransGNN"]
            else:
                #model_conf_list = ["BitTrans", "DS", "KD", "BitTransGNN", "BERT", "BERTGCN"]
                model_conf_list = ["BitTransformer", "DS", "KD", "BitTransGNN", "Transformer", "BERTGCN"]
    return model_conf_list

"""
        if small:
            model_conf_list = ["BitTrans", "DS", "KD", "BitTransGNN"]
        else:
            model_conf_list = ["BitTrans", "DS", "KD", "BitTransGNN", "BERT", "BERTGCN"]
"""

In [None]:
model_type = "bert"
model_size = "base"
exact = True
add_clsif = True
return_bert_byte_info(model_type, model_size, exact, add_clsif)
model_type = "bert"
model_size = "large"
return_bert_byte_info(model_type, model_size, exact, add_clsif)
model_type = "roberta"
model_size = "base"
return_bert_byte_info(model_type, model_size, exact, add_clsif)
model_type = "roberta"
model_size = "large"
return_bert_byte_info(model_type, model_size, exact, add_clsif)

In [None]:
V = 30522
S = 512
T = 2
dh = 768
H = 12
L = 12

In [None]:
model_type = "bert"
model_size = "base"
print(transformer_bytes(model_type, model_size, be=32, bl=1, add_clsif=True))
print(transformer_bytes(model_type, model_size, be=32, bl=32, add_clsif=True))
print(transformer_bytes(model_type, model_size, be=32, bl=1, add_clsif=True, bg=1, add_gnn=True))

In [None]:
437990480.0/106199138.5

In [None]:
model_type = "bert"
model_size = "base"
return_bittransgnn_byte_info(model_type, model_size)

In [None]:
model_type = "bert"
model_size_list = ["base", "large"]
for model_size in model_size_list:
    print(model_size)
    bert_ops_info(model_type, model_size, exact=False, inv=False)
    print("----------")

In [None]:
model_type = "bert"
model_size_list = ["base", "large"]
for model_size in model_size_list:
    print(model_size)
    bittransformer_energy_info(model_type, model_size, exact=True, inv=False)
    print("----------")

In [None]:
model_type = "bert"
#model_size_list = ["base", "large"]
model_size_list = ["base"]
dataset_name = "20ng"
dataset_conf = get_dataset_conf(dataset_name)
for model_size in model_size_list:
    print(model_size)
    bittransformer_energy_info(model_type, model_size, exact=True, inv=False, dataset_conf=dataset_conf, full_batch=True)
    print("----------")

In [None]:
model_type = "bert"
#model_size_list = ["base", "large"]
model_size = "base"
dataset_name = "20ng"
batch_size = 32
dataset_conf = get_dataset_conf(dataset_name)
train_type_list = ["static", "dynamic"]
gnn_conf = get_gnn_conf(dataset_name, model_size)
for train_type in train_type_list:
    print(model_size)
    print("BitTransGNN-" + train_type)
    bittransgnn_energy_info(model_type, model_size, exact=True, inv=False, full_batch=True, gnn_conf=gnn_conf, dataset_conf=dataset_conf, train_type=train_type, batch_size=batch_size)
    print("----------")

In [None]:
# we do not account for the num_sequences in this part
# only the energy spent for one sequence is calculated to stay dataset-independent for now
model_bits = []
model_energy = []

model_type = "bert"
model_size="base"
add_clsif = False
add_gnn = False

be=32
bg=32
bl=1
print("1-bit BitTransformer")
bittrans_bytes_1 = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
dtype = "float32"
bittrans_float32_energy_1 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
dtype = "int32"
bittrans_int32_energy_1 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
dtype = "int8"
bittrans_int8_energy_1 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
print(f"Memory consumption: {bittrans_bytes_1:e} bytes")
print(f"float32 energy: {bittrans_float32_energy_1:e} J")
print(f"int32 energy: {bittrans_int32_energy_1:e} J")
print(f"int8 energy: {bittrans_int8_energy_1:e} J")

be=32
bg=32
bl=1.58
print("1.58-bit BitTransformer")
bittrans_bytes_158 = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
dtype = "float32"
bittrans_float32_energy_158 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
dtype = "int32"
bittrans_int32_energy_158 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
dtype = "int8"
bittrans_int8_energy_158 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
print(f"Memory consumption: {bittrans_bytes_158:e} bytes")
print(f"float32 energy: {bittrans_float32_energy_158:e} J")
print(f"int32 energy: {bittrans_int32_energy_158:e} J")
print(f"int8 energy: {bittrans_int8_energy_158:e} J")


be=32
bg=32
bl=2
print("2-bit BitTransformer")
bittrans_bytes_2 = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
dtype = "float32"
bittrans_float32_energy_2 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
dtype = "int32"
bittrans_int32_energy_2 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
dtype = "int8"
bittrans_int8_energy_2 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
print(f"Memory consumption: {bittrans_bytes_2:e} bytes")
print(f"float32 energy: {bittrans_float32_energy_2:e} J")
print(f"int32 energy: {bittrans_int32_energy_2:e} J")
print(f"int8 energy: {bittrans_int8_energy_2:e} J")

be=32
bg=32
bl=2.32
print("2.32-bit BitTransformer")
bittrans_bytes_232 = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
dtype = "float32"
bittrans_float32_energy_232 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
dtype = "int32"
bittrans_int32_energy_232 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
dtype = "int8"
bittrans_int8_energy_232 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
print(f"Memory consumption: {bittrans_bytes_232:e} bytes")
print(f"float32 energy: {bittrans_float32_energy_232:e} J")
print(f"int32 energy: {bittrans_int32_energy_232:e} J")
print(f"int8 energy: {bittrans_int8_energy_232:e} J")

be=32
bg=32
bl=32
print("32-bit Full Precision Transformer (BERT)")
bert_fullprec_bytes = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
dtype = "float32"
bert_fullprec_energy_232 = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
print(f"Memory consumption: {bert_fullprec_bytes:e} bytes")
print(f"float32 energy: {bert_fullprec_energy_232:e} J")



In [None]:
# this accounts for the number of bits required to store the classifier

model_type = "bert"
model_size="base"
add_clsif = True
add_gnn_list = [False, True]
dataset_name = "20ng"
d_o = 20
dataset_conf = get_dataset_conf(dataset_name)
gnn_conf = get_gnn_conf(dataset_name, model_size)
for add_gnn in add_gnn_list:
    be=32
    bg=32
    bl=1
    print("1-bit BitTrans" + ("GNN" if add_gnn else ""))
    bittrans_bytes_1 = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
    print(f"Memory consumption: {bittrans_bytes_1:e} bytes")

    be=32
    bg=32
    bl=1.58
    print("1.58-bit BitTrans" + ("GNN" if add_gnn else ""))
    bittrans_bytes_158 = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
    print(f"Memory consumption: {bittrans_bytes_158:e} bytes")

    be=32
    bg=32
    bl=2
    print("2-bit BitTrans" + ("GNN" if add_gnn else ""))
    bittrans_bytes_2 = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
    print(f"Memory consumption: {bittrans_bytes_2:e} bytes")

    be=32
    bg=32
    bl=2.32
    print("2.32-bit BitTrans" + ("GNN" if add_gnn else ""))
    bittrans_bytes_232 = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
    print(f"Memory consumption: {bittrans_bytes_232:e} bytes")

    be=32
    bg=32
    bl=32
    print("32-bit Full Precision Transformer" + ("GNN" if add_gnn else ""))
    bert_fullprec_bytes = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
    print(f"Memory consumption: {bert_fullprec_bytes:e} bytes")

In [None]:
plt.style.use('default')

color = itertools.cycle(("red", "blue", "green", "purple", "orange"))
model_type = "bert"
model_size = "base"
marker_list = []
color_list = []
bit_data = []
energy_data = []
labels = []
bit_list = [1, 1.58, 2, 2.32, 32]
dtype_list = ["int8", "int32"]
for bl in bit_list:
    if bl == 32:
        dtype_list = ["float32"]
        marker = itertools.cycle(("^"))
    else:
        dtype_list = ["int8", "int32", "float32"]
        marker = itertools.cycle(('s', 'o', '*'))
    current_color = next(color)
    for dtype in dtype_list:
        current_marker = next(marker)
        marker_list.append(current_marker)
        color_list.append(current_color)
        bittrans_bytes = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
        bittrans_energy = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
        bit_data.append(bittrans_bytes*1e-6)
        energy_data.append(bittrans_energy*1e3)
        labels.append(f"{bl}-bit {dtype}")

figure_size = 5
plt.figure(figsize=(figure_size+3, figure_size))

plt.xlabel(r"Memory Consumption (Million Bytes)", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 16}, labelpad=10)
plt.xscale('linear')
plt.yscale('linear')
plt.ylabel("Energy Consumption (mJ)", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 16}, labelpad=10)
plt.grid(visible=True)
plt.xticks()

for i in range(len(labels)):
    plt.scatter(x=bit_data[i], y=energy_data[i], color=color_list[i], marker=marker_list[i])

plt.legend(labels, ncols=2, loc="lower right", fontsize="11")
plt.xticks(np.arange(100,501,50), np.arange(100, 501, 50), fontsize="13", fontfamily="times new roman")
plt.yticks(np.arange(0, 8, 1), np.arange(0,8,1), fontsize="13", fontfamily="times new roman")

plt.savefig("efficiency_plots/efficiency_plot_all.pdf", format="pdf", bbox_inches="tight")
plt.show()


In [None]:
model_type = "bert"
model_size = "base"
color = itertools.cycle(("red", "blue", "green", "purple", "orange", "black"))
marker_list = []
color_list = []
bit_data = []
energy_data = []
labels = []
bit_list = [1, 1.58, 2, 2.32]
dtype_list = ["int8", "int32"]
for bl in bit_list:
    if bl == 32:
        dtype_list = ["float32"]
        marker = itertools.cycle(('^'))
    else:
        dtype_list = ["int8", "int32", "float32"]
        marker = itertools.cycle(('s', 'o', '*'))
    current_color = next(color)
    for dtype in dtype_list:
        current_marker = next(marker)
        marker_list.append(current_marker)
        color_list.append(current_color)
        bittrans_bytes = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, add_gnn=add_gnn)
        bittrans_energy = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype)
        bit_data.append(bittrans_bytes*1e-6)
        energy_data.append(bittrans_energy*1e3)
        labels.append(f"{bl}-bit {dtype}")

figure_size = 5
plt.figure(figsize=(figure_size+3, figure_size))

plt.xlabel(r"Memory Consumption (Million Bytes)", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 16}, labelpad=10)
plt.xscale('linear')
plt.yscale('linear')
plt.ylabel("Energy Consumption (mJ)", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 16}, labelpad=10)
plt.grid(visible=True, axis="both")
plt.xticks()

for i in range(len(labels)):
    plt.scatter(x=bit_data[i], y=energy_data[i], color=color_list[i], marker=marker_list[i])

plt.legend(labels, ncols=2, fontsize="9.5", loc=(0.48,0.30))
plt.xticks(fontsize="13", fontfamily="times new roman")
plt.yticks(fontsize="13", fontfamily="times new roman")
plt.savefig("efficiency_plots/efficiency_plot_small_bits.pdf", format="pdf", bbox_inches="tight")

plt.show()

PERFORMANCE VS MEMORY PLOTS

In [None]:
model_type = "bert"
model_size = "base"
for_energy = False
small_list = [False, True]
datasets = ["20ng", "mr", "r8", "r52", "ohsumed"]
for small in small_list:
    for dataset_name in datasets:
        do=get_do(dataset_name)
        print(dataset_name)
        print(f"do={do}")
        add_clsif=True
        model_conf_list = get_model_conf_list(for_energy=for_energy, small=small)
        confs_bits_accs_list = get_acc_list(for_energy=for_energy, dataset_name=dataset_name, small=small)
        color = itertools.cycle(("red", "blue", "green", "purple", "orange", "black"))
        labels = []
        marker_list = []
        color_list = []
        bit_data = []
        accs = []

        for i in range(len(model_conf_list)):
            model_conf = model_conf_list[i]
            #if model_conf in ["BERT", "BERTGCN"]:
            if model_conf in ["Transformer", "BERTGCN"]:
                bit_list = [32]
                marker = itertools.cycle(('^'))
            else:
                bit_list = [1, 1.58, 2.32]
                marker = itertools.cycle(('s', 'o', '*'))
            current_color = next(color)
            conf_acc_list = confs_bits_accs_list[i]
            for j in range(len(bit_list)):
                bl = bit_list[j]
                acc = conf_acc_list[j]
                current_marker = next(marker)
                marker_list.append(current_marker)
                color_list.append(current_color)
                if "GCN" in model_conf or "GNN" in model_conf:
                    add_gnn = True
                else:
                    add_gnn = False
                bittrans_bytes = transformer_bytes(model_type, model_size=model_size, be=be, bg=bg, bl=bl, add_clsif=add_clsif, do=do, add_gnn=add_gnn)
                bit_data.append(bittrans_bytes*1e-6)
                if bl == 32:
                    label = f"{model_conf}"
                else:
                    label = f"{model_conf} {bl}-bit"
                labels.append(label)
                accs.append(acc)

        figure_size = 6
        plt.figure(figsize=(figure_size+2, figure_size))
        ax = plt.gca()
        for spine in ax.spines.values():
            spine.set_linewidth(2)

        plt.xlabel("Memory Consumption (Million Bytes)", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 15}, labelpad=10)
        #plt.xscale('log')
        plt.xscale('linear')
        plt.ylabel("Test Accuracy $(\%)$", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 15}, labelpad=10)
        plt.grid(visible=True, axis="both")
        plt.tight_layout()

        for i in range(len(labels)):
            if dataset_name == "20ng":
                if labels[i] == "BitTransformer 1.58-bit":
                    plt.scatter(x=bit_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=144, alpha=0.5)
                elif labels[i] == "DS 1.58-bit":
                    plt.scatter(x=bit_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=81, alpha=0.7)
                else:
                    plt.scatter(x=bit_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=100, alpha=1.0)

            else:
                plt.scatter(x=bit_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=100, alpha=1.0)

        plt.xticks(fontsize="12", fontweight="bold", fontfamily="times new roman")
        plt.yticks(fontsize="12", fontweight="bold", fontfamily="times new roman")

        if small:
            if dataset_name == "20ng":
                plt.legend(labels, ncols=2, fontsize="9", loc=(0.05,0.55), prop={"family": "times new roman", 'weight': 'bold'})
            elif dataset_name == "mr":
                plt.legend(labels, ncols=2, fontsize="9", loc="lower right", prop={"family": "times new roman", 'weight': 'bold'})
            elif dataset_name == "r8":
                plt.legend(labels, ncols=2, fontsize="9", loc=(0.35,0.65), prop={"family": "times new roman", 'weight': 'bold'})
            elif dataset_name == "r52":
                plt.legend(labels, ncols=2, fontsize="9", loc=(0.40,0.30), prop={"family": "times new roman", 'weight': 'bold'})
            elif dataset_name == "ohsumed":
                plt.legend(labels, ncols=2, fontsize="9", loc="lower right", prop={"family": "times new roman", 'weight': 'bold'})
            else:
                plt.legend(labels, ncols=2, fontsize="9", loc="lower right", prop={"family": "times new roman", 'weight': 'bold'})

            plt.savefig(f"efficiency_plots/efficiency_vs_acc_{dataset_name}_plot_small_bits.pdf", format="pdf", bbox_inches="tight")
        else:
            plt.legend(labels, ncols=2, fontsize="10", loc="lower right", prop={"family": "times new roman", 'weight': 'bold'})
            plt.savefig(f"efficiency_plots/efficiency_vs_acc_{dataset_name}_plot_all.pdf", format="pdf", bbox_inches="tight")
        

        plt.show()


PERFORMANCE VS ENERGY PLOTS

In [None]:
#Performance vs Energy
bert_only = True
small = False
for_energy = True
datasets = ["20ng", "mr", "r8", "r52", "ohsumed"]
dtype_list = ["float32", "int32", "int8"]
for dtype in dtype_list:
    for dataset_name in datasets:
        #do=get_do(dataset_name)
        dataset_conf = get_dataset_conf(dataset_name)
        do, num_sequences = dataset_conf["do"], dataset_conf["num_sequences"]
        print(dataset_name)
        print(f"do={do}")
        add_clsif=True
        model_conf_list = get_model_conf_list(for_energy=for_energy, small=small, bert_only=bert_only)
        confs_bits_accs_list = get_acc_list(for_energy=for_energy, dataset_name=dataset_name, small=small, bert_only=bert_only)

        #color = itertools.cycle(("red", "blue", "green", "orange"))
        color = itertools.cycle(("red", "blue", "green", "purple", "orange", "black"))
        labels = []
        marker_list = []
        color_list = []
        bit_data = []
        energy_data = []
        accs = []

        for i in range(len(model_conf_list)):
            model_conf = model_conf_list[i]
            if "(d)" in model_conf:
                train_type = "dynamic"
            elif "(s)" in model_conf:
                train_type = "static"
            else:
                train_type = None
            #if model_conf in ["BERT", "BERTGCN"]:
            if model_conf in ["Transformer", "BERTGCN"]:
                bit_list = [32]
                marker = itertools.cycle(('^'))
                dtype_manual = "float32"
            else:
                bit_list = [1, 1.58, 2.32]
                marker = itertools.cycle(('s', 'o', '*'))
                dtype_manual = dtype
            current_color = next(color)
            conf_acc_list = confs_bits_accs_list[i]
            for j in range(len(bit_list)):
                bl = bit_list[j]
                acc = conf_acc_list[j]
                #current_color = next(color)
                current_marker = next(marker)
                marker_list.append(current_marker)
                color_list.append(current_color)
                if "GCN" in model_conf or "GNN" in model_conf:
                    add_gnn = True
                    gnn_bits = 32
                    gnn_conf = get_gnn_conf(dataset_name, bert_model_size=model_size)
                    gnn_conf["gnn_bits"] = gnn_bits
                else:
                    add_gnn = False
                    gnn_conf = None
                bittrans_energy = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype_manual, dataset_conf=dataset_conf, add_gnn=add_gnn, gnn_conf=gnn_conf)
                energy_data.append(bittrans_energy)
                if bl == 32:
                    label = f"{model_conf}"
                else:
                    label = f"{model_conf} {bl}-bit"
                labels.append(label)
                accs.append(acc)

        figure_size = 6
        plt.figure(figsize=(figure_size+2, figure_size))
        ax = plt.gca()
        for spine in ax.spines.values():
            spine.set_linewidth(2)

        plt.xlabel("Energy Consumption (J)", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 15}, labelpad=10)
        plt.xscale('linear')
        plt.yscale('linear')
        plt.ylabel("Test Accuracy $(\%)$", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 15}, labelpad=10)
        plt.grid(visible=True, axis="both")
        plt.tight_layout()

        for i in range(len(labels)):
            if dataset_name == "20ng":
                if labels[i] == "BitTransformer 1.58-bit":
                    plt.scatter(x=energy_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=144, alpha=0.5)
                elif labels[i] == "DS 1.58-bit":
                    plt.scatter(x=energy_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=81, alpha=0.7)
                else:
                    plt.scatter(x=energy_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=100, alpha=1.0)
            else:
                plt.scatter(x=energy_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=100, alpha=1.0)

        plt.legend(labels, ncols=2, loc="lower right", fontsize="10", prop={"family": "times new roman", 'weight': 'bold'})
        plt.xticks(fontsize="12", fontweight="bold", fontfamily="times new roman")
        plt.yticks(fontsize="12", fontweight="bold", fontfamily="times new roman")

        if dtype == "float32":
            plt.savefig(f"efficiency_plots/energy_vs_acc_{dataset_name}_plot_all.pdf", format="pdf", bbox_inches="tight")
        else:
            plt.savefig(f"efficiency_plots/energy_vs_acc_{dataset_name}_plot_{dtype}_all.pdf", format="pdf", bbox_inches="tight")

        plt.show()


In [None]:
#Performance vs Energy
# \w GNN
bert_only = False
small = False
#train_type = "static"
for_energy = True
full_batch = True
datasets = ["20ng", "mr", "r8", "r52", "ohsumed"]
dtype_list = ["float32", "int32", "int8"]
for dataset_name in datasets:
    print(dataset_name)
    print(f"do={do}")
    for dtype in dtype_list:
        print("dtype:", dtype)
        #do=get_do(dataset_name)
        dataset_conf = get_dataset_conf(dataset_name)
        do, num_sequences = dataset_conf["do"], dataset_conf["num_sequences"]
        add_clsif=True
        model_conf_list = get_model_conf_list(for_energy=for_energy, small=small, bert_only=bert_only)
        confs_bits_accs_list = get_acc_list(for_energy=for_energy, dataset_name=dataset_name, small=small, bert_only=bert_only)

        #color = itertools.cycle(("red", "blue", "green", "purple", "orange", "brown", "black"))
        color = itertools.cycle(("red", "blue", "green", "purple", "orange", "black"))
        labels = []
        marker_list = []
        color_list = []
        bit_data = []
        energy_data = []
        accs = []

        for i in range(len(model_conf_list)):
            model_conf = model_conf_list[i]
            if "(d)" in model_conf:
                train_type = "dynamic"
            elif "(s)" in model_conf:
                train_type = "static"
            else:
                train_type = None
            #if model_conf in ["BERT", "BERTGCN", "BERTGCN (d)"]:
            if model_conf in ["Transformer", "BERTGCN", "BERTGCN (d)"]:
                bit_list = [32]
                marker = itertools.cycle(('^'))
                dtype_manual = "float32"
            else:
                bit_list = [1, 1.58, 2.32]
                marker = itertools.cycle(('s', 'o', '*'))
                dtype_manual = dtype
            current_color = next(color)
            conf_acc_list = confs_bits_accs_list[i]
            for j in range(len(bit_list)):
                bl = bit_list[j]
                acc = conf_acc_list[j]
                #current_color = next(color)
                current_marker = next(marker)
                marker_list.append(current_marker)
                color_list.append(current_color)
                if "GCN" in model_conf or "GNN" in model_conf:
                    add_gnn = True
                    gnn_bits = 32
                    gnn_conf = get_gnn_conf(dataset_name, bert_model_size=model_size)
                    gnn_conf["gnn_bits"] = gnn_bits
                else:
                    add_gnn = False
                    gnn_conf = None
                bittrans_energy = get_bittransformer_total_energy(model_type=model_type, model_size=model_size, bits=bl, dtype=dtype_manual, dataset_conf=dataset_conf, add_gnn=add_gnn, gnn_conf=gnn_conf, full_batch=full_batch, train_type=train_type)
                energy_data.append(bittrans_energy)
                if bl == 32:
                    label = f"{model_conf}"
                else:
                    label = f"{model_conf} {bl}-bit"
                labels.append(label)
                accs.append(acc)

        figure_size = 6
        plt.figure(figsize=(figure_size+2, figure_size))
        ax = plt.gca()
        for spine in ax.spines.values():
            spine.set_linewidth(2)

        #plt.xlabel("Energy Consumption (mJ)", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 15}, labelpad=10)
        plt.xlabel("Energy Consumption (J)", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 15}, labelpad=10)
        plt.xscale('linear')
        plt.yscale('linear')
        plt.ylabel("Test Accuracy $(\%)$", fontdict={"family": "times new roman", "weight": "bold", "fontsize": 15}, labelpad=10)
        plt.grid(visible=True, axis="both")
        plt.tight_layout()

        for i in range(len(labels)):
            if dataset_name == "20ng":
                if labels[i] == "BitTransformer 1.58-bit":
                    plt.scatter(x=energy_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=144, alpha=0.5)
                elif labels[i] == "DS 1.58-bit":
                    plt.scatter(x=energy_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=81, alpha=0.7)
                else:
                    plt.scatter(x=energy_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=100, alpha=1.0)
            else:
                plt.scatter(x=energy_data[i], y=accs[i], color=color_list[i], marker=marker_list[i], s=100, alpha=1.0)

        plt.legend(labels, ncols=2, loc="lower right", fontsize="10", prop={"family": "times new roman", 'weight': 'bold'})
        plt.xticks(fontsize="12", fontweight="bold", fontfamily="times new roman")
        plt.yticks(fontsize="12", fontweight="bold", fontfamily="times new roman")

        if dtype == "float32":
            plt.savefig(f"efficiency_plots/energy_vs_acc_wgnn_{dataset_name}_plot_all.pdf", format="pdf", bbox_inches="tight")
        else:
            plt.savefig(f"efficiency_plots/energy_vs_acc_wgnn_{dataset_name}_plot_{dtype}_all.pdf", format="pdf", bbox_inches="tight")

        plt.show()
