### This file loads the proxy model trained with IIT, and evaluate it for concept-based explanations with the CEBaB dataset.
This file tends to follow the one in our OpenTable repo.

In [10]:
import json
import os 
import pandas as pd
import datasets
from collections import defaultdict
import numpy as np
import random
from tqdm import tqdm

import torch
from transformers import (
    AutoModelForSequenceClassification, 
    AutoTokenizer, 
)
from math import ceil

import sklearn
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report

from models.modelings_roberta import *
from models.modelings_bert import *

from eval_pipeline.models.abstract_model import Model 
from eval_pipeline.explainers.abstract_explainer import Explainer
from eval_pipeline.utils.data_utils import preprocess_hf_dataset
from eval_pipeline.customized_models.bert import BertForNonlinearSequenceClassification

In [11]:
# TODO: set random seeds.
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f2958709e10>

In [32]:
class BERTForCEBaB(Model):
    def __init__(self, model_path, device='cpu', batch_size=64):
        self.device = device
        self.model_path = model_path
        self.tokenizer_path = model_path
        self.batch_size = batch_size

        self.model = BertForNonlinearSequenceClassification.from_pretrained(
            self.model_path,
            cache_dir="../huggingface_cache"
        )
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

        self.model.to(device)

    def __str__(self):
        return self.model_path.split('/')[-1]

    def preprocess(self, df):
        x = self.tokenizer(df['description'].to_list(), padding=True, truncation=True, return_tensors='pt')
        y = df['review_majority'].astype(int)

        return x, y

    def fit(self, dataset):
        # assume model was already trained
        pass

    def predict_proba(self, dataset):
        self.model.eval()

        x, y = self.preprocess(dataset)

        # get the predictions batch per batch
        probas = []
        for i in range(ceil(len(dataset) / self.batch_size)):
            x_batch = {k: v[i * self.batch_size:(i + 1) * self.batch_size].to(self.device) for k, v in x.items()}
            probas.append(torch.nn.functional.softmax(self.model(**x_batch).logits.cpu(), dim=-1).detach())

        probas = torch.concat(probas)
        probas = np.round(probas.numpy(), decimals=4)

        predictions = np.argmax(probas, axis=1)
        clf_report = classification_report(y.to_numpy(), predictions, output_dict=True)

        return probas, clf_report

    def get_embeddings(self, sentences_list):
        x = self.tokenizer(sentences_list, padding=True, truncation=True, return_tensors='pt')
        embeddings = []
        for i in range(ceil(len(x['input_ids']) / self.batch_size)):
            x_batch = {k: v[i * self.batch_size:(i + 1) * self.batch_size].to(self.device) for k, v in x.items()}
            embeddings.append(self.model.base_model(**x_batch).pooler_output.detach().cpu().tolist())

        return embeddings

def get_iit_examples(df):
    """
    Given a dataframe in the new data scheme, return all intervention pairs.
    """
    # Drop label distribution and worker information.
    columns_to_keep = ['id', 'original_id', 'edit_id', 'is_original', 'edit_goal', 'edit_type', 'description', 'review_majority','food_aspect_majority', 'ambiance_aspect_majority', 'service_aspect_majority', 'noise_aspect_majority']
    columns_to_keep += [col for col in df.columns if 'prediction' in col]
    df = df[columns_to_keep]
    return df

