In [None]:
import logging
import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
import pandas as pd
import json
import csv
import datasets
from datasets import load_dataset
from datasets import Dataset, DatasetDict, Features, Value

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_MASKED_LM_MAPPING,
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from huggingface_hub import notebook_login

from libs import *
from modelings.modelings_bert import *
from modelings.modelings_roberta import *
from modelings.modelings_gpt2 import *
from modelings.modelings_lstm import *

In [None]:
datasets.logging.disable_progress_bar()

In [None]:
notebook_login()

push model

In [None]:
"""
The following blocks will run CEBaB benchmark in
all the combinations of the following conditions.
"""
grid = {
    "eval_split": ["test"],
    # dev,test
    "control": ["ks"],
    # baseline-random,baseline-blackbox,hdims,layers,ks,approximate,ablation
    "seed": [42, 66, 77],
    # 42, 66, 77
    "h_dim": [192],
    # 1,16,64,128,192
    # 1,16,64,75
    "interchange_layer" : [10],
    # 0,1; 2,4,6,8,10,12
    "class_num": [5],
    "k" : [19684], 
    # 0;10,100,500,1000,3000,6000,9848,19684
    "alpha" : [1.0],
    # 0.0,1.0
    "beta" : [1.0],
    # 0.0,1.0
    "gemma" : [3.0],
    # 0.0,3.0
    "model_arch" : ["bert-base-uncased"],
    # lstm, bert-base-uncased, roberta-base, gpt2
    "lr" : ["8e-05"],
    # 8e-05; 0.001
    "counterfactual_type" : ["true"],
    # approximate,true
    "random_source" : [False]
}

keys, values = zip(*grid.items())
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]

device = 'cuda:1'
batch_size = 32

if grid["control"][0] == "hdims" or grid["control"][0] == "layers":
    assert grid["eval_split"][0] == "dev"
else:
    assert grid["eval_split"][0] == "test"

In [None]:
for i in range(len(permutations_dicts)):
    
    eval_split=permutations_dicts[i]["eval_split"]
    seed=permutations_dicts[i]["seed"]
    class_num=permutations_dicts[i]["class_num"]
    alpha=permutations_dicts[i]["alpha"]
    beta=permutations_dicts[i]["beta"]
    gemma=permutations_dicts[i]["gemma"]
    h_dim=permutations_dicts[i]["h_dim"]
    dataset_type = f'{class_num}-way'
    control=permutations_dicts[i]["control"]
    model_arch=permutations_dicts[i]["model_arch"]
    k=permutations_dicts[i]["k"]
    interchange_layer=permutations_dicts[i]["interchange_layer"]
    lr=permutations_dicts[i]["lr"]
    counterfactual_type=permutations_dicts[i]["counterfactual_type"]
    random_source=permutations_dicts[i]["random_source"]
    
    if model_arch == "bert-base-uncased":
        model_path = "BERT"
        model_module = BERTForCEBaB
        explainer_module = CausalProxyModelForBERT
    elif model_arch == "roberta-base":
        model_path = "RoBERTa" 
        model_module = RoBERTaForCEBaB
        explainer_module = CausalProxyModelForRoBERTa
    elif model_arch == "gpt2":
        model_path = "gpt2"
        model_module = GPT2ForCEBaB
        explainer_module = CausalProxyModelForGPT2
    elif model_arch == "lstm":
        model_path = "lstm"
        model_module = LSTMForCEBaB
        explainer_module = CausalProxyModelForLSTM
    model_path += f"-{control}"
    grid_conditions=(
        ("eval_split", eval_split),
        ("control", control),
        ("seed", seed),
        ("h_dim", h_dim),
        ("interchange_layer", interchange_layer),
        ("class_num", class_num),
        ("k", k),
        ("alpha", alpha),
        ("beta", beta),
        ("gemma", gemma),
        ("model_arch", model_arch),
        ("lr", lr),
        ("counterfactual_type", counterfactual_type)
    )
    print("Running for this setting: ", grid_conditions)

    blackbox_model_path = f'CEBaB/{model_arch}.CEBaB.sa.'\
                          f'{class_num}-class.exclusive.seed_{seed}'
    cpm_model_path = f'../proxy_training_results/{model_path}/'\
                     f'cebab.alpha.{alpha}.beta.{beta}.gemma.{gemma}.'\
                     f'lr.{lr}.dim.{h_dim}.hightype.{model_arch}.'\
                     f'CEBaB.cls.dropout.0.1.enc.dropout.0.1.counter.type.'\
                     f'{counterfactual_type}.k.{k}.int.layer.{interchange_layer}.'\
                     f'seed_{seed}/'
    hub_name = f'cpm.{model_arch}.best.seed.{seed}'
    explainer = explainer_module(
        blackbox_model_path,
        cpm_model_path, 
        device=device, 
        batch_size=batch_size,
        intervention_h_dim=h_dim,
        random_source=random_source
    )
    
    explainer.cpm_model.model.push_to_hub(hub_name, organization="CPMs")