In [1]:
import pandas as pd
import wandb
from pathlib import Path
import json
import numpy as np

In [2]:
method="procrustes" 
extend_exp="poly"
vm_name = {
    'resnet101': 'ResNet101',
    'resnet152': 'ResNet152',
    'resnet18': 'ResNet18',
    'resnet34': 'ResNet34',
    'resnet50': 'ResNet50',
    'segformer-b0-finetuned-ade-512-512': 'SegFormer<sub>B0</sub>',
    'segformer-b1-finetuned-ade-512-512': 'SegFormer<sub>B1</sub>',
    'segformer-b2-finetuned-ade-512-512': 'SegFormer<sub>B2</sub>',
    'segformer-b3-finetuned-ade-512-512': 'SegFormer<sub>B3</sub>',
    'segformer-b4-finetuned-ade-512-512': 'SegFormer<sub>B4</sub>',
    'segformer-b5-finetuned-ade-640-640': 'SegFormer<sub>B5</sub>',
    'vit-mae-base': 'MAE<sub>Base</sub>',
    'vit-mae-huge': 'MAE<sub>Huge</sub>',
    'vit-mae-large': 'MAE<sub>Large</sub>'
}

model_name = {'bert_uncased_L-2_H-128_A-2': "BERT<sub>TINY</sub>",
                'bert_uncased_L-4_H-256_A-4': "BERT<sub>MINI</sub>" ,
                'bert_uncased_L-4_H-512_A-8': "BERT<sub>SMALL</sub>",
                'bert_uncased_L-8_H-512_A-8': "BERT<sub>MEDIUM</sub>",
                'bert-base-uncased': "BERT<sub>BASE</sub>",
                'bert-large-uncased': "BERT<sub>LARGE</sub>",
                'gpt2': "GPT2<sub>BASE</sub>",
                'gpt2-medium': "GPT2<sub>MEDIUM</sub>",
                'gpt2-large': "GPT2<sub>LARGE</sub>",
                'gpt2-xl': "GPT2<sub>XL</sub>",
                'opt-125m': "OPT<sub>125M</sub>",
                'opt-1.3b': "OPT<sub>1.3B</sub>",
                'opt-6.7b': "OPT<sub>6.7B</sub>",
                'opt-30b': "OPT<sub>30B</sub>",
                'opt-66b': "OPT<sub>66B</sub>",
                'Llama-2-7b-hf': "Llama-2<sub>7B</sub>",
                'Llama-2-13b-hf': "Llama-2<sub>13B</sub>",
                'Llama-2-70b-hf': "Llama-2<sub>70B</sub>",
                "fasttext":"fastText"}

model_size = {'bert_uncased_L-2_H-128_A-2':4.4, 'bert_uncased_L-4_H-256_A-4':11.3, 'bert_uncased_L-4_H-512_A-8':29.1, 
    'bert_uncased_L-8_H-512_A-8':41.7, 'bert-base-uncased':110, 'bert-large-uncased':340, 
    'gpt2':117,  'gpt2-large':762, 'gpt2-xl':1542, 'opt-125m':125, 'opt-6.7b':6700, 'opt-30b':30000,'opt-66b':66000,
    'Llama-2-7b-hf': 7000,
    'Llama-2-13b-hf': 13000,
    'Llama-2-70b-hf': 70000,}

model_size = {'bert_uncased_L-2_H-128_A-2':4.4, 'bert_uncased_L-4_H-256_A-4':11.3, 'bert_uncased_L-4_H-512_A-8':29.1, 
    'bert_uncased_L-8_H-512_A-8':41.7, 'bert-base-uncased':110, 'bert-large-uncased':340, 
    'gpt2':117,  'gpt2-large':762, 'gpt2-xl':1542, 'opt-125m':125, 'opt-6.7b':6700, 'opt-30b':30000,'opt-66b':66000,
    'Llama-2-7b-hf': 7000,
    'Llama-2-13b-hf': 13000,
    'Llama-2-70b-hf': 70000}

model_tags = {
    "ft": "fasttext",
    "bert": "bert",
    "gpt2": "gpt2",
    "opt": "opt"
}

bins = {
    "_1":"1",
    "_2_or_3": "2-3",
    "_over_3":"4+"
}

bins_order = {
    "_1":"1",
    "_2_or_3": "2",
    "_over_3":"3"
}

set_name = "cleaned"

In [3]:


api = wandb.Api()
df = pd.DataFrame()

for model_alias in ["LM"]:
    # method_name = "Procrustes Analysis" if method == "procrustes" else "Ridge Regression"
    runs = api.runs(path=f"jiaangli00/image2{model_alias}-TACL-{set_name}{extend_exp}")
    # metrics = runs.summary["Results"]

    # data = []
    for i, single_run in enumerate(runs):
        metric = json.load(single_run.file(single_run.summary["Results"]["path"]).download(exist_ok=True))
        df = pd.concat([df,pd.DataFrame(metric["data"], columns=metric["columns"])])

# # df['Subjects'] = df['Subjects'].replace({f'brain_{i}': f'Subject-{i}' for i in range(1, num_subs + 1)})
# df['Layers'] = df['Layers'].replace({f'layer_{i}': f'layer-{i})
df['Models_size'] = df['LM'].copy()
df['Models_size'] = df['Models_size'].replace(model_size)
df['VM'] = df['VM'].replace(vm_name)
df['LM'] = df['LM'].replace(model_name)
df['Bins_order'] = df['Bins'].copy()
df['Bins_order'] = df['Bins_order'].replace(bins_order)
df["Bins"] = df["Bins"].replace(bins)
# #
group_names: list[str] = ["LM", "VM"]
if extend_exp != "":
    group_names.append("Bins")
    group_names.append("Bins_order")