class ProxyIIT():
    def __init__(self, model_path, device, batch_size):
        self.batch_size = batch_size
        self.device = device
        model = IITBERTForSequenceClassification.from_pretrained(
            model_path,
            cache_dir="../huggingface_cache/"
        )
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        model.to(device)
        self.model = InterventionableIITTransformerForSequenceClassification(
            model=model
        )
    def fit(self, dataset, classifier_predictions, classifier, dev_dataset=None):
        # we don't need to train IIT here.
        pass
    
    def preprocess(self, pairs_dataset, dev_dataset):

        min_iit_pair_examples = 10
        query_dataset = get_iit_examples(dev_dataset)
        iit_pairs_dataset = []
        iit_id = 0
        for index, row in pairs_dataset.iterrows():
            query_description_base = row['description_base']
            query_int_type = row['intervention_type']
            query_int_aspect_base = row["intervention_aspect_base"]
            query_int_aspect_assignment = row['intervention_aspect_counterfactual']
            query_original_id = row["original_id_base"]
            matched_iit_examples = query_dataset[
                (query_dataset[f"{query_int_type}_aspect_majority"]==query_int_aspect_assignment)&
                (query_dataset["original_id"]!=query_original_id)
            ]
            if len(set(matched_iit_examples["id"])) < min_iit_pair_examples:
                assert False # we need to check the number!
            sampled_iit_example_ids = random.sample(
                set(matched_iit_examples["id"]), min_iit_pair_examples
            )
            for _id in sampled_iit_example_ids:
                description_iit = query_dataset[query_dataset["id"]==_id]["description"].iloc[0]
                iit_pairs_dataset += [[
                    iit_id,
                    query_int_type, query_int_aspect_base, 
                    query_int_aspect_assignment, 
                    query_description_base, description_iit
                ]]
            iit_id += 1
        iit_pairs_dataset = pd.DataFrame(
            columns=[
                'iit_id',
                'intervention_type', 
                'intervention_aspect_base', 
                'intervention_aspect_counterfactual',
                'description_base', 
                'description_iit'], 
            data=iit_pairs_dataset
        )
        
        base_x = self.tokenizer(
            iit_pairs_dataset['description_base'].to_list(), 
            padding=True, truncation=True, return_tensors='pt'
        )
        source_x = self.tokenizer(
            iit_pairs_dataset['description_iit'].to_list(), 
            padding=True, truncation=True, return_tensors='pt'
        )
        intervention_corr = []
        for _type in iit_pairs_dataset["intervention_type"].tolist():
            if _type == "ambiance":
                intervention_corr += [0]
            if _type == "food":
                intervention_corr += [1]
            if _type == "noise":
                intervention_corr += [2]
            if _type == "service":
                intervention_corr += [3]
        intervention_corr = torch.tensor(intervention_corr).long()
        return base_x, source_x, intervention_corr, iit_pairs_dataset
    
    def predict_proba(self, pairs, df):
        ProxyIIT_iTEs = []
        self.model.model.eval()
        base_x, source_x, intervention_corr, iit_pairs_dataset = self.preprocess(
            pairs, df
        )
        with torch.no_grad():
            for i in tqdm(range(ceil(len(iit_pairs_dataset)/self.batch_size))):
                base_x_batch = {k:v[i*self.batch_size:(i+1)*self.batch_size].to(self.device) for k,v in base_x.items()} 
                source_x_batch = {k:v[i*self.batch_size:(i+1)*self.batch_size].to(self.device) for k,v in source_x.items()} 
                intervention_corr_batch = intervention_corr[i*self.batch_size:(i+1)*self.batch_size].to(self.device)
                
                base_outputs, _, counterfactual_outputs = self.model.forward(
                    base=(base_x_batch['input_ids'], base_x_batch['attention_mask']),
                    source=(source_x_batch['input_ids'], source_x_batch['attention_mask']),
                    base_intervention_corr=intervention_corr_batch,
                    source_intervention_corr=intervention_corr_batch,
                )
                base_outputs = torch.nn.functional.softmax(base_outputs["logits"][0].cpu(), dim=-1).detach()
                counterfactual_outputs = torch.nn.functional.softmax(counterfactual_outputs["logits"][0].cpu(), dim=-1).detach()
                ProxyIIT_iTE = counterfactual_outputs-base_outputs
                ProxyIIT_iTEs.append(ProxyIIT_iTE)
        ProxyIIT_iTEs = torch.concat(ProxyIIT_iTEs)
        ProxyIIT_iTEs = np.round(ProxyIIT_iTEs.numpy(), decimals=4)

        # only for iit explainer!
        iit_pairs_dataset["EiCaCE"] = list(ProxyIIT_iTEs)
        ProxyIIT_iTEs = list(iit_pairs_dataset.groupby(["iit_id"])["EiCaCE"].mean())
        
        return ProxyIIT_iTEs
    

In [51]:
import pandas as pd

from eval_pipeline.utils import metric_utils, get_intervention_pairs

