In [None]:
import torch
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import json

from collections import defaultdict
from src.models.debiaser_pruned import DebiaserPruned
from nn_pruning.model_structure import BertStructure
from analysis.model_card_graphics import DensityBokehPlotter

In [None]:
block_name = 'b32'
models = [
    ('last-sentence', '/remote/csifs1/disk1/przm/models/pruned/B32/bert-base-uncased/last-layer/sentence-debias/multiruns/2021-11-16_14-08-39/0/checkpoints/epoch=99-step=24599.ckpt'),
    ('last-token', '/remote/csifs1/disk1/przm/models/pruned/B32/bert-base-uncased/last-layer/token-debias/multiruns/2021-11-16_14-08-39/1/checkpoints/epoch=99-step=24599.ckpt'),
    ('all-sentence', '/remote/csifs1/disk1/przm/models/pruned/B32/bert-base-uncased/all-layer/sentence-debias/multiruns/2021-11-16_14-10-53/0/checkpoints/epoch=99-step=24599.ckpt'),
    ('all-token', '/remote/csifs1/disk1/przm/models/pruned/B32/bert-base-uncased/all-layer/token-debias/multiruns/2021-11-16_14-10-53/1/checkpoints/epoch=99-step=24599.ckpt'),
]

In [None]:
# block_name = 'b64'
# models = [
#     ('last-sentence', '/remote/csifs1/disk1/przm/models/pruned/B64/bert-base-uncased/last-layer/sentence-debias/multiruns/2021-11-17_21-43-46/0/checkpoints/epoch=99-step=24599.ckpt'),
#     ('last-token', '/remote/csifs1/disk1/przm/models/pruned/B64/bert-base-uncased/last-layer/token-debias/multiruns/2021-11-17_21-43-46/1/checkpoints/epoch=99-step=24599.ckpt'),
#     ('all-sentence', '/remote/csifs1/disk1/przm/models/pruned/B64/bert-base-uncased/all-layer/sentence-debias/multiruns/2021-11-17_21-43-46/2/checkpoints/epoch=99-step=24599.ckpt'),
#     ('all-token', '/remote/csifs1/disk1/przm/models/pruned/B64/bert-base-uncased/all-layer/token-debias/multiruns/2021-11-17_21-43-46/3/checkpoints/epoch=99-step=24599.ckpt'),
# ]

In [None]:
# block_name = 'b128'
# models = [
#     ('last-sentence', '/remote/csifs1/disk1/przm/models/pruned/B128/bert-base-uncased/last-layer/sentence-debias/multiruns/2021-11-24_22-02-04/0/checkpoints/epoch=99-step=24599.ckpt'),
#     ('last-token', '/remote/csifs1/disk1/przm/models/pruned/B128/bert-base-uncased/last-layer/token-debias/multiruns/2021-11-24_22-02-04/1/checkpoints/epoch=99-step=24599.ckpt'),
#     ('all-sentence', '/remote/csifs1/disk1/przm/models/pruned/B128/bert-base-uncased/all-layer/sentence-debias/multiruns/2021-11-24_22-02-04/2/checkpoints/epoch=99-step=24599.ckpt'),
#     ('all-token', '/remote/csifs1/disk1/przm/models/pruned/B128/bert-base-uncased/all-layer/token-debias/multiruns/2021-11-24_22-02-04/3/checkpoints/epoch=99-step=24599.ckpt'),
# ]

In [None]:
# block_name = 'b64x768-Vonly'
# models = [
#     ('last-sentence', '/remote/csifs1/disk1/przm/models/pruned/B64x768/V-only_no-fc/bert-base-uncased/last-layer/sentence-debias/multiruns/2021-12-14_22-52-38/0/checkpoints/epoch=99-step=24599.ckpt'),
#     ('last-token', '/remote/csifs1/disk1/przm/models/pruned/B64x768/V-only_no-fc/bert-base-uncased/last-layer/token-debias/multiruns/2021-12-14_22-52-38/1/checkpoints/epoch=99-step=24599.ckpt'),
#     ('all-sentence', '/remote/csifs1/disk1/przm/models/pruned/B64x768/V-only_no-fc/bert-base-uncased/all-layer/sentence-debias/multiruns/2021-12-14_22-52-38/2/checkpoints/epoch=99-step=24599.ckpt'),
#     ('all-token', '/remote/csifs1/disk1/przm/models/pruned/B64x768/V-only_no-fc/bert-base-uncased/all-layer/token-debias/multiruns/2021-12-14_22-52-38/3/checkpoints/epoch=99-step=24599.ckpt'),
# ]