#
precision_csls, precision_nn = [], []
for k in [100]:
    precision_csls.append(f'P@{k}-CSLS')
    precision_nn.append(f'P@{k}-NN')

precision = precision_csls
#
precision.append("Models_size")

In [4]:
seed_avg = df.groupby(group_names)[precision].mean().round(1).reset_index()
_selected = seed_avg.loc[(seed_avg["LM"] == "BERT_LARGE") | (seed_avg["LM"] == "GPT2_XL") | (seed_avg["LM"] == "OPT_30B") | (seed_avg["LM"] == "Llama-2-13B")]
# _selected.to_csv("poly.csv")

In [5]:
table = _selected.loc[(seed_avg["VM"] == "MAE-Huge") | (seed_avg["VM"] == "SegFormer-B5") | (seed_avg["VM"] == "ResNet152") ].round(1)

In [6]:
table

Unnamed: 0,LM,VM,Bins,Bins_order,P@100-CSLS,Models_size


In [7]:
table2 = seed_avg.loc[(seed_avg["VM"] != "MAE<sub>Huge</sub>") & (seed_avg["VM"] != "SegFormer<sub>B5</sub>") & (seed_avg["VM"] != "ResNet152") ]
table2 = table2.sort_values(by=["Bins_order",'Models_size'], ascending=[True,True])
table2

Unnamed: 0,LM,VM,Bins,Bins_order,P@100-CSLS,Models_size
0,BERT<sub>LARGE</sub>,MAE<sub>Base</sub>,1,1,33.8,340.0
6,BERT<sub>LARGE</sub>,MAE<sub>Large</sub>,1,1,36.8,340.0
9,BERT<sub>LARGE</sub>,ResNet101,1,1,45.9,340.0
15,BERT<sub>LARGE</sub>,ResNet18,1,1,39.3,340.0
18,BERT<sub>LARGE</sub>,ResNet34,1,1,44.9,340.0
...,...,...,...,...,...,...
152,OPT<sub>30B</sub>,SegFormer<sub>B0</sub>,4+,3,18.3,30000.0
155,OPT<sub>30B</sub>,SegFormer<sub>B1</sub>,4+,3,22.2,30000.0
158,OPT<sub>30B</sub>,SegFormer<sub>B2</sub>,4+,3,24.1,30000.0
161,OPT<sub>30B</sub>,SegFormer<sub>B3</sub>,4+,3,24.3,30000.0


In [11]:
import plotly.express as px
# import plotly.graph_objects as go

# fig = go.Figure()
traces = []
# df['VM_Bins'] = table2['VM'] + '_' + table2['Bins']
color_scale = px.colors.qualitative.Set3

# Create stacked and grouped bar chart
fig = px.bar(table2, x='VM', y='P@100-CSLS', color='Bins', facet_col="LM",
             labels={'P@100-CSLS': 'P@100'},
            #  facet_col_wrap=3,
            #  category_orders={'VM': ['MAE-Huge', 'ResNet152', 'SegFormer-B5']},
             color_discrete_map={'1': color_scale[5],"2-3":color_scale[0],"4+":color_scale[3]},
             height=500, width=1500,
            #  facet_col_titles='none'
            #  barmode="group"
             )


for i, trace in enumerate(fig.data):
    # trace.marker.color = color_scale[i%3]  # Set color
    trace.marker.opacity = 0.65  # Set opacity


# Update layout for better appearance
fig.update_layout(barmode='overlay', showlegend=True,
                  font=dict(size=30),
                  yaxis=dict(title=dict(text='P@100', font=dict(size=30)), tickfont=dict(size=25)),  # Change y-axis font size
                  legend=dict(title=dict(text='Polysemy', font=dict(size=30))),
                  xaxis=dict(title=dict(text=''),tickfont=dict(size=20)),
                  xaxis2=dict(title=dict(text=''),tickfont=dict(size=20)),
                  xaxis3=dict(title=dict(text=''),tickfont=dict(size=20)),
                  xaxis4=dict(title=dict(text=''),tickfont=dict(size=20)),
                  template="plotly_white"
                  ) 

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

# fig.update_xaxes(title_text='', showgrid=False)

# fig.update_xaxes(title_font=dict(size=20), tickfont=dict(size=20))

# Show the plot
fig.show()
# fig.write_image(f"poly_remain_{set_name}.pdf")

In [7]:
# table.to_csv("poly.csv")

In [8]:
for j in ["BERT_LARGE", "GPT2_XL", "OPT_30B", "Llama-2-13B"]:
    print(j)
    for i in ["_1", "_2_or_3", "_over_3"]:
        all_values = table.loc[(_selected["LM"] == j) & (_selected["Bins"] == i), precision].values
        formatted_values = " & ".join(f"{value}" for one_model_value in all_values for value in one_model_value)
        print(f"& {i[1:]} & " + formatted_values )


BERT_LARGE
& 1 & 38.1 & 47.0 & 42.6
& 2_or_3 & 24.6 & 29.9 & 28.5
& over_3 & 17.5 & 23.7 & 19.5
GPT2_XL
& 1 & 42.1 & 52.4 & 48.9
& 2_or_3 & 28.0 & 34.9 & 32.7
& over_3 & 15.7 & 22.8 & 21.0
OPT_30B
& 1 & 47.4 & 57.2 & 55.9
& 2_or_3 & 31.8 & 39.6 & 34.4
& over_3 & 19.9 & 25.4 & 25.2
Llama-2-13B
& 1 & 40.6 & 51.8 & 48.9
& 2_or_3 & 26.1 & 35.5 & 31.3
& over_3 & 14.8 & 21.3 & 18.1