def cebab_pipeline(
    model, explainer, 
    train_dataset, dev_dataset, 
    dataset_type='5-way', 
    shorten_model_name=False
):
    # get predictions on train and dev
    train_predictions, _ = model.predict_proba(
        train_dataset
    )
    dev_predictions, dev_report = model.predict_proba(
        dev_dataset
    )

    # append predictions to datasets
    train_dataset['prediction'] = list(train_predictions)
    dev_dataset['prediction'] = list(dev_predictions)

    # fit explainer
    explainer.fit(
        train_dataset, train_predictions, 
        model, dev_dataset
    )

    # get intervention pairs
    pairs_dataset = get_intervention_pairs(
        dev_dataset, dataset_type=dataset_type
    )  # TODO why is the index not unique here?
     
    # get explanations
    explanations = explainer.predict_proba(
        pairs_dataset,
        dev_dataset
    )
    
    # append explanations to the pairs
    pairs_dataset['EICaCE'] = explanations
    pairs_dataset = metric_utils._calculate_ite(pairs_dataset)  # effect of crowd-workers on other crowd-workers (no model, no explainer)
    pairs_dataset = metric_utils._calculate_icace(pairs_dataset)  # effect of concept on the model (with model, no explainer)
    pairs_dataset = metric_utils._calculate_estimate_loss(pairs_dataset)  # l2 CEBaB Score (model and explainer)

    # only keep columns relevant for metrics
    CEBaB_metrics_per_pair = pairs_dataset[[
        'intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual', 'ITE', 'ICaCE', 'EICaCE', 'ICaCE-error']].copy()
    CEBaB_metrics_per_pair['count'] = 1

    # get CEBaB tables
    metrics = ['count', 'ICaCE', 'EICaCE']

    groupby_aspect_direction = ['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual']

    CaCE_per_aspect_direction = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, groupby_aspect_direction, metrics)
    CaCE_per_aspect_direction.columns = ['count', 'CaCE', 'ECaCE']
    CaCE_per_aspect_direction = CaCE_per_aspect_direction.set_index(['count'], append=True)
    
    ACaCE_per_aspect = metric_utils._aggregate_metrics(CaCE_per_aspect_direction.abs(), ['intervention_type'], ['CaCE', 'ECaCE'])
    ACaCE_per_aspect.columns = ['ACaCE', 'EACaCE']

    CEBaB_metrics_per_aspect_direction = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, groupby_aspect_direction, ['count', 'ICaCE-error'])
    CEBaB_metrics_per_aspect_direction.columns = ['count', 'ICaCE-error']
    CEBaB_metrics_per_aspect_direction = CEBaB_metrics_per_aspect_direction.set_index(['count'], append=True)

    CEBaB_metrics = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, [], ['ICaCE-error'])

    # get ATE table
    ATE = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, groupby_aspect_direction, ['count', 'ITE'])
    ATE.columns = ['count', 'ATE']
    # ATE = ATE.set_index(['count'], append=True)  # TODO why is the count a part of the index?

    # add model and explainer information
    if shorten_model_name:
        model_name = str(model).split('.')[0]
    else:
        model_name = str(model)

    CaCE_per_aspect_direction.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) if col != 'CaCE' else (model_name, '', col) for col in CaCE_per_aspect_direction.columns])
    ACaCE_per_aspect.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) if col != 'ACaCE' else (model_name, '', col) for col in ACaCE_per_aspect.columns])
    CEBaB_metrics_per_aspect_direction.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) for col in CEBaB_metrics_per_aspect_direction.columns])
    CEBaB_metrics.index = pd.MultiIndex.from_product([[model_name], [str(explainer)], CEBaB_metrics.index])

    return ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, CaCE_per_aspect_direction, ACaCE_per_aspect, dev_report

In [52]:
model_path = 'CEBaB/bert-base-uncased.CEBaB.sa.2-class.exclusive.seed_42'
proxy_model_path = './proxy_training_results/cebab.train.train.alpha.1.0.beta.1.0.dim.192.hightype.bert-base-uncased.Proxy.CEBaB.sa.2-class.exclusive.mode.align.seed_42'

dataset_type = '2-way'

device = 'cuda:0'
batch_size = 32

# load data from HF
cebab = datasets.load_dataset(
    'CEBaB/CEBaB', use_auth_token=True,
    cache_dir="../huggingface_cache/"
)
cebab['train'] = cebab['train_exclusive']
train, dev, test = preprocess_hf_dataset(
    cebab, one_example_per_world=False, 
    verbose=1, dataset_type=dataset_type
)

tf_model = BERTForCEBaB(
    model_path, 
    device=device, 
    batch_size=batch_size
)
explanator = ProxyIIT(
    proxy_model_path, 
    device=device, 
    batch_size=batch_size
)

train_dataset = train.copy()
dev_dataset = dev.copy()

Using custom data configuration CEBaB--CEBaB-0e2f7ed67c9d7e55
Reusing dataset parquet (../huggingface_cache/parquet/CEBaB--CEBaB-0e2f7ed67c9d7e55/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


  0%|          | 0/4 [00:00<?, ?it/s]

Dropping no majority reviews: 16.03% of train dataset.
Dropped 2604 examples with a neutral label.
Dropped 452 examples with a neutral label.
Dropped 461 examples with a neutral label.


In [53]:
ATE, CEBaB_metrics, \
CEBaB_metrics_per_aspect_direction, \
CaCE_per_aspect_direction, \
ACaCE_per_aspect, dev_report = cebab_pipeline(
    tf_model, explanator, 
    train_dataset, dev_dataset, 
    dataset_type='2-way'
)

100%|██████████| 711/711 [01:47<00:00,  6.58it/s]


In [54]:
CEBaB_metrics

Unnamed: 0,Unnamed: 1,Unnamed: 2,ICaCE-error
bert-base-uncased.CEBaB.sa.2-class.exclusive.seed_42,<__main__.ProxyIIT object at 0x7f281604ffd0>,mean,0.2141