In [None]:
# block_name = 'b64x768-att-shared'
# models = [
#     ('last-sentence', '/remote/csifs1/disk1/przm/models/pruned/B64x768/att-only-shared/bert-base-uncased/last-layer/sentence-debias/multiruns/2021-12-14_23-57-20/0/checkpoints/epoch=99-step=24599.ckpt'),
#     ('last-token', '/remote/csifs1/disk1/przm/models/pruned/B64x768/att-only-shared/bert-base-uncased/last-layer/token-debias/multiruns/2021-12-14_23-57-20/1/checkpoints/epoch=99-step=24599.ckpt'),
#     ('all-sentence', '/remote/csifs1/disk1/przm/models/pruned/B64x768/att-only-shared/bert-base-uncased/all-layer/sentence-debias/multiruns/2021-12-14_23-57-20/2/checkpoints/epoch=99-step=24599.ckpt'),
#     ('all-token', '/remote/csifs1/disk1/przm/models/pruned/B64x768/att-only-shared/bert-base-uncased/all-layer/token-debias/multiruns/2021-12-14_23-57-20/3/checkpoints/epoch=99-step=24599.ckpt'),
# ]

In [None]:
def get_avg_density_per_layer(layer_info):
    layer_avg_density = defaultdict(lambda: 0)
    layer_total_numel = defaultdict(lambda: 0)

    for info in layer_info:
        name = info['name']
        density = info['density']
        size = info['size']
        
        numel = size[0] * size[1]
        nnz = density * numel

        layer_idx = BertStructure.layer_index(name)
        
        layer_avg_density[layer_idx] += nnz
        layer_total_numel[layer_idx] += numel
        

    for key, val in layer_avg_density.items():
        layer_avg_density[key] = val / layer_total_numel[key]

    return layer_avg_density

In [None]:
per_model = {}

dbp = DensityBokehPlotter("density", "$$JS_PATH$$")


for model_name, ckpt_path in models:
    
    print(f'Proessing {model_name}')
    model = DebiaserPruned.load_from_checkpoint(ckpt_path)
    removed_heads, total_heads = model.model_patcher.compile_model(model.model_debias.model)
    
    print(f'{model_name}: removed {removed_heads}/{total_heads}')
    
    fig, _, _ = dbp.run(model.model_debias.model, 'tmp/imgs', 'tmp/imgs')
    
    per_model[model_name] = get_avg_density_per_layer(dbp.layers)
    
    # dump to json
    payload = {}
    payload['model_name'] = model_name
    payload['checkpoint'] = ckpt_path
    payload['heads_pruned'] = removed_heads
    payload['total_heads'] = total_heads
    payload['layers'] = dbp.layers

    fname = f"data/densities/{block_name}/{model_name}.json"

    with open(fname, 'w+') as f:
        json.dump(payload, f, indent=4)

In [None]:
df = pd.DataFrame(per_model)
df.head()

In [None]:
sns.set_theme()
sns.despine(left=True, bottom=True)


sns.set(rc={'figure.figsize': (12, 4)})


fig, ax = plt.subplots()

ax = sns.lineplot(
    data=df,
    dashes=False,
    linestyle='-',
    markers=['o'] * len(per_model)
)

ax.set(xlabel='Layer', ylabel='Average Density (%)')

plt.savefig(f"data/fig/avg-sparsity-{block_name}.pdf", bbox_inches='tight')
# plt.savefig("data/fig/avg-sparsity-b32.svg", bbox_inches='tight')
# plt.savefig("data/fig/avg-sparsity-b32.png", bbox_inches='tight')

plt.show();

# TODO
 - [ ] https://seaborn.pydata.org/generated/seaborn.set_palette.html#seaborn.set_palette

### Visualize (HF code)

In [None]:
# from bokeh.plotting import show, figure
# from bokeh.io import output_notebook, reset_output

# try:
#     reset_output()
#     output_notebook()
#     show(fig)
# except:
#     output_notebook()
#     show(fig)

# print('blue is preserved, pink is pruned')